SunDou's picture
Upload data1/util.py with huggingface_hub
8629ccf verified
# util.py
from loguru import logger
import os
import json
from datetime import datetime
from typing import List, Dict, Any, Optional
from pathlib import Path
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import JsonOutputParser
from langchain.output_parsers import OutputFixingParser
from pydantic import BaseModel
import asyncio
import re
# Global concurrency limit
CONCURRENCY_LIMIT = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENCY", "8")))
# LLM 调用的超时与重试配置(可通过环境变量覆盖)
LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT", "180")) # 单次 LLM 调用超时秒数(默认 3 分钟)
LLM_MAX_RETRIES = int(os.getenv("LLM_MAX_RETRIES", "1")) # 最多重试次数(默认 1 次)
LLM_RETRY_BACKOFF = float(os.getenv("LLM_RETRY_BACKOFF", "2.0")) # 重试等待时间(秒)
# Common code file extensions
CODE_EXTENSIONS = {
# Python
".py", ".ipynb",
# C/C++
".c", ".cpp", ".cc", ".cxx", ".h", ".hpp", ".hh",
# Fortran
".f", ".f90", ".f95", ".for",
# Julia
".jl",
# R
".r", ".R",
# Java
".java",
# MATLAB/Octave
".m",
# Shell脚本
".sh", ".bash",
".rs", # Rust
".go", # Go
# Markdown
".md", ".markdown",
}
def init_logger(log_file: str, level: str = "INFO"):
"""Initialize logger with color output"""
os.makedirs(os.path.dirname(log_file), exist_ok=True)
logger.remove()
# Console output with color
logger.add(
sink=lambda msg: print(msg, end=""),
level=level,
colorize=True,
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
)
# File output
logger.add(
Path(log_file),
level=level,
rotation="1 day",
encoding="utf-8",
format="{time:YYYY-MM-DD HH:mm:ss} [{level}] ({name}:{function}:{line}) {message}",
)
return logger
def log_api(log_file: str, data: Dict):
"""Log API call to file"""
with open(log_file, "a", encoding="utf-8") as f:
record = {"time": datetime.utcnow().isoformat() + "Z", "data": data}
f.write(json.dumps(record, ensure_ascii=False) + "\n")
def extract_final_answer_from_reasoning(text: str, pydantic_object: Optional[type[BaseModel]] = None) -> Dict:
"""
Extract final answer from reasoning model output.
For reasoning models like Qwen3, the response format is:
<think>reasoning content</think>final result
Args:
text: Raw response text from reasoning model
pydantic_object: Expected Pydantic model for structured output
Returns:
Dict with extracted relevant and reason fields
"""
# Extract reasoning content from <think>...</think> tags
reasoning_pattern = r'<think>(.*?)</think>'
reasoning_match = re.search(reasoning_pattern, text, re.DOTALL | re.IGNORECASE)
reasoning_content = reasoning_match.group(1).strip() if reasoning_match else ""
# Extract final result (content after </think> tag)
final_result = ""
if reasoning_match:
# Get everything after the closing tag
final_result = text[reasoning_match.end():].strip()
else:
# If no tags found, use the whole text as final result
final_result = text.strip()
# Now extract JSON or YES/NO from the final result
# First, try to extract JSON from the final result
json_block_pattern = r'```(?:json)?\s*(\{.*?\})\s*```'
json_block_matches = re.findall(json_block_pattern, final_result, re.DOTALL | re.IGNORECASE)
for match in json_block_matches:
try:
parsed = json.loads(match)
if isinstance(parsed, dict) and "relevant" in parsed:
return {
"relevant": str(parsed.get("relevant", "")).upper(),
"reason": reasoning_content or parsed.get("reason", "")[:1000]
}
except json.JSONDecodeError:
continue
# Try to find JSON objects in the final result (handle nested braces)
brace_count = 0
start_pos = -1
for i, char in enumerate(final_result):
if char == '{':
if brace_count == 0:
start_pos = i
brace_count += 1
elif char == '}':
brace_count -= 1
if brace_count == 0 and start_pos >= 0:
# Found a complete JSON object
json_str = final_result[start_pos:i+1]
try:
parsed = json.loads(json_str)
if isinstance(parsed, dict) and "relevant" in parsed:
return {
"relevant": str(parsed.get("relevant", "")).upper(),
"reason": reasoning_content or parsed.get("reason", "")[:1000]
}
except json.JSONDecodeError:
pass
start_pos = -1
# If no JSON found, try to extract YES/NO from final result
relevant_patterns = [
r'"relevant"\s*:\s*["\']?(YES|NO)["\']?',
r'"relevant"\s*:\s*(YES|NO)',
r'relevant\s*[:=]\s*["\']?(YES|NO)["\']?',
r'answer\s*[:=]\s*["\']?(YES|NO)["\']?',
r'final\s+answer\s*[:=]\s*["\']?(YES|NO)["\']?',
r'\b(YES|NO)\b', # Last resort: find YES or NO in final result
]
relevant = None
for pattern in relevant_patterns:
match = re.search(pattern, final_result, re.IGNORECASE)
if match:
relevant = match.group(1).upper()
break
# Use reasoning content as reason, or final result if no reasoning found
reason = reasoning_content if reasoning_content else final_result[:1000]
return {
"relevant": relevant or "NO", # Default to NO if not found
"reason": reason[:1000] if len(reason) > 1000 else reason # Limit reason length
}
async def call_llm(
messages: List[Dict[str, str]],
model: str,
base_url: str,
api_key: str,
pydantic_object: Optional[type[BaseModel]] = None,
log_file: str = "workdir/calls_llm.jsonl",
**kwargs,
) -> Optional[Dict]:
"""异步LLM调用,使用langchain结构化输出"""
# region agent log
debug_log_path = Path(__file__).parent.parent.parent / ".cursor" / "debug.log"
try:
with open(debug_log_path, "a", encoding="utf-8") as f:
log_entry = {
"sessionId": "debug-session",
"runId": "api-key-llm-call",
"hypothesisId": "B",
"location": "util.py:180",
"message": "API key passed to LLM",
"data": {
"base_url": base_url,
"model": model,
"api_key_length": len(api_key) if api_key else 0,
"api_key_prefix": api_key[:20] + "..." if api_key and len(api_key) > 20 else api_key,
"api_key_suffix": "..." + api_key[-10:] if api_key and len(api_key) > 10 else api_key,
"api_key_is_none": api_key == "none",
},
"timestamp": int(__import__("time").time() * 1000)
}
f.write(json.dumps(log_entry) + "\n")
except Exception:
pass
# endregion
llm = ChatOpenAI(model=model, base_url=base_url, api_key=api_key, **kwargs)
# 创建parser
parser = JsonOutputParser(pydantic_object=pydantic_object) if pydantic_object else JsonOutputParser()
fixing_parser = OutputFixingParser.from_llm(parser=parser, llm=llm)
# 打印输入日志
user_msgs = [msg for msg in messages if msg["role"] == "user"]
if user_msgs:
logger.info("=" * 80)
logger.info(f"📤 INPUT | 模型: {model}")
for msg in user_msgs:
logger.info(f"\n{msg['content']}")
logger.info("=" * 80)
# 使用 timeout 和重试机制
timeout = kwargs.pop("timeout", LLM_TIMEOUT)
response = None
last_exc: Optional[BaseException] = None
for attempt in range(1, LLM_MAX_RETRIES + 2): # +2 因为 range(1, n+2) 会执行 n+1 次(初始尝试 + n 次重试)
try:
async with CONCURRENCY_LIMIT:
# 使用 asyncio.wait_for 添加超时保护
response = await asyncio.wait_for(
llm.ainvoke(messages),
timeout=timeout
)
output = response.content
break # 成功则跳出重试循环
except asyncio.TimeoutError:
last_exc = asyncio.TimeoutError(f"LLM 调用超时({timeout}秒)")
logger.warning(f"⏱️ LLM 调用超时(第 {attempt}/{LLM_MAX_RETRIES + 1} 次): {base_url} | 模型: {model}")
if attempt <= LLM_MAX_RETRIES:
wait_time = LLM_RETRY_BACKOFF * attempt
logger.info(f"🔄 等待 {wait_time} 秒后重试(剩余 {LLM_MAX_RETRIES - attempt + 1} 次)...")
await asyncio.sleep(wait_time) # 指数退避
else:
logger.error(f"❌ LLM 调用最终超时(已尝试 {attempt} 次),放弃: {base_url} | 模型: {model}")
# 不再抛出异常,而是返回 None,让调用者处理
return None
except Exception as e:
last_exc = e
logger.warning(f"⚠️ LLM 调用失败(第 {attempt}/{LLM_MAX_RETRIES + 1} 次): {base_url} | 模型: {model} | 错误: {e}")
if attempt <= LLM_MAX_RETRIES:
wait_time = LLM_RETRY_BACKOFF * attempt
logger.info(f"🔄 等待 {wait_time} 秒后重试(剩余 {LLM_MAX_RETRIES - attempt + 1} 次)...")
await asyncio.sleep(wait_time) # 指数退避
else:
logger.error(f"❌ LLM 调用最终失败(已尝试 {attempt} 次),放弃: {base_url} | 模型: {model} | 错误: {e}")
# 不再抛出异常,而是返回 None
return None
# 如果到这里 response 还是 None,说明所有重试都失败了
if response is None:
logger.error("❌ LLM 调用失败:所有重试都失败")
return None
try:
# 打印输出日志
logger.info("=" * 80)
total_tokens = getattr(response, "usage_metadata", {}).get("total_tokens", "N/A")
input_tokens = getattr(response, "usage_metadata", {}).get("input_tokens", "N/A")
output_tokens = getattr(response, "usage_metadata", {}).get("output_tokens", "N/A")
logger.info(
f"📥 OUTPUT | total_tokens: {total_tokens} | input_tokens: {input_tokens} | output_tokens: {output_tokens}"
)
logger.info(f"\n{output}")
logger.info("=" * 80)
# Check if this is a reasoning model
is_reasoning_model = "qwen" in model.lower() or "reasoning" in model.lower()
# JSON解析(先直接解析,失败则自动修复)
parsed = None
try:
parsed = parser.invoke(response)
# Validate parsed result for reasoning models
if is_reasoning_model and isinstance(parsed, dict):
# Check if parsed result has valid relevant field (for RelevanceResult type)
relevant = parsed.get("relevant", "")
if relevant and isinstance(relevant, str):
relevant_upper = relevant.upper()
if relevant_upper not in ["YES", "NO"]:
# Invalid format, need post-processing
logger.warning(f"推理模型响应格式无效,进行后处理: {relevant}")
parsed = extract_final_answer_from_reasoning(output, pydantic_object)
else:
# Update to ensure uppercase
parsed["relevant"] = relevant_upper
except Exception as e:
logger.warning(f"直接解析失败,尝试修复: {e}")
try:
parsed = fixing_parser.invoke(response)
# Validate again after fixing
if is_reasoning_model and isinstance(parsed, dict):
relevant = parsed.get("relevant", "")
if relevant and isinstance(relevant, str) and relevant.upper() not in ["YES", "NO"]:
logger.warning(f"修复后格式仍无效,使用后处理: {relevant}")
parsed = extract_final_answer_from_reasoning(output, pydantic_object)
except Exception as e2:
logger.warning(f"修复解析也失败: {e2}")
# For reasoning models, try post-processing
if is_reasoning_model:
logger.info("使用后处理函数提取推理模型的最终答案")
parsed = extract_final_answer_from_reasoning(output, pydantic_object)
else:
raise e2
# 记录完整日志
log_api(
log_file,
{
"input": messages,
"output": response.dict() if hasattr(response, "dict") else str(response),
"parsed": parsed,
},
)
return parsed # 直接返回解析后的字典
except Exception as e:
logger.error(f"❌ LLM调用失败: {e}")
raise