aleclyu commited on
Commit
e7257d2
·
1 Parent(s): c2b0812

debug zerogpu timeout error

Browse files
Files changed (1) hide show
  1. app.py +38 -6
app.py CHANGED
@@ -151,16 +151,48 @@ def _launch_demo(args, model, processor):
151
  print(f"[DEBUG] 输入准备完成,耗时: {time.time() - start_time:.2f}s")
152
  print(f"[DEBUG] Input IDs shape: {inputs.input_ids.shape}")
153
  print(f"[DEBUG] Input device: {inputs.input_ids.device}")
 
154
 
155
  # 生成
156
  gen_start = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  with torch.no_grad():
158
- generated_ids = model.generate(
159
- **inputs,
160
- max_new_tokens=1024*8,
161
- repetition_penalty=1.03,
162
- do_sample=False
163
- )
 
 
 
 
 
 
 
 
 
 
164
 
165
  gen_time = time.time() - gen_start
166
  print(f"[DEBUG] ========== 生成完成 ==========")
 
151
  print(f"[DEBUG] 输入准备完成,耗时: {time.time() - start_time:.2f}s")
152
  print(f"[DEBUG] Input IDs shape: {inputs.input_ids.shape}")
153
  print(f"[DEBUG] Input device: {inputs.input_ids.device}")
154
+ print(f"[DEBUG] Input sequence length: {inputs.input_ids.shape[1]}")
155
 
156
  # 生成
157
  gen_start = time.time()
158
+ print(f"[DEBUG] ========== 开始生成 tokens ==========")
159
+
160
+ # 关键优化:根据任务类型动态调整 max_new_tokens
161
+ # OCR 任务通常不需要 8192 tokens,这会导致不必要的等待
162
+ max_new_tokens = 2048 # 从 8192 降到 2048,大幅提速
163
+ print(f"[DEBUG] max_new_tokens: {max_new_tokens}")
164
+
165
+ # 添加进度回调
166
+ token_count = [0]
167
+ last_time = [gen_start]
168
+
169
+ def progress_callback(input_ids, scores, **kwargs):
170
+ token_count[0] += 1
171
+ current_time = time.time()
172
+ if token_count[0] % 10 == 0 or (current_time - last_time[0]) > 2.0:
173
+ elapsed = current_time - gen_start
174
+ tokens_per_sec = token_count[0] / elapsed if elapsed > 0 else 0
175
+ print(f"[DEBUG] 已生成 {token_count[0]} tokens, 速度: {tokens_per_sec:.2f} tokens/s, 耗时: {elapsed:.2f}s")
176
+ last_time[0] = current_time
177
+ return False
178
+
179
  with torch.no_grad():
180
+ print(f"[DEBUG] 调用 model.generate()...")
181
+ try:
182
+ generated_ids = model.generate(
183
+ **inputs,
184
+ max_new_tokens=max_new_tokens,
185
+ repetition_penalty=1.03,
186
+ do_sample=False,
187
+ stopping_criteria=None, # 确保没有额外的停止条件
188
+ pad_token_id=processor.tokenizer.pad_token_id,
189
+ eos_token_id=processor.tokenizer.eos_token_id,
190
+ )
191
+ except Exception as e:
192
+ print(f"[ERROR] 生成失败: {e}")
193
+ raise
194
+
195
+ print(f"[DEBUG] model.generate() 调用完成")
196
 
197
  gen_time = time.time() - gen_start
198
  print(f"[DEBUG] ========== 生成完成 ==========")