sparkleman
commited on
Commit
·
9b9e15b
1
Parent(s):
d761c6a
UPDATE: support stop tokens
Browse files- Dockerfile +2 -1
- app.py +26 -6
- config.production.yaml +4 -0
- config.py +3 -2
- pyproject.toml +1 -0
- uv.lock +2 -0
Dockerfile
CHANGED
|
@@ -15,7 +15,8 @@ RUN ["cargo", "install", "wasm-pack"]
|
|
| 15 |
WORKDIR /app
|
| 16 |
ENV PATH=/root/.cargo/bin:$PATH
|
| 17 |
RUN npm install -g pnpm
|
| 18 |
-
RUN pnpm install
|
|
|
|
| 19 |
|
| 20 |
FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 AS Backend
|
| 21 |
|
|
|
|
| 15 |
WORKDIR /app
|
| 16 |
ENV PATH=/root/.cargo/bin:$PATH
|
| 17 |
RUN npm install -g pnpm
|
| 18 |
+
RUN pnpm install
|
| 19 |
+
RUN pnpm run build --mode target-rwkv-hf-space
|
| 20 |
|
| 21 |
FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 AS Backend
|
| 22 |
|
app.py
CHANGED
|
@@ -3,6 +3,7 @@ from config import CONFIG, ModelConfig
|
|
| 3 |
import os, copy, types, gc, sys, re, time, collections, asyncio
|
| 4 |
from huggingface_hub import hf_hub_download
|
| 5 |
from loguru import logger
|
|
|
|
| 6 |
|
| 7 |
from snowflake import SnowflakeGenerator
|
| 8 |
|
|
@@ -92,6 +93,8 @@ for model_config in CONFIG.MODELS:
|
|
| 92 |
else:
|
| 93 |
DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
|
| 94 |
|
|
|
|
|
|
|
| 95 |
MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
|
| 96 |
MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
|
| 97 |
MODEL_STORAGE[model_config.SERVICE_NAME].model = tmp_model
|
|
@@ -119,6 +122,7 @@ class ChatCompletionRequest(BaseModel):
|
|
| 119 |
state_name: Optional[str] = Field(default=None)
|
| 120 |
include_usage: Optional[bool] = Field(default=False)
|
| 121 |
stop: Optional[list[str]] = Field(["\n\n"])
|
|
|
|
| 122 |
|
| 123 |
@model_validator(mode="before")
|
| 124 |
@classmethod
|
|
@@ -169,7 +173,7 @@ async def runPrefill(
|
|
| 169 |
def generate(
|
| 170 |
request: ChatCompletionRequest,
|
| 171 |
out,
|
| 172 |
-
model_tokens,
|
| 173 |
model_state,
|
| 174 |
stops=["\n\n"],
|
| 175 |
max_tokens=2048,
|
|
@@ -184,7 +188,7 @@ def generate(
|
|
| 184 |
) # stop generation whenever you see any token here
|
| 185 |
|
| 186 |
occurrence = {}
|
| 187 |
-
out_tokens = []
|
| 188 |
out_last = 0
|
| 189 |
|
| 190 |
output_cache = collections.deque(maxlen=5)
|
|
@@ -192,7 +196,7 @@ def generate(
|
|
| 192 |
for i in range(max_tokens):
|
| 193 |
for n in occurrence:
|
| 194 |
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
|
| 195 |
-
out[0] -= 1e10 # disable END_OF_TEXT
|
| 196 |
|
| 197 |
token = MODEL_STORAGE[request.model].pipeline.sample_logits(
|
| 198 |
out, temperature=args.temperature, top_p=args.top_p
|
|
@@ -201,9 +205,21 @@ def generate(
|
|
| 201 |
out, model_state = MODEL_STORAGE[request.model].model.forward(
|
| 202 |
[token], model_state
|
| 203 |
)
|
| 204 |
-
model_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
-
|
|
|
|
|
|
|
| 207 |
|
| 208 |
for xxx in occurrence:
|
| 209 |
occurrence[xxx] *= request.penalty_decay
|
|
@@ -260,6 +276,7 @@ async def chatResponse(
|
|
| 260 |
if request.prompt == None
|
| 261 |
else request.prompt.strip()
|
| 262 |
)
|
|
|
|
| 263 |
|
| 264 |
out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
|
| 265 |
|
|
@@ -343,6 +360,8 @@ async def chatResponseStream(
|
|
| 343 |
else request.prompt.strip()
|
| 344 |
)
|
| 345 |
|
|
|
|
|
|
|
| 346 |
out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
|
| 347 |
|
| 348 |
prefillTime = time.time()
|
|
@@ -465,7 +484,7 @@ async def chatResponseStream(
|
|
| 465 |
streamConfig["fullTextCursor"] = len(fullText)
|
| 466 |
|
| 467 |
markEnd = fullText.find(">", streamConfig["fullTextCursor"])
|
| 468 |
-
if streamConfig["isChecking"] and markEnd != -1:
|
| 469 |
streamConfig["isChecking"] = False
|
| 470 |
|
| 471 |
if (
|
|
@@ -626,6 +645,7 @@ async def chat_completions(request: ChatCompletionRequest):
|
|
| 626 |
|
| 627 |
return r
|
| 628 |
|
|
|
|
| 629 |
app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static")
|
| 630 |
|
| 631 |
if __name__ == "__main__":
|
|
|
|
| 3 |
import os, copy, types, gc, sys, re, time, collections, asyncio
|
| 4 |
from huggingface_hub import hf_hub_download
|
| 5 |
from loguru import logger
|
| 6 |
+
from rich import print
|
| 7 |
|
| 8 |
from snowflake import SnowflakeGenerator
|
| 9 |
|
|
|
|
| 93 |
else:
|
| 94 |
DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
|
| 95 |
|
| 96 |
+
print(model_config.DEFAULT_SAMPLER)
|
| 97 |
+
|
| 98 |
MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
|
| 99 |
MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
|
| 100 |
MODEL_STORAGE[model_config.SERVICE_NAME].model = tmp_model
|
|
|
|
| 122 |
state_name: Optional[str] = Field(default=None)
|
| 123 |
include_usage: Optional[bool] = Field(default=False)
|
| 124 |
stop: Optional[list[str]] = Field(["\n\n"])
|
| 125 |
+
stop_tokens: Optional[list[int]] = Field([0])
|
| 126 |
|
| 127 |
@model_validator(mode="before")
|
| 128 |
@classmethod
|
|
|
|
| 173 |
def generate(
|
| 174 |
request: ChatCompletionRequest,
|
| 175 |
out,
|
| 176 |
+
model_tokens: List[int],
|
| 177 |
model_state,
|
| 178 |
stops=["\n\n"],
|
| 179 |
max_tokens=2048,
|
|
|
|
| 188 |
) # stop generation whenever you see any token here
|
| 189 |
|
| 190 |
occurrence = {}
|
| 191 |
+
out_tokens: List[int] = []
|
| 192 |
out_last = 0
|
| 193 |
|
| 194 |
output_cache = collections.deque(maxlen=5)
|
|
|
|
| 196 |
for i in range(max_tokens):
|
| 197 |
for n in occurrence:
|
| 198 |
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
|
| 199 |
+
# out[0] -= 1e10 # disable END_OF_TEXT
|
| 200 |
|
| 201 |
token = MODEL_STORAGE[request.model].pipeline.sample_logits(
|
| 202 |
out, temperature=args.temperature, top_p=args.top_p
|
|
|
|
| 205 |
out, model_state = MODEL_STORAGE[request.model].model.forward(
|
| 206 |
[token], model_state
|
| 207 |
)
|
| 208 |
+
model_tokens.append(token)
|
| 209 |
+
|
| 210 |
+
out_tokens.append(token)
|
| 211 |
+
|
| 212 |
+
if token in request.stop_tokens:
|
| 213 |
+
yield {
|
| 214 |
+
"content": "",
|
| 215 |
+
"tokens": out_tokens[out_last:],
|
| 216 |
+
"finish_reason": "stop",
|
| 217 |
+
"state": model_state,
|
| 218 |
+
}
|
| 219 |
|
| 220 |
+
del out
|
| 221 |
+
gc.collect()
|
| 222 |
+
return
|
| 223 |
|
| 224 |
for xxx in occurrence:
|
| 225 |
occurrence[xxx] *= request.penalty_decay
|
|
|
|
| 276 |
if request.prompt == None
|
| 277 |
else request.prompt.strip()
|
| 278 |
)
|
| 279 |
+
logger.info(f"[REQ] {completionId} - prompt - {prompt}")
|
| 280 |
|
| 281 |
out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
|
| 282 |
|
|
|
|
| 360 |
else request.prompt.strip()
|
| 361 |
)
|
| 362 |
|
| 363 |
+
# logger.info(f"[REQ] {completionId} - prompt - {prompt}")
|
| 364 |
+
|
| 365 |
out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
|
| 366 |
|
| 367 |
prefillTime = time.time()
|
|
|
|
| 484 |
streamConfig["fullTextCursor"] = len(fullText)
|
| 485 |
|
| 486 |
markEnd = fullText.find(">", streamConfig["fullTextCursor"])
|
| 487 |
+
if (streamConfig["isChecking"] and markEnd != -1) or finishReason != None:
|
| 488 |
streamConfig["isChecking"] = False
|
| 489 |
|
| 490 |
if (
|
|
|
|
| 645 |
|
| 646 |
return r
|
| 647 |
|
| 648 |
+
|
| 649 |
app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static")
|
| 650 |
|
| 651 |
if __name__ == "__main__":
|
config.production.yaml
CHANGED
|
@@ -18,6 +18,8 @@ MODELS:
|
|
| 18 |
penalty_decay: 0.996
|
| 19 |
stop:
|
| 20 |
- "\n\n"
|
|
|
|
|
|
|
| 21 |
- SERVICE_NAME: "rwkv7-g1-0.1b-20250307-ctx4096"
|
| 22 |
DOWNLOAD_MODEL_FILE_NAME: "rwkv7-g1-0.1b-20250307-ctx4096.pth"
|
| 23 |
DOWNLOAD_MODEL_REPO_ID: "BlinkDL/rwkv7-g1"
|
|
@@ -32,3 +34,5 @@ MODELS:
|
|
| 32 |
penalty_decay: 0.996
|
| 33 |
stop:
|
| 34 |
- "\n\n"
|
|
|
|
|
|
|
|
|
| 18 |
penalty_decay: 0.996
|
| 19 |
stop:
|
| 20 |
- "\n\n"
|
| 21 |
+
stop_tokens:
|
| 22 |
+
- 0
|
| 23 |
- SERVICE_NAME: "rwkv7-g1-0.1b-20250307-ctx4096"
|
| 24 |
DOWNLOAD_MODEL_FILE_NAME: "rwkv7-g1-0.1b-20250307-ctx4096.pth"
|
| 25 |
DOWNLOAD_MODEL_REPO_ID: "BlinkDL/rwkv7-g1"
|
|
|
|
| 34 |
penalty_decay: 0.996
|
| 35 |
stop:
|
| 36 |
- "\n\n"
|
| 37 |
+
stop_tokens:
|
| 38 |
+
- 0
|
config.py
CHANGED
|
@@ -23,8 +23,9 @@ class SamplerConfig(BaseModel):
|
|
| 23 |
top_p: float = Field(0.3, description="Top-p sampling threshold.")
|
| 24 |
presence_penalty: float = Field(0.5, description="Presence penalty.")
|
| 25 |
count_penalty: float = Field(0.5, description="Count penalty.")
|
| 26 |
-
penalty_decay: float = Field(0.
|
| 27 |
-
stop: List[str] = Field(
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
class ModelConfig(BaseModel):
|
|
|
|
| 23 |
top_p: float = Field(0.3, description="Top-p sampling threshold.")
|
| 24 |
presence_penalty: float = Field(0.5, description="Presence penalty.")
|
| 25 |
count_penalty: float = Field(0.5, description="Count penalty.")
|
| 26 |
+
penalty_decay: float = Field(0.996, description="Penalty decay factor.")
|
| 27 |
+
stop: List[str] = Field(["\n\n"], description="List of stop sequences.")
|
| 28 |
+
stop_tokens: List[int] = Field([0], description="List of stop tokens.")
|
| 29 |
|
| 30 |
|
| 31 |
class ModelConfig(BaseModel):
|
pyproject.toml
CHANGED
|
@@ -13,6 +13,7 @@ dependencies = [
|
|
| 13 |
"pydantic>=2.10.6",
|
| 14 |
"pydantic-settings>=2.8.1",
|
| 15 |
"pynvml>=12.0.0",
|
|
|
|
| 16 |
"rwkv==0.8.28",
|
| 17 |
"setuptools>=75.8.2",
|
| 18 |
"snowflake-id>=1.0.2",
|
|
|
|
| 13 |
"pydantic>=2.10.6",
|
| 14 |
"pydantic-settings>=2.8.1",
|
| 15 |
"pynvml>=12.0.0",
|
| 16 |
+
"rich>=13.9.4",
|
| 17 |
"rwkv==0.8.28",
|
| 18 |
"setuptools>=75.8.2",
|
| 19 |
"snowflake-id>=1.0.2",
|
uv.lock
CHANGED
|
@@ -944,6 +944,7 @@ dependencies = [
|
|
| 944 |
{ name = "pydantic" },
|
| 945 |
{ name = "pydantic-settings" },
|
| 946 |
{ name = "pynvml" },
|
|
|
|
| 947 |
{ name = "rwkv" },
|
| 948 |
{ name = "setuptools" },
|
| 949 |
{ name = "snowflake-id" },
|
|
@@ -971,6 +972,7 @@ requires-dist = [
|
|
| 971 |
{ name = "pydantic", specifier = ">=2.10.6" },
|
| 972 |
{ name = "pydantic-settings", specifier = ">=2.8.1" },
|
| 973 |
{ name = "pynvml", specifier = ">=12.0.0" },
|
|
|
|
| 974 |
{ name = "rwkv", specifier = "==0.8.28" },
|
| 975 |
{ name = "setuptools", specifier = ">=75.8.2" },
|
| 976 |
{ name = "snowflake-id", specifier = ">=1.0.2" },
|
|
|
|
| 944 |
{ name = "pydantic" },
|
| 945 |
{ name = "pydantic-settings" },
|
| 946 |
{ name = "pynvml" },
|
| 947 |
+
{ name = "rich" },
|
| 948 |
{ name = "rwkv" },
|
| 949 |
{ name = "setuptools" },
|
| 950 |
{ name = "snowflake-id" },
|
|
|
|
| 972 |
{ name = "pydantic", specifier = ">=2.10.6" },
|
| 973 |
{ name = "pydantic-settings", specifier = ">=2.8.1" },
|
| 974 |
{ name = "pynvml", specifier = ">=12.0.0" },
|
| 975 |
+
{ name = "rich", specifier = ">=13.9.4" },
|
| 976 |
{ name = "rwkv", specifier = "==0.8.28" },
|
| 977 |
{ name = "setuptools", specifier = ">=75.8.2" },
|
| 978 |
{ name = "snowflake-id", specifier = ">=1.0.2" },
|