SunDou commited on
Commit
8629ccf
·
verified ·
1 Parent(s): f0b48c5

Upload data1/util.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. data1/util.py +330 -0
data1/util.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # util.py
2
+ from loguru import logger
3
+ import os
4
+ import json
5
+
6
+ from datetime import datetime
7
+ from typing import List, Dict, Any, Optional
8
+ from pathlib import Path
9
+ from langchain_openai import ChatOpenAI
10
+ from langchain_core.output_parsers import JsonOutputParser
11
+ from langchain.output_parsers import OutputFixingParser
12
+ from pydantic import BaseModel
13
+ import asyncio
14
+ import re
15
+
16
+ # Global concurrency limit
17
+ CONCURRENCY_LIMIT = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENCY", "8")))
18
+
19
+ # LLM 调用的超时与重试配置(可通过环境变量覆盖)
20
+ LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT", "180")) # 单次 LLM 调用超时秒数(默认 3 分钟)
21
+ LLM_MAX_RETRIES = int(os.getenv("LLM_MAX_RETRIES", "1")) # 最多重试次数(默认 1 次)
22
+ LLM_RETRY_BACKOFF = float(os.getenv("LLM_RETRY_BACKOFF", "2.0")) # 重试等待时间(秒)
23
+
24
+ # Common code file extensions
25
+ CODE_EXTENSIONS = {
26
+ # Python
27
+ ".py", ".ipynb",
28
+ # C/C++
29
+ ".c", ".cpp", ".cc", ".cxx", ".h", ".hpp", ".hh",
30
+ # Fortran
31
+ ".f", ".f90", ".f95", ".for",
32
+ # Julia
33
+ ".jl",
34
+ # R
35
+ ".r", ".R",
36
+ # Java
37
+ ".java",
38
+ # MATLAB/Octave
39
+ ".m",
40
+ # Shell脚本
41
+ ".sh", ".bash",
42
+ ".rs", # Rust
43
+ ".go", # Go
44
+ # Markdown
45
+ ".md", ".markdown",
46
+ }
47
+
48
+
49
+ def init_logger(log_file: str, level: str = "INFO"):
50
+ """Initialize logger with color output"""
51
+ os.makedirs(os.path.dirname(log_file), exist_ok=True)
52
+ logger.remove()
53
+ # Console output with color
54
+ logger.add(
55
+ sink=lambda msg: print(msg, end=""),
56
+ level=level,
57
+ colorize=True,
58
+ format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
59
+ )
60
+ # File output
61
+ logger.add(
62
+ Path(log_file),
63
+ level=level,
64
+ rotation="1 day",
65
+ encoding="utf-8",
66
+ format="{time:YYYY-MM-DD HH:mm:ss} [{level}] ({name}:{function}:{line}) {message}",
67
+ )
68
+ return logger
69
+
70
+
71
+ def log_api(log_file: str, data: Dict):
72
+ """Log API call to file"""
73
+ with open(log_file, "a", encoding="utf-8") as f:
74
+ record = {"time": datetime.utcnow().isoformat() + "Z", "data": data}
75
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
76
+
77
+
78
+ def extract_final_answer_from_reasoning(text: str, pydantic_object: Optional[type[BaseModel]] = None) -> Dict:
79
+ """
80
+ Extract final answer from reasoning model output.
81
+ For reasoning models like Qwen3, the response format is:
82
+ <think>reasoning content</think>final result
83
+
84
+ Args:
85
+ text: Raw response text from reasoning model
86
+ pydantic_object: Expected Pydantic model for structured output
87
+
88
+ Returns:
89
+ Dict with extracted relevant and reason fields
90
+ """
91
+ # Extract reasoning content from <think>...</think> tags
92
+ reasoning_pattern = r'<think>(.*?)</think>'
93
+ reasoning_match = re.search(reasoning_pattern, text, re.DOTALL | re.IGNORECASE)
94
+ reasoning_content = reasoning_match.group(1).strip() if reasoning_match else ""
95
+
96
+ # Extract final result (content after </think> tag)
97
+ final_result = ""
98
+ if reasoning_match:
99
+ # Get everything after the closing tag
100
+ final_result = text[reasoning_match.end():].strip()
101
+ else:
102
+ # If no tags found, use the whole text as final result
103
+ final_result = text.strip()
104
+
105
+ # Now extract JSON or YES/NO from the final result
106
+ # First, try to extract JSON from the final result
107
+ json_block_pattern = r'```(?:json)?\s*(\{.*?\})\s*```'
108
+ json_block_matches = re.findall(json_block_pattern, final_result, re.DOTALL | re.IGNORECASE)
109
+ for match in json_block_matches:
110
+ try:
111
+ parsed = json.loads(match)
112
+ if isinstance(parsed, dict) and "relevant" in parsed:
113
+ return {
114
+ "relevant": str(parsed.get("relevant", "")).upper(),
115
+ "reason": reasoning_content or parsed.get("reason", "")[:1000]
116
+ }
117
+ except json.JSONDecodeError:
118
+ continue
119
+
120
+ # Try to find JSON objects in the final result (handle nested braces)
121
+ brace_count = 0
122
+ start_pos = -1
123
+ for i, char in enumerate(final_result):
124
+ if char == '{':
125
+ if brace_count == 0:
126
+ start_pos = i
127
+ brace_count += 1
128
+ elif char == '}':
129
+ brace_count -= 1
130
+ if brace_count == 0 and start_pos >= 0:
131
+ # Found a complete JSON object
132
+ json_str = final_result[start_pos:i+1]
133
+ try:
134
+ parsed = json.loads(json_str)
135
+ if isinstance(parsed, dict) and "relevant" in parsed:
136
+ return {
137
+ "relevant": str(parsed.get("relevant", "")).upper(),
138
+ "reason": reasoning_content or parsed.get("reason", "")[:1000]
139
+ }
140
+ except json.JSONDecodeError:
141
+ pass
142
+ start_pos = -1
143
+
144
+ # If no JSON found, try to extract YES/NO from final result
145
+ relevant_patterns = [
146
+ r'"relevant"\s*:\s*["\']?(YES|NO)["\']?',
147
+ r'"relevant"\s*:\s*(YES|NO)',
148
+ r'relevant\s*[:=]\s*["\']?(YES|NO)["\']?',
149
+ r'answer\s*[:=]\s*["\']?(YES|NO)["\']?',
150
+ r'final\s+answer\s*[:=]\s*["\']?(YES|NO)["\']?',
151
+ r'\b(YES|NO)\b', # Last resort: find YES or NO in final result
152
+ ]
153
+
154
+ relevant = None
155
+ for pattern in relevant_patterns:
156
+ match = re.search(pattern, final_result, re.IGNORECASE)
157
+ if match:
158
+ relevant = match.group(1).upper()
159
+ break
160
+
161
+ # Use reasoning content as reason, or final result if no reasoning found
162
+ reason = reasoning_content if reasoning_content else final_result[:1000]
163
+
164
+ return {
165
+ "relevant": relevant or "NO", # Default to NO if not found
166
+ "reason": reason[:1000] if len(reason) > 1000 else reason # Limit reason length
167
+ }
168
+
169
+
170
+ async def call_llm(
171
+ messages: List[Dict[str, str]],
172
+ model: str,
173
+ base_url: str,
174
+ api_key: str,
175
+ pydantic_object: Optional[type[BaseModel]] = None,
176
+ log_file: str = "workdir/calls_llm.jsonl",
177
+ **kwargs,
178
+ ) -> Optional[Dict]:
179
+ """异步LLM调用,使用langchain结构化输出"""
180
+ # region agent log
181
+ debug_log_path = Path(__file__).parent.parent.parent / ".cursor" / "debug.log"
182
+ try:
183
+ with open(debug_log_path, "a", encoding="utf-8") as f:
184
+ log_entry = {
185
+ "sessionId": "debug-session",
186
+ "runId": "api-key-llm-call",
187
+ "hypothesisId": "B",
188
+ "location": "util.py:180",
189
+ "message": "API key passed to LLM",
190
+ "data": {
191
+ "base_url": base_url,
192
+ "model": model,
193
+ "api_key_length": len(api_key) if api_key else 0,
194
+ "api_key_prefix": api_key[:20] + "..." if api_key and len(api_key) > 20 else api_key,
195
+ "api_key_suffix": "..." + api_key[-10:] if api_key and len(api_key) > 10 else api_key,
196
+ "api_key_is_none": api_key == "none",
197
+ },
198
+ "timestamp": int(__import__("time").time() * 1000)
199
+ }
200
+ f.write(json.dumps(log_entry) + "\n")
201
+ except Exception:
202
+ pass
203
+ # endregion
204
+
205
+ llm = ChatOpenAI(model=model, base_url=base_url, api_key=api_key, **kwargs)
206
+
207
+ # 创建parser
208
+ parser = JsonOutputParser(pydantic_object=pydantic_object) if pydantic_object else JsonOutputParser()
209
+ fixing_parser = OutputFixingParser.from_llm(parser=parser, llm=llm)
210
+
211
+ # 打印输入日志
212
+ user_msgs = [msg for msg in messages if msg["role"] == "user"]
213
+ if user_msgs:
214
+ logger.info("=" * 80)
215
+ logger.info(f"📤 INPUT | 模型: {model}")
216
+ for msg in user_msgs:
217
+ logger.info(f"\n{msg['content']}")
218
+ logger.info("=" * 80)
219
+
220
+ # 使用 timeout 和重试机制
221
+ timeout = kwargs.pop("timeout", LLM_TIMEOUT)
222
+ response = None
223
+ last_exc: Optional[BaseException] = None
224
+
225
+ for attempt in range(1, LLM_MAX_RETRIES + 2): # +2 因为 range(1, n+2) 会执行 n+1 次(初始尝试 + n 次重试)
226
+ try:
227
+ async with CONCURRENCY_LIMIT:
228
+ # 使用 asyncio.wait_for 添加超时保护
229
+ response = await asyncio.wait_for(
230
+ llm.ainvoke(messages),
231
+ timeout=timeout
232
+ )
233
+ output = response.content
234
+ break # 成功则跳出重试循环
235
+ except asyncio.TimeoutError:
236
+ last_exc = asyncio.TimeoutError(f"LLM 调用超时({timeout}秒)")
237
+ logger.warning(f"⏱️ LLM 调用超时(第 {attempt}/{LLM_MAX_RETRIES + 1} 次): {base_url} | 模型: {model}")
238
+ if attempt <= LLM_MAX_RETRIES:
239
+ wait_time = LLM_RETRY_BACKOFF * attempt
240
+ logger.info(f"🔄 等待 {wait_time} 秒后重试(剩余 {LLM_MAX_RETRIES - attempt + 1} 次)...")
241
+ await asyncio.sleep(wait_time) # 指数退避
242
+ else:
243
+ logger.error(f"❌ LLM 调用最终超时(已尝试 {attempt} 次),放弃: {base_url} | 模型: {model}")
244
+ # 不再抛出异常,而是返回 None,让调用者处理
245
+ return None
246
+ except Exception as e:
247
+ last_exc = e
248
+ logger.warning(f"⚠️ LLM 调用失败(第 {attempt}/{LLM_MAX_RETRIES + 1} 次): {base_url} | 模型: {model} | 错误: {e}")
249
+ if attempt <= LLM_MAX_RETRIES:
250
+ wait_time = LLM_RETRY_BACKOFF * attempt
251
+ logger.info(f"🔄 等待 {wait_time} 秒后重试(剩余 {LLM_MAX_RETRIES - attempt + 1} 次)...")
252
+ await asyncio.sleep(wait_time) # 指数退避
253
+ else:
254
+ logger.error(f"❌ LLM 调用最终失败(已尝试 {attempt} 次),放弃: {base_url} | 模型: {model} | 错误: {e}")
255
+ # 不再抛出异常,而是返回 None
256
+ return None
257
+
258
+ # 如果到这里 response 还是 None,说明所有重试都失败了
259
+ if response is None:
260
+ logger.error("❌ LLM 调用失败:所有重试都失败")
261
+ return None
262
+
263
+ try:
264
+
265
+ # 打印输出日志
266
+ logger.info("=" * 80)
267
+ total_tokens = getattr(response, "usage_metadata", {}).get("total_tokens", "N/A")
268
+ input_tokens = getattr(response, "usage_metadata", {}).get("input_tokens", "N/A")
269
+ output_tokens = getattr(response, "usage_metadata", {}).get("output_tokens", "N/A")
270
+
271
+ logger.info(
272
+ f"📥 OUTPUT | total_tokens: {total_tokens} | input_tokens: {input_tokens} | output_tokens: {output_tokens}"
273
+ )
274
+ logger.info(f"\n{output}")
275
+ logger.info("=" * 80)
276
+
277
+ # Check if this is a reasoning model
278
+ is_reasoning_model = "qwen" in model.lower() or "reasoning" in model.lower()
279
+
280
+ # JSON解析(先直接解析,失败则自动修复)
281
+ parsed = None
282
+ try:
283
+ parsed = parser.invoke(response)
284
+ # Validate parsed result for reasoning models
285
+ if is_reasoning_model and isinstance(parsed, dict):
286
+ # Check if parsed result has valid relevant field (for RelevanceResult type)
287
+ relevant = parsed.get("relevant", "")
288
+ if relevant and isinstance(relevant, str):
289
+ relevant_upper = relevant.upper()
290
+ if relevant_upper not in ["YES", "NO"]:
291
+ # Invalid format, need post-processing
292
+ logger.warning(f"推理模型响应格式无效,进行后处理: {relevant}")
293
+ parsed = extract_final_answer_from_reasoning(output, pydantic_object)
294
+ else:
295
+ # Update to ensure uppercase
296
+ parsed["relevant"] = relevant_upper
297
+ except Exception as e:
298
+ logger.warning(f"直接解析失败,尝试修复: {e}")
299
+ try:
300
+ parsed = fixing_parser.invoke(response)
301
+ # Validate again after fixing
302
+ if is_reasoning_model and isinstance(parsed, dict):
303
+ relevant = parsed.get("relevant", "")
304
+ if relevant and isinstance(relevant, str) and relevant.upper() not in ["YES", "NO"]:
305
+ logger.warning(f"修复后格式仍无效,使用后处理: {relevant}")
306
+ parsed = extract_final_answer_from_reasoning(output, pydantic_object)
307
+ except Exception as e2:
308
+ logger.warning(f"修复解析也失败: {e2}")
309
+ # For reasoning models, try post-processing
310
+ if is_reasoning_model:
311
+ logger.info("使用后处理函数提取推理模型的最终答案")
312
+ parsed = extract_final_answer_from_reasoning(output, pydantic_object)
313
+ else:
314
+ raise e2
315
+
316
+ # 记录完整日志
317
+ log_api(
318
+ log_file,
319
+ {
320
+ "input": messages,
321
+ "output": response.dict() if hasattr(response, "dict") else str(response),
322
+ "parsed": parsed,
323
+ },
324
+ )
325
+
326
+ return parsed # 直接返回解析后的字典
327
+
328
+ except Exception as e:
329
+ logger.error(f"❌ LLM调用失败: {e}")
330
+ raise