Ksjsjjdj commited on
Commit
a8d60da
·
verified ·
1 Parent(s): 29fae62

Update config.py

Browse files
Files changed (1) hide show
  1. config.py +84 -84
config.py CHANGED
@@ -1,84 +1,84 @@
1
- from pydantic import BaseModel, Field
2
- from typing import List, Optional
3
- from typing import List, Optional, Union, Any
4
-
5
- import sys
6
-
7
-
8
- from pydantic_settings import BaseSettings
9
-
10
-
11
- class CliConfig(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=True):
12
- CONFIG_FILE: str = Field("./config.local.yaml", description="Config file path")
13
-
14
-
15
- CLI_CONFIG = CliConfig()
16
-
17
-
18
- class SamplerConfig(BaseModel):
19
- """Default sampler configuration for each model."""
20
-
21
- max_tokens: int = Field(512, description="Maximum number of tokens to generate.")
22
- temperature: float = Field(1.0, description="Sampling temperature.")
23
- top_p: float = Field(0.3, description="Top-p sampling threshold.")
24
- presence_penalty: float = Field(0.5, description="Presence penalty.")
25
- count_penalty: float = Field(0.5, description="Count penalty.")
26
- penalty_decay: float = Field(0.996, description="Penalty decay factor.")
27
- stop: List[str] = Field(["\n\n"], description="List of stop sequences.")
28
- stop_tokens: List[int] = Field([0], description="List of stop tokens.")
29
-
30
-
31
- class ModelConfig(BaseModel):
32
- """Configuration for each individual model."""
33
-
34
- SERVICE_NAME: str = Field(..., description="Service name of the model.")
35
-
36
- MODEL_FILE_PATH: Optional[str] = Field(None, description="Model file path.")
37
-
38
- DOWNLOAD_MODEL_FILE_NAME: Optional[str] = Field(
39
- None, description="Model name, should end with .pth"
40
- )
41
- DOWNLOAD_MODEL_REPO_ID: Optional[str] = Field(
42
- None, description="Model repository ID on Hugging Face Hub."
43
- )
44
- DOWNLOAD_MODEL_DIR: Optional[str] = Field(
45
- None, description="Directory to download the model to."
46
- )
47
-
48
- REASONING: bool = Field(
49
- False, description="Whether reasoning is enabled for this model."
50
- )
51
-
52
- DEFAULT_CHAT: bool = Field(False, description="Whether this model is the default chat model.")
53
- DEFAULT_REASONING: bool = Field(False, description="Whether this model is the default reasoning model.")
54
- DEFAULT_SAMPLER: SamplerConfig = Field(
55
- SamplerConfig(), description="Default sampler configuration for this model."
56
- )
57
- VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
58
-
59
-
60
- class RootConfig(BaseModel):
61
- """Root configuration for the RWKV service."""
62
-
63
- HOST: Optional[str] = Field(
64
- "127.0.0.1", description="Host IP address to bind to."
65
- ) # 注释掉可选的HOST和PORT
66
- PORT: Optional[int] = Field(
67
- 8000, description="Port number to listen on."
68
- ) # 因为YAML示例中被注释掉了
69
- STRATEGY: str = Field(
70
- "cpu", description="Strategy for model execution (e.g., 'cuda fp16')."
71
- )
72
- RWKV_CUDA_ON: bool = Field(False, description="Whether to enable RWKV CUDA kernel.")
73
- CHUNK_LEN: int = Field(256, description="Chunk length for processing.")
74
- MODELS: List[ModelConfig] = Field(..., description="List of model configurations.")
75
-
76
-
77
- import yaml
78
-
79
- try:
80
- with open(CLI_CONFIG.CONFIG_FILE, "r", encoding="utf-8") as f:
81
- CONFIG = RootConfig.model_validate(yaml.safe_load(f.read()))
82
- except Exception as e:
83
- print(f"Pydantic Model Validation Failed: {e}")
84
- sys.exit(0)
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import List, Optional
3
+ from typing import List, Optional, Union, Any
4
+
5
+ import sys
6
+
7
+
8
+ from pydantic_settings import BaseSettings
9
+
10
+
11
+ class CliConfig(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=True):
12
+ CONFIG_FILE: str = Field("./config.local.yaml", description="Config file path")
13
+
14
+
15
+ CLI_CONFIG = CliConfig()
16
+
17
+
18
+ class SamplerConfig(BaseModel):
19
+ """Default sampler configuration for each model."""
20
+
21
+ max_tokens: int = Field(512, description="Maximum number of tokens to generate.")
22
+ temperature: float = Field(0.1, description="Sampling temperature.")
23
+ top_p: float = Field(0.3, description="Top-p sampling threshold.")
24
+ presence_penalty: float = Field(0.5, description="Presence penalty.")
25
+ count_penalty: float = Field(0.5, description="Count penalty.")
26
+ penalty_decay: float = Field(0.996, description="Penalty decay factor.")
27
+ stop: List[str] = Field(["\n\n"], description="List of stop sequences.")
28
+ stop_tokens: List[int] = Field([0], description="List of stop tokens.")
29
+
30
+
31
+ class ModelConfig(BaseModel):
32
+ """Configuration for each individual model."""
33
+
34
+ SERVICE_NAME: str = Field(..., description="Service name of the model.")
35
+
36
+ MODEL_FILE_PATH: Optional[str] = Field(None, description="Model file path.")
37
+
38
+ DOWNLOAD_MODEL_FILE_NAME: Optional[str] = Field(
39
+ None, description="Model name, should end with .pth"
40
+ )
41
+ DOWNLOAD_MODEL_REPO_ID: Optional[str] = Field(
42
+ None, description="Model repository ID on Hugging Face Hub."
43
+ )
44
+ DOWNLOAD_MODEL_DIR: Optional[str] = Field(
45
+ None, description="Directory to download the model to."
46
+ )
47
+
48
+ REASONING: bool = Field(
49
+ False, description="Whether reasoning is enabled for this model."
50
+ )
51
+
52
+ DEFAULT_CHAT: bool = Field(False, description="Whether this model is the default chat model.")
53
+ DEFAULT_REASONING: bool = Field(False, description="Whether this model is the default reasoning model.")
54
+ DEFAULT_SAMPLER: SamplerConfig = Field(
55
+ SamplerConfig(), description="Default sampler configuration for this model."
56
+ )
57
+ VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
58
+
59
+
60
+ class RootConfig(BaseModel):
61
+ """Root configuration for the RWKV service."""
62
+
63
+ HOST: Optional[str] = Field(
64
+ "127.0.0.1", description="Host IP address to bind to."
65
+ ) # 注释掉可选的HOST和PORT
66
+ PORT: Optional[int] = Field(
67
+ 8000, description="Port number to listen on."
68
+ ) # 因为YAML示例中被注释掉了
69
+ STRATEGY: str = Field(
70
+ "cpu", description="Strategy for model execution (e.g., 'cuda fp16')."
71
+ )
72
+ RWKV_CUDA_ON: bool = Field(False, description="Whether to enable RWKV CUDA kernel.")
73
+ CHUNK_LEN: int = Field(256, description="Chunk length for processing.")
74
+ MODELS: List[ModelConfig] = Field(..., description="List of model configurations.")
75
+
76
+
77
+ import yaml
78
+
79
+ try:
80
+ with open(CLI_CONFIG.CONFIG_FILE, "r", encoding="utf-8") as f:
81
+ CONFIG = RootConfig.model_validate(yaml.safe_load(f.read()))
82
+ except Exception as e:
83
+ print(f"Pydantic Model Validation Failed: {e}")
84
+ sys.exit(0)