|
|
""" |
|
|
Model inference and client management for AnyCoder. |
|
|
Handles different model providers and inference clients. |
|
|
""" |
|
|
import os |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
import re |
|
|
from http import HTTPStatus |
|
|
|
|
|
from huggingface_hub import InferenceClient |
|
|
from openai import OpenAI |
|
|
from mistralai import Mistral |
|
|
import dashscope |
|
|
|
|
|
from .config import HF_TOKEN, AVAILABLE_MODELS |
|
|
|
|
|
|
|
|
History = List[Dict[str, str]] |
|
|
Messages = List[Dict[str, str]] |
|
|
|
|
|
def get_inference_client(model_id, provider="auto"): |
|
|
"""Return an InferenceClient with provider based on model_id and user selection.""" |
|
|
if model_id == "qwen3-30b-a3b-instruct-2507": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("DASHSCOPE_API_KEY"), |
|
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", |
|
|
) |
|
|
elif model_id == "qwen3-30b-a3b-thinking-2507": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("DASHSCOPE_API_KEY"), |
|
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", |
|
|
) |
|
|
elif model_id == "qwen3-coder-30b-a3b-instruct": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("DASHSCOPE_API_KEY"), |
|
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", |
|
|
) |
|
|
elif model_id == "gpt-5": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("POE_API_KEY"), |
|
|
base_url="https://api.poe.com/v1" |
|
|
) |
|
|
elif model_id == "grok-4": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("POE_API_KEY"), |
|
|
base_url="https://api.poe.com/v1" |
|
|
) |
|
|
elif model_id == "Grok-Code-Fast-1": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("POE_API_KEY"), |
|
|
base_url="https://api.poe.com/v1" |
|
|
) |
|
|
elif model_id == "claude-opus-4.1": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("POE_API_KEY"), |
|
|
base_url="https://api.poe.com/v1" |
|
|
) |
|
|
elif model_id == "claude-sonnet-4.5": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("POE_API_KEY"), |
|
|
base_url="https://api.poe.com/v1" |
|
|
) |
|
|
elif model_id == "claude-haiku-4.5": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("POE_API_KEY"), |
|
|
base_url="https://api.poe.com/v1" |
|
|
) |
|
|
elif model_id == "qwen3-max-preview": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("DASHSCOPE_API_KEY"), |
|
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", |
|
|
) |
|
|
elif model_id == "openrouter/sonoma-dusk-alpha": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("OPENROUTER_API_KEY"), |
|
|
base_url="https://openrouter.ai/api/v1", |
|
|
) |
|
|
elif model_id == "openrouter/sonoma-sky-alpha": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("OPENROUTER_API_KEY"), |
|
|
base_url="https://openrouter.ai/api/v1", |
|
|
) |
|
|
elif model_id == "MiniMaxAI/MiniMax-M2": |
|
|
|
|
|
provider = "novita" |
|
|
elif model_id == "step-3": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("STEP_API_KEY"), |
|
|
base_url="https://api.stepfun.com/v1" |
|
|
) |
|
|
elif model_id == "codestral-2508" or model_id == "mistral-medium-2508": |
|
|
|
|
|
return Mistral(api_key=os.getenv("MISTRAL_API_KEY")) |
|
|
elif model_id == "gemini-2.5-flash": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("GEMINI_API_KEY"), |
|
|
base_url="https://generativelanguage.googleapis.com/v1beta/openai/", |
|
|
) |
|
|
elif model_id == "gemini-2.5-pro": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("GEMINI_API_KEY"), |
|
|
base_url="https://generativelanguage.googleapis.com/v1beta/openai/", |
|
|
) |
|
|
elif model_id == "gemini-flash-latest": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("GEMINI_API_KEY"), |
|
|
base_url="https://generativelanguage.googleapis.com/v1beta/openai/", |
|
|
) |
|
|
elif model_id == "gemini-flash-lite-latest": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("GEMINI_API_KEY"), |
|
|
base_url="https://generativelanguage.googleapis.com/v1beta/openai/", |
|
|
) |
|
|
elif model_id == "kimi-k2-turbo-preview": |
|
|
|
|
|
return OpenAI( |
|
|
api_key=os.getenv("MOONSHOT_API_KEY"), |
|
|
base_url="https://api.moonshot.ai/v1", |
|
|
) |
|
|
elif model_id == "moonshotai/Kimi-K2-Thinking": |
|
|
|
|
|
provider = "novita" |
|
|
elif model_id == "stealth-model-1": |
|
|
|
|
|
api_key = os.getenv("STEALTH_MODEL_1_API_KEY") |
|
|
if not api_key: |
|
|
raise ValueError("STEALTH_MODEL_1_API_KEY environment variable is required for Carrot model") |
|
|
|
|
|
base_url = os.getenv("STEALTH_MODEL_1_BASE_URL") |
|
|
if not base_url: |
|
|
raise ValueError("STEALTH_MODEL_1_BASE_URL environment variable is required for Carrot model") |
|
|
|
|
|
return OpenAI( |
|
|
api_key=api_key, |
|
|
base_url=base_url, |
|
|
) |
|
|
elif model_id == "moonshotai/Kimi-K2-Instruct": |
|
|
provider = "groq" |
|
|
elif model_id == "deepseek-ai/DeepSeek-V3.1": |
|
|
provider = "novita" |
|
|
elif model_id == "deepseek-ai/DeepSeek-V3.1-Terminus": |
|
|
provider = "novita" |
|
|
elif model_id == "deepseek-ai/DeepSeek-V3.2-Exp": |
|
|
provider = "novita" |
|
|
elif model_id == "zai-org/GLM-4.5": |
|
|
provider = "fireworks-ai" |
|
|
elif model_id == "zai-org/GLM-4.6": |
|
|
|
|
|
provider = "auto" |
|
|
return InferenceClient( |
|
|
provider=provider, |
|
|
api_key=HF_TOKEN, |
|
|
bill_to="huggingface" |
|
|
) |
|
|
|
|
|
|
|
|
def get_real_model_id(model_id: str) -> str: |
|
|
"""Get the real model ID, checking environment variables for stealth models and handling special model formats""" |
|
|
if model_id == "stealth-model-1": |
|
|
|
|
|
real_model_id = os.getenv("STEALTH_MODEL_1_ID") |
|
|
if not real_model_id: |
|
|
raise ValueError("STEALTH_MODEL_1_ID environment variable is required for Carrot model") |
|
|
|
|
|
return real_model_id |
|
|
elif model_id == "zai-org/GLM-4.6": |
|
|
|
|
|
return "zai-org/GLM-4.6:zai-org" |
|
|
return model_id |
|
|
|
|
|
|
|
|
History = List[Tuple[str, str]] |
|
|
Messages = List[Dict[str, str]] |
|
|
|
|
|
def history_to_messages(history: History, system: str) -> Messages: |
|
|
messages = [{'role': 'system', 'content': system}] |
|
|
for h in history: |
|
|
|
|
|
user_content = h[0] |
|
|
if isinstance(user_content, list): |
|
|
|
|
|
text_content = "" |
|
|
for item in user_content: |
|
|
if isinstance(item, dict) and item.get("type") == "text": |
|
|
text_content += item.get("text", "") |
|
|
user_content = text_content if text_content else str(user_content) |
|
|
|
|
|
messages.append({'role': 'user', 'content': user_content}) |
|
|
messages.append({'role': 'assistant', 'content': h[1]}) |
|
|
return messages |
|
|
|
|
|
def history_to_chatbot_messages(history: History) -> List[Dict[str, str]]: |
|
|
"""Convert history tuples to chatbot message format""" |
|
|
messages = [] |
|
|
for user_msg, assistant_msg in history: |
|
|
|
|
|
if isinstance(user_msg, list): |
|
|
text_content = "" |
|
|
for item in user_msg: |
|
|
if isinstance(item, dict) and item.get("type") == "text": |
|
|
text_content += item.get("text", "") |
|
|
user_msg = text_content if text_content else str(user_msg) |
|
|
|
|
|
messages.append({"role": "user", "content": user_msg}) |
|
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
|
return messages |
|
|
|
|
|
def strip_tool_call_markers(text): |
|
|
"""Remove TOOL_CALL markers that some LLMs (like Qwen) add to their output.""" |
|
|
if not text: |
|
|
return text |
|
|
|
|
|
text = re.sub(r'\[/?TOOL_CALL\]', '', text, flags=re.IGNORECASE) |
|
|
|
|
|
|
|
|
text = re.sub(r'^\s*\}\}\s*$', '', text, flags=re.MULTILINE) |
|
|
return text.strip() |
|
|
|
|
|
def remove_code_block(text): |
|
|
|
|
|
text = strip_tool_call_markers(text) |
|
|
|
|
|
|
|
|
patterns = [ |
|
|
r'```(?:html|HTML)\n([\s\S]+?)\n```', |
|
|
r'```\n([\s\S]+?)\n```', |
|
|
r'```([\s\S]+?)```' |
|
|
] |
|
|
for pattern in patterns: |
|
|
match = re.search(pattern, text, re.DOTALL) |
|
|
if match: |
|
|
extracted = match.group(1).strip() |
|
|
|
|
|
if extracted.split('\n', 1)[0].strip().lower() in ['python', 'html', 'css', 'javascript', 'json', 'c', 'cpp', 'markdown', 'latex', 'jinja2', 'typescript', 'yaml', 'dockerfile', 'shell', 'r', 'sql', 'sql-mssql', 'sql-mysql', 'sql-mariadb', 'sql-sqlite', 'sql-cassandra', 'sql-plSQL', 'sql-hive', 'sql-pgsql', 'sql-gql', 'sql-gpsql', 'sql-sparksql', 'sql-esper']: |
|
|
return extracted.split('\n', 1)[1] if '\n' in extracted else '' |
|
|
|
|
|
html_root_idx = None |
|
|
for tag in ['<!DOCTYPE html', '<html']: |
|
|
idx = extracted.find(tag) |
|
|
if idx != -1: |
|
|
html_root_idx = idx if html_root_idx is None else min(html_root_idx, idx) |
|
|
if html_root_idx is not None and html_root_idx > 0: |
|
|
return extracted[html_root_idx:].strip() |
|
|
return extracted |
|
|
|
|
|
stripped = text.strip() |
|
|
if stripped.startswith('<!DOCTYPE html>') or stripped.startswith('<html') or stripped.startswith('<'): |
|
|
|
|
|
for tag in ['<!DOCTYPE html', '<html']: |
|
|
idx = stripped.find(tag) |
|
|
if idx > 0: |
|
|
return stripped[idx:].strip() |
|
|
return stripped |
|
|
|
|
|
if text.strip().startswith('```python'): |
|
|
return text.strip()[9:-3].strip() |
|
|
|
|
|
lines = text.strip().split('\n', 1) |
|
|
if lines[0].strip().lower() in ['python', 'html', 'css', 'javascript', 'json', 'c', 'cpp', 'markdown', 'latex', 'jinja2', 'typescript', 'yaml', 'dockerfile', 'shell', 'r', 'sql', 'sql-mssql', 'sql-mysql', 'sql-mariadb', 'sql-sqlite', 'sql-cassandra', 'sql-plSQL', 'sql-hive', 'sql-pgsql', 'sql-gql', 'sql-gpsql', 'sql-sparksql', 'sql-esper']: |
|
|
return lines[1] if len(lines) > 1 else '' |
|
|
return text.strip() |
|
|
|
|
|
|
|
|
|
|
|
def strip_thinking_tags(text: str) -> str: |
|
|
"""Strip <think> tags and [TOOL_CALL] markers from streaming output.""" |
|
|
if not text: |
|
|
return text |
|
|
|
|
|
text = re.sub(r'<think>', '', text, flags=re.IGNORECASE) |
|
|
|
|
|
text = re.sub(r'</think>', '', text, flags=re.IGNORECASE) |
|
|
|
|
|
text = re.sub(r'\[/?TOOL_CALL\]', '', text, flags=re.IGNORECASE) |
|
|
return text |
|
|
|
|
|
def strip_placeholder_thinking(text: str) -> str: |
|
|
"""Remove placeholder 'Thinking...' status lines from streamed text.""" |
|
|
if not text: |
|
|
return text |
|
|
|
|
|
return re.sub(r"(?mi)^[\t ]*Thinking\.\.\.(?:\s*\(\d+s elapsed\))?[\t ]*$\n?", "", text) |
|
|
|
|
|
def is_placeholder_thinking_only(text: str) -> bool: |
|
|
"""Return True if text contains only 'Thinking...' placeholder lines (with optional elapsed).""" |
|
|
if not text: |
|
|
return False |
|
|
stripped = text.strip() |
|
|
if not stripped: |
|
|
return False |
|
|
return re.fullmatch(r"(?s)(?:\s*Thinking\.\.\.(?:\s*\(\d+s elapsed\))?\s*)+", stripped) is not None |
|
|
|
|
|
def extract_last_thinking_line(text: str) -> str: |
|
|
"""Extract the last 'Thinking...' line to display as status.""" |
|
|
matches = list(re.finditer(r"Thinking\.\.\.(?:\s*\(\d+s elapsed\))?", text)) |
|
|
return matches[-1].group(0) if matches else "Thinking..." |
|
|
|
|
|
|