Spaces:
Sleeping
Sleeping
fix: greedy decoding
Browse files- modules/ai_model.py +24 -11
modules/ai_model.py
CHANGED
|
@@ -143,43 +143,56 @@ class AIModel:
|
|
| 143 |
def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str,temperature: float = 0.7) -> str:
|
| 144 |
|
| 145 |
try:
|
|
|
|
| 146 |
if len(prompt) > 500:
|
| 147 |
prompt = prompt[:500] + "..."
|
| 148 |
-
|
| 149 |
|
|
|
|
| 150 |
if input_type == "image" and isinstance(formatted_input, Image.Image):
|
| 151 |
-
|
| 152 |
image_token = getattr(self.processor.tokenizer, 'image_token', '<image>')
|
| 153 |
if image_token not in prompt:
|
| 154 |
-
prompt = f"{image_token}
|
| 155 |
-
|
| 156 |
inputs = self.processor(
|
| 157 |
text=prompt,
|
| 158 |
images=formatted_input,
|
| 159 |
return_tensors="pt"
|
| 160 |
).to(self.model.device, dtype=torch.bfloat16)
|
| 161 |
else:
|
| 162 |
-
|
| 163 |
inputs = self.processor(
|
| 164 |
text=prompt,
|
| 165 |
return_tensors="pt"
|
| 166 |
).to(self.model.device, dtype=torch.bfloat16)
|
| 167 |
|
|
|
|
| 168 |
if hasattr(inputs, 'input_ids') and inputs.input_ids.shape[-1] > 512:
|
| 169 |
log.warning(f"⚠️ 截断过长输入: {inputs.input_ids.shape[-1]} -> 512")
|
| 170 |
inputs.input_ids = inputs.input_ids[:, :512]
|
| 171 |
if hasattr(inputs, 'attention_mask'):
|
| 172 |
inputs.attention_mask = inputs.attention_mask[:, :512]
|
| 173 |
|
|
|
|
| 174 |
with torch.inference_mode():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
outputs = self.model.generate(
|
| 176 |
**inputs,
|
| 177 |
-
|
| 178 |
-
do_sample=True,
|
| 179 |
-
temperature=temperature,
|
| 180 |
-
top_p=0.9,
|
| 181 |
-
pad_token_id=self.processor.tokenizer.eos_token_id,
|
| 182 |
-
use_cache=True
|
| 183 |
)
|
| 184 |
|
| 185 |
decoded = self.processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
|
|
|
| 143 |
def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str,temperature: float = 0.7) -> str:
|
| 144 |
|
| 145 |
try:
|
| 146 |
+
# 截断过长的 prompt
|
| 147 |
if len(prompt) > 500:
|
| 148 |
prompt = prompt[:500] + "..."
|
|
|
|
| 149 |
|
| 150 |
+
# 准备输入 (处理图片或文本)
|
| 151 |
if input_type == "image" and isinstance(formatted_input, Image.Image):
|
|
|
|
| 152 |
image_token = getattr(self.processor.tokenizer, 'image_token', '<image>')
|
| 153 |
if image_token not in prompt:
|
| 154 |
+
prompt = f"{image_token}\\n{prompt}"
|
|
|
|
| 155 |
inputs = self.processor(
|
| 156 |
text=prompt,
|
| 157 |
images=formatted_input,
|
| 158 |
return_tensors="pt"
|
| 159 |
).to(self.model.device, dtype=torch.bfloat16)
|
| 160 |
else:
|
|
|
|
| 161 |
inputs = self.processor(
|
| 162 |
text=prompt,
|
| 163 |
return_tensors="pt"
|
| 164 |
).to(self.model.device, dtype=torch.bfloat16)
|
| 165 |
|
| 166 |
+
# 截断过长的 token
|
| 167 |
if hasattr(inputs, 'input_ids') and inputs.input_ids.shape[-1] > 512:
|
| 168 |
log.warning(f"⚠️ 截断过长输入: {inputs.input_ids.shape[-1]} -> 512")
|
| 169 |
inputs.input_ids = inputs.input_ids[:, :512]
|
| 170 |
if hasattr(inputs, 'attention_mask'):
|
| 171 |
inputs.attention_mask = inputs.attention_mask[:, :512]
|
| 172 |
|
| 173 |
+
# --- 这是关键的修改 ---
|
| 174 |
with torch.inference_mode():
|
| 175 |
+
generation_args = {
|
| 176 |
+
"max_new_tokens": 256,
|
| 177 |
+
"pad_token_id": self.processor.tokenizer.eos_token_id,
|
| 178 |
+
"use_cache": True
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
# 如果 temperature 接近0,使用贪心解码 (用于分类等确定性任务)
|
| 182 |
+
if temperature < 1e-6: # 使用一个很小的数来比较浮点数
|
| 183 |
+
log.info("▶️ 使用贪心解码 (do_sample=False) 以获得确定性输出。")
|
| 184 |
+
generation_args["do_sample"] = False
|
| 185 |
+
# 否则,使用采样解码 (用于创造性生成任务)
|
| 186 |
+
else:
|
| 187 |
+
log.info(f"▶️ 使用采样解码 (do_sample=True),temperature={temperature}。")
|
| 188 |
+
generation_args["do_sample"] = True
|
| 189 |
+
generation_args["temperature"] = temperature
|
| 190 |
+
generation_args["top_p"] = 0.9 # top_p 只在采样时有意义
|
| 191 |
+
|
| 192 |
+
# 使用构建好的参数字典来调用 generate
|
| 193 |
outputs = self.model.generate(
|
| 194 |
**inputs,
|
| 195 |
+
**generation_args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
)
|
| 197 |
|
| 198 |
decoded = self.processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|