Eliot0110 commited on
Commit
96512ae
·
1 Parent(s): 3589840

fix: greedy decoding

Browse files
Files changed (1) hide show
  1. modules/ai_model.py +24 -11
modules/ai_model.py CHANGED
@@ -143,43 +143,56 @@ class AIModel:
143
  def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str,temperature: float = 0.7) -> str:
144
 
145
  try:
 
146
  if len(prompt) > 500:
147
  prompt = prompt[:500] + "..."
148
-
149
 
 
150
  if input_type == "image" and isinstance(formatted_input, Image.Image):
151
-
152
  image_token = getattr(self.processor.tokenizer, 'image_token', '<image>')
153
  if image_token not in prompt:
154
- prompt = f"{image_token}\n{prompt}"
155
-
156
  inputs = self.processor(
157
  text=prompt,
158
  images=formatted_input,
159
  return_tensors="pt"
160
  ).to(self.model.device, dtype=torch.bfloat16)
161
  else:
162
-
163
  inputs = self.processor(
164
  text=prompt,
165
  return_tensors="pt"
166
  ).to(self.model.device, dtype=torch.bfloat16)
167
 
 
168
  if hasattr(inputs, 'input_ids') and inputs.input_ids.shape[-1] > 512:
169
  log.warning(f"⚠️ 截断过长输入: {inputs.input_ids.shape[-1]} -> 512")
170
  inputs.input_ids = inputs.input_ids[:, :512]
171
  if hasattr(inputs, 'attention_mask'):
172
  inputs.attention_mask = inputs.attention_mask[:, :512]
173
 
 
174
  with torch.inference_mode():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  outputs = self.model.generate(
176
  **inputs,
177
- max_new_tokens=256,
178
- do_sample=True,
179
- temperature=temperature,
180
- top_p=0.9,
181
- pad_token_id=self.processor.tokenizer.eos_token_id,
182
- use_cache=True
183
  )
184
 
185
  decoded = self.processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
 
143
  def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str,temperature: float = 0.7) -> str:
144
 
145
  try:
146
+ # 截断过长的 prompt
147
  if len(prompt) > 500:
148
  prompt = prompt[:500] + "..."
 
149
 
150
+ # 准备输入 (处理图片或文本)
151
  if input_type == "image" and isinstance(formatted_input, Image.Image):
 
152
  image_token = getattr(self.processor.tokenizer, 'image_token', '<image>')
153
  if image_token not in prompt:
154
+ prompt = f"{image_token}\\n{prompt}"
 
155
  inputs = self.processor(
156
  text=prompt,
157
  images=formatted_input,
158
  return_tensors="pt"
159
  ).to(self.model.device, dtype=torch.bfloat16)
160
  else:
 
161
  inputs = self.processor(
162
  text=prompt,
163
  return_tensors="pt"
164
  ).to(self.model.device, dtype=torch.bfloat16)
165
 
166
+ # 截断过长的 token
167
  if hasattr(inputs, 'input_ids') and inputs.input_ids.shape[-1] > 512:
168
  log.warning(f"⚠️ 截断过长输入: {inputs.input_ids.shape[-1]} -> 512")
169
  inputs.input_ids = inputs.input_ids[:, :512]
170
  if hasattr(inputs, 'attention_mask'):
171
  inputs.attention_mask = inputs.attention_mask[:, :512]
172
 
173
+ # --- 这是关键的修改 ---
174
  with torch.inference_mode():
175
+ generation_args = {
176
+ "max_new_tokens": 256,
177
+ "pad_token_id": self.processor.tokenizer.eos_token_id,
178
+ "use_cache": True
179
+ }
180
+
181
+ # 如果 temperature 接近0,使用贪心解码 (用于分类等确定性任务)
182
+ if temperature < 1e-6: # 使用一个很小的数来比较浮点数
183
+ log.info("▶️ 使用贪心解码 (do_sample=False) 以获得确定性输出。")
184
+ generation_args["do_sample"] = False
185
+ # 否则,使用采样解码 (用于创造性生成任务)
186
+ else:
187
+ log.info(f"▶️ 使用采样解码 (do_sample=True),temperature={temperature}。")
188
+ generation_args["do_sample"] = True
189
+ generation_args["temperature"] = temperature
190
+ generation_args["top_p"] = 0.9 # top_p 只在采样时有意义
191
+
192
+ # 使用构建好的参数字典来调用 generate
193
  outputs = self.model.generate(
194
  **inputs,
195
+ **generation_args
 
 
 
 
 
196
  )
197
 
198
  decoded = self.processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()