|
|
|
|
|
from loguru import logger |
|
|
import os |
|
|
import json |
|
|
|
|
|
from datetime import datetime |
|
|
from typing import List, Dict, Any, Optional |
|
|
from pathlib import Path |
|
|
from langchain_openai import ChatOpenAI |
|
|
from langchain_core.output_parsers import JsonOutputParser |
|
|
from langchain.output_parsers import OutputFixingParser |
|
|
from pydantic import BaseModel |
|
|
import asyncio |
|
|
import re |
|
|
|
|
|
|
|
|
CONCURRENCY_LIMIT = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENCY", "8"))) |
|
|
|
|
|
|
|
|
LLM_TIMEOUT = int(os.getenv("LLM_TIMEOUT", "180")) |
|
|
LLM_MAX_RETRIES = int(os.getenv("LLM_MAX_RETRIES", "1")) |
|
|
LLM_RETRY_BACKOFF = float(os.getenv("LLM_RETRY_BACKOFF", "2.0")) |
|
|
|
|
|
|
|
|
CODE_EXTENSIONS = { |
|
|
|
|
|
".py", ".ipynb", |
|
|
|
|
|
".c", ".cpp", ".cc", ".cxx", ".h", ".hpp", ".hh", |
|
|
|
|
|
".f", ".f90", ".f95", ".for", |
|
|
|
|
|
".jl", |
|
|
|
|
|
".r", ".R", |
|
|
|
|
|
".java", |
|
|
|
|
|
".m", |
|
|
|
|
|
".sh", ".bash", |
|
|
".rs", |
|
|
".go", |
|
|
|
|
|
".md", ".markdown", |
|
|
} |
|
|
|
|
|
|
|
|
def init_logger(log_file: str, level: str = "INFO"): |
|
|
"""Initialize logger with color output""" |
|
|
os.makedirs(os.path.dirname(log_file), exist_ok=True) |
|
|
logger.remove() |
|
|
|
|
|
logger.add( |
|
|
sink=lambda msg: print(msg, end=""), |
|
|
level=level, |
|
|
colorize=True, |
|
|
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>", |
|
|
) |
|
|
|
|
|
logger.add( |
|
|
Path(log_file), |
|
|
level=level, |
|
|
rotation="1 day", |
|
|
encoding="utf-8", |
|
|
format="{time:YYYY-MM-DD HH:mm:ss} [{level}] ({name}:{function}:{line}) {message}", |
|
|
) |
|
|
return logger |
|
|
|
|
|
|
|
|
def log_api(log_file: str, data: Dict): |
|
|
"""Log API call to file""" |
|
|
with open(log_file, "a", encoding="utf-8") as f: |
|
|
record = {"time": datetime.utcnow().isoformat() + "Z", "data": data} |
|
|
f.write(json.dumps(record, ensure_ascii=False) + "\n") |
|
|
|
|
|
|
|
|
def extract_final_answer_from_reasoning(text: str, pydantic_object: Optional[type[BaseModel]] = None) -> Dict: |
|
|
""" |
|
|
Extract final answer from reasoning model output. |
|
|
For reasoning models like Qwen3, the response format is: |
|
|
<think>reasoning content</think>final result |
|
|
|
|
|
Args: |
|
|
text: Raw response text from reasoning model |
|
|
pydantic_object: Expected Pydantic model for structured output |
|
|
|
|
|
Returns: |
|
|
Dict with extracted relevant and reason fields |
|
|
""" |
|
|
|
|
|
reasoning_pattern = r'<think>(.*?)</think>' |
|
|
reasoning_match = re.search(reasoning_pattern, text, re.DOTALL | re.IGNORECASE) |
|
|
reasoning_content = reasoning_match.group(1).strip() if reasoning_match else "" |
|
|
|
|
|
|
|
|
final_result = "" |
|
|
if reasoning_match: |
|
|
|
|
|
final_result = text[reasoning_match.end():].strip() |
|
|
else: |
|
|
|
|
|
final_result = text.strip() |
|
|
|
|
|
|
|
|
|
|
|
json_block_pattern = r'```(?:json)?\s*(\{.*?\})\s*```' |
|
|
json_block_matches = re.findall(json_block_pattern, final_result, re.DOTALL | re.IGNORECASE) |
|
|
for match in json_block_matches: |
|
|
try: |
|
|
parsed = json.loads(match) |
|
|
if isinstance(parsed, dict) and "relevant" in parsed: |
|
|
return { |
|
|
"relevant": str(parsed.get("relevant", "")).upper(), |
|
|
"reason": reasoning_content or parsed.get("reason", "")[:1000] |
|
|
} |
|
|
except json.JSONDecodeError: |
|
|
continue |
|
|
|
|
|
|
|
|
brace_count = 0 |
|
|
start_pos = -1 |
|
|
for i, char in enumerate(final_result): |
|
|
if char == '{': |
|
|
if brace_count == 0: |
|
|
start_pos = i |
|
|
brace_count += 1 |
|
|
elif char == '}': |
|
|
brace_count -= 1 |
|
|
if brace_count == 0 and start_pos >= 0: |
|
|
|
|
|
json_str = final_result[start_pos:i+1] |
|
|
try: |
|
|
parsed = json.loads(json_str) |
|
|
if isinstance(parsed, dict) and "relevant" in parsed: |
|
|
return { |
|
|
"relevant": str(parsed.get("relevant", "")).upper(), |
|
|
"reason": reasoning_content or parsed.get("reason", "")[:1000] |
|
|
} |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
start_pos = -1 |
|
|
|
|
|
|
|
|
relevant_patterns = [ |
|
|
r'"relevant"\s*:\s*["\']?(YES|NO)["\']?', |
|
|
r'"relevant"\s*:\s*(YES|NO)', |
|
|
r'relevant\s*[:=]\s*["\']?(YES|NO)["\']?', |
|
|
r'answer\s*[:=]\s*["\']?(YES|NO)["\']?', |
|
|
r'final\s+answer\s*[:=]\s*["\']?(YES|NO)["\']?', |
|
|
r'\b(YES|NO)\b', |
|
|
] |
|
|
|
|
|
relevant = None |
|
|
for pattern in relevant_patterns: |
|
|
match = re.search(pattern, final_result, re.IGNORECASE) |
|
|
if match: |
|
|
relevant = match.group(1).upper() |
|
|
break |
|
|
|
|
|
|
|
|
reason = reasoning_content if reasoning_content else final_result[:1000] |
|
|
|
|
|
return { |
|
|
"relevant": relevant or "NO", |
|
|
"reason": reason[:1000] if len(reason) > 1000 else reason |
|
|
} |
|
|
|
|
|
|
|
|
async def call_llm( |
|
|
messages: List[Dict[str, str]], |
|
|
model: str, |
|
|
base_url: str, |
|
|
api_key: str, |
|
|
pydantic_object: Optional[type[BaseModel]] = None, |
|
|
log_file: str = "workdir/calls_llm.jsonl", |
|
|
**kwargs, |
|
|
) -> Optional[Dict]: |
|
|
"""异步LLM调用,使用langchain结构化输出""" |
|
|
|
|
|
debug_log_path = Path(__file__).parent.parent.parent / ".cursor" / "debug.log" |
|
|
try: |
|
|
with open(debug_log_path, "a", encoding="utf-8") as f: |
|
|
log_entry = { |
|
|
"sessionId": "debug-session", |
|
|
"runId": "api-key-llm-call", |
|
|
"hypothesisId": "B", |
|
|
"location": "util.py:180", |
|
|
"message": "API key passed to LLM", |
|
|
"data": { |
|
|
"base_url": base_url, |
|
|
"model": model, |
|
|
"api_key_length": len(api_key) if api_key else 0, |
|
|
"api_key_prefix": api_key[:20] + "..." if api_key and len(api_key) > 20 else api_key, |
|
|
"api_key_suffix": "..." + api_key[-10:] if api_key and len(api_key) > 10 else api_key, |
|
|
"api_key_is_none": api_key == "none", |
|
|
}, |
|
|
"timestamp": int(__import__("time").time() * 1000) |
|
|
} |
|
|
f.write(json.dumps(log_entry) + "\n") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
llm = ChatOpenAI(model=model, base_url=base_url, api_key=api_key, **kwargs) |
|
|
|
|
|
|
|
|
parser = JsonOutputParser(pydantic_object=pydantic_object) if pydantic_object else JsonOutputParser() |
|
|
fixing_parser = OutputFixingParser.from_llm(parser=parser, llm=llm) |
|
|
|
|
|
|
|
|
user_msgs = [msg for msg in messages if msg["role"] == "user"] |
|
|
if user_msgs: |
|
|
logger.info("=" * 80) |
|
|
logger.info(f"📤 INPUT | 模型: {model}") |
|
|
for msg in user_msgs: |
|
|
logger.info(f"\n{msg['content']}") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
|
|
|
timeout = kwargs.pop("timeout", LLM_TIMEOUT) |
|
|
response = None |
|
|
last_exc: Optional[BaseException] = None |
|
|
|
|
|
for attempt in range(1, LLM_MAX_RETRIES + 2): |
|
|
try: |
|
|
async with CONCURRENCY_LIMIT: |
|
|
|
|
|
response = await asyncio.wait_for( |
|
|
llm.ainvoke(messages), |
|
|
timeout=timeout |
|
|
) |
|
|
output = response.content |
|
|
break |
|
|
except asyncio.TimeoutError: |
|
|
last_exc = asyncio.TimeoutError(f"LLM 调用超时({timeout}秒)") |
|
|
logger.warning(f"⏱️ LLM 调用超时(第 {attempt}/{LLM_MAX_RETRIES + 1} 次): {base_url} | 模型: {model}") |
|
|
if attempt <= LLM_MAX_RETRIES: |
|
|
wait_time = LLM_RETRY_BACKOFF * attempt |
|
|
logger.info(f"🔄 等待 {wait_time} 秒后重试(剩余 {LLM_MAX_RETRIES - attempt + 1} 次)...") |
|
|
await asyncio.sleep(wait_time) |
|
|
else: |
|
|
logger.error(f"❌ LLM 调用最终超时(已尝试 {attempt} 次),放弃: {base_url} | 模型: {model}") |
|
|
|
|
|
return None |
|
|
except Exception as e: |
|
|
last_exc = e |
|
|
logger.warning(f"⚠️ LLM 调用失败(第 {attempt}/{LLM_MAX_RETRIES + 1} 次): {base_url} | 模型: {model} | 错误: {e}") |
|
|
if attempt <= LLM_MAX_RETRIES: |
|
|
wait_time = LLM_RETRY_BACKOFF * attempt |
|
|
logger.info(f"🔄 等待 {wait_time} 秒后重试(剩余 {LLM_MAX_RETRIES - attempt + 1} 次)...") |
|
|
await asyncio.sleep(wait_time) |
|
|
else: |
|
|
logger.error(f"❌ LLM 调用最终失败(已尝试 {attempt} 次),放弃: {base_url} | 模型: {model} | 错误: {e}") |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
if response is None: |
|
|
logger.error("❌ LLM 调用失败:所有重试都失败") |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
logger.info("=" * 80) |
|
|
total_tokens = getattr(response, "usage_metadata", {}).get("total_tokens", "N/A") |
|
|
input_tokens = getattr(response, "usage_metadata", {}).get("input_tokens", "N/A") |
|
|
output_tokens = getattr(response, "usage_metadata", {}).get("output_tokens", "N/A") |
|
|
|
|
|
logger.info( |
|
|
f"📥 OUTPUT | total_tokens: {total_tokens} | input_tokens: {input_tokens} | output_tokens: {output_tokens}" |
|
|
) |
|
|
logger.info(f"\n{output}") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
|
|
|
is_reasoning_model = "qwen" in model.lower() or "reasoning" in model.lower() |
|
|
|
|
|
|
|
|
parsed = None |
|
|
try: |
|
|
parsed = parser.invoke(response) |
|
|
|
|
|
if is_reasoning_model and isinstance(parsed, dict): |
|
|
|
|
|
relevant = parsed.get("relevant", "") |
|
|
if relevant and isinstance(relevant, str): |
|
|
relevant_upper = relevant.upper() |
|
|
if relevant_upper not in ["YES", "NO"]: |
|
|
|
|
|
logger.warning(f"推理模型响应格式无效,进行后处理: {relevant}") |
|
|
parsed = extract_final_answer_from_reasoning(output, pydantic_object) |
|
|
else: |
|
|
|
|
|
parsed["relevant"] = relevant_upper |
|
|
except Exception as e: |
|
|
logger.warning(f"直接解析失败,尝试修复: {e}") |
|
|
try: |
|
|
parsed = fixing_parser.invoke(response) |
|
|
|
|
|
if is_reasoning_model and isinstance(parsed, dict): |
|
|
relevant = parsed.get("relevant", "") |
|
|
if relevant and isinstance(relevant, str) and relevant.upper() not in ["YES", "NO"]: |
|
|
logger.warning(f"修复后格式仍无效,使用后处理: {relevant}") |
|
|
parsed = extract_final_answer_from_reasoning(output, pydantic_object) |
|
|
except Exception as e2: |
|
|
logger.warning(f"修复解析也失败: {e2}") |
|
|
|
|
|
if is_reasoning_model: |
|
|
logger.info("使用后处理函数提取推理模型的最终答案") |
|
|
parsed = extract_final_answer_from_reasoning(output, pydantic_object) |
|
|
else: |
|
|
raise e2 |
|
|
|
|
|
|
|
|
log_api( |
|
|
log_file, |
|
|
{ |
|
|
"input": messages, |
|
|
"output": response.dict() if hasattr(response, "dict") else str(response), |
|
|
"parsed": parsed, |
|
|
}, |
|
|
) |
|
|
|
|
|
return parsed |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ LLM调用失败: {e}") |
|
|
raise |
|
|
|