Spaces:
Sleeping
Sleeping
| from typing import Dict, Any, List, Optional, AsyncGenerator | |
| import json | |
| import os | |
| from .provider import BaseProvider, ModelInfo, Message, StreamChunk, ToolCall | |
| DEFAULT_MODELS = { | |
| "claude-sonnet-4-20250514": ModelInfo( | |
| id="claude-sonnet-4-20250514", | |
| name="Claude Sonnet 4", | |
| provider_id="litellm", | |
| context_limit=200000, | |
| output_limit=64000, | |
| supports_tools=True, | |
| supports_streaming=True, | |
| cost_input=3.0, | |
| cost_output=15.0, | |
| ), | |
| "claude-opus-4-20250514": ModelInfo( | |
| id="claude-opus-4-20250514", | |
| name="Claude Opus 4", | |
| provider_id="litellm", | |
| context_limit=200000, | |
| output_limit=32000, | |
| supports_tools=True, | |
| supports_streaming=True, | |
| cost_input=15.0, | |
| cost_output=75.0, | |
| ), | |
| "claude-3-5-haiku-20241022": ModelInfo( | |
| id="claude-3-5-haiku-20241022", | |
| name="Claude 3.5 Haiku", | |
| provider_id="litellm", | |
| context_limit=200000, | |
| output_limit=8192, | |
| supports_tools=True, | |
| supports_streaming=True, | |
| cost_input=0.8, | |
| cost_output=4.0, | |
| ), | |
| "gpt-4o": ModelInfo( | |
| id="gpt-4o", | |
| name="GPT-4o", | |
| provider_id="litellm", | |
| context_limit=128000, | |
| output_limit=16384, | |
| supports_tools=True, | |
| supports_streaming=True, | |
| cost_input=2.5, | |
| cost_output=10.0, | |
| ), | |
| "gpt-4o-mini": ModelInfo( | |
| id="gpt-4o-mini", | |
| name="GPT-4o Mini", | |
| provider_id="litellm", | |
| context_limit=128000, | |
| output_limit=16384, | |
| supports_tools=True, | |
| supports_streaming=True, | |
| cost_input=0.15, | |
| cost_output=0.6, | |
| ), | |
| "o1": ModelInfo( | |
| id="o1", | |
| name="O1", | |
| provider_id="litellm", | |
| context_limit=200000, | |
| output_limit=100000, | |
| supports_tools=True, | |
| supports_streaming=True, | |
| cost_input=15.0, | |
| cost_output=60.0, | |
| ), | |
| "gemini/gemini-2.0-flash": ModelInfo( | |
| id="gemini/gemini-2.0-flash", | |
| name="Gemini 2.0 Flash", | |
| provider_id="litellm", | |
| context_limit=1000000, | |
| output_limit=8192, | |
| supports_tools=True, | |
| supports_streaming=True, | |
| cost_input=0.075, | |
| cost_output=0.3, | |
| ), | |
| "gemini/gemini-2.5-pro-preview-05-06": ModelInfo( | |
| id="gemini/gemini-2.5-pro-preview-05-06", | |
| name="Gemini 2.5 Pro", | |
| provider_id="litellm", | |
| context_limit=1000000, | |
| output_limit=65536, | |
| supports_tools=True, | |
| supports_streaming=True, | |
| cost_input=1.25, | |
| cost_output=10.0, | |
| ), | |
| "groq/llama-3.3-70b-versatile": ModelInfo( | |
| id="groq/llama-3.3-70b-versatile", | |
| name="Llama 3.3 70B (Groq)", | |
| provider_id="litellm", | |
| context_limit=128000, | |
| output_limit=32768, | |
| supports_tools=True, | |
| supports_streaming=True, | |
| cost_input=0.59, | |
| cost_output=0.79, | |
| ), | |
| "deepseek/deepseek-chat": ModelInfo( | |
| id="deepseek/deepseek-chat", | |
| name="DeepSeek Chat", | |
| provider_id="litellm", | |
| context_limit=64000, | |
| output_limit=8192, | |
| supports_tools=True, | |
| supports_streaming=True, | |
| cost_input=0.14, | |
| cost_output=0.28, | |
| ), | |
| "openrouter/anthropic/claude-sonnet-4": ModelInfo( | |
| id="openrouter/anthropic/claude-sonnet-4", | |
| name="Claude Sonnet 4 (OpenRouter)", | |
| provider_id="litellm", | |
| context_limit=200000, | |
| output_limit=64000, | |
| supports_tools=True, | |
| supports_streaming=True, | |
| cost_input=3.0, | |
| cost_output=15.0, | |
| ), | |
| # Z.ai Free Flash Models | |
| "zai/glm-4.7-flash": ModelInfo( | |
| id="zai/glm-4.7-flash", | |
| name="GLM-4.7 Flash (Free)", | |
| provider_id="litellm", | |
| context_limit=128000, | |
| output_limit=8192, | |
| supports_tools=True, | |
| supports_streaming=True, | |
| cost_input=0.0, | |
| cost_output=0.0, | |
| ), | |
| "zai/glm-4.6v-flash": ModelInfo( | |
| id="zai/glm-4.6v-flash", | |
| name="GLM-4.6V Flash (Free)", | |
| provider_id="litellm", | |
| context_limit=128000, | |
| output_limit=8192, | |
| supports_tools=True, | |
| supports_streaming=True, | |
| cost_input=0.0, | |
| cost_output=0.0, | |
| ), | |
| "zai/glm-4.5-flash": ModelInfo( | |
| id="zai/glm-4.5-flash", | |
| name="GLM-4.5 Flash (Free)", | |
| provider_id="litellm", | |
| context_limit=128000, | |
| output_limit=8192, | |
| supports_tools=True, | |
| supports_streaming=True, | |
| cost_input=0.0, | |
| cost_output=0.0, | |
| ), | |
| } | |
| class LiteLLMProvider(BaseProvider): | |
| def __init__(self): | |
| self._litellm = None | |
| self._models = dict(DEFAULT_MODELS) | |
| def id(self) -> str: | |
| return "litellm" | |
| def name(self) -> str: | |
| return "LiteLLM (Multi-Provider)" | |
| def models(self) -> Dict[str, ModelInfo]: | |
| return self._models | |
| def add_model(self, model: ModelInfo) -> None: | |
| self._models[model.id] = model | |
| def _get_litellm(self): | |
| if self._litellm is None: | |
| try: | |
| import litellm | |
| litellm.drop_params = True | |
| self._litellm = litellm | |
| except ImportError: | |
| raise ImportError("litellm package is required. Install with: pip install litellm") | |
| return self._litellm | |
| async def stream( | |
| self, | |
| model_id: str, | |
| messages: List[Message], | |
| tools: Optional[List[Dict[str, Any]]] = None, | |
| system: Optional[str] = None, | |
| temperature: Optional[float] = None, | |
| max_tokens: Optional[int] = None, | |
| ) -> AsyncGenerator[StreamChunk, None]: | |
| litellm = self._get_litellm() | |
| litellm_messages = [] | |
| if system: | |
| litellm_messages.append({"role": "system", "content": system}) | |
| for msg in messages: | |
| content = msg.content | |
| if isinstance(content, str): | |
| litellm_messages.append({"role": msg.role, "content": content}) | |
| else: | |
| litellm_messages.append({ | |
| "role": msg.role, | |
| "content": [{"type": c.type, "text": c.text} for c in content if c.text] | |
| }) | |
| # Z.ai 모델 처리: OpenAI-compatible API 사용 | |
| actual_model = model_id | |
| if model_id.startswith("zai/"): | |
| # zai/glm-4.7-flash -> openai/glm-4.7-flash with custom api_base | |
| actual_model = "openai/" + model_id[4:] | |
| kwargs: Dict[str, Any] = { | |
| "model": actual_model, | |
| "messages": litellm_messages, | |
| "stream": True, | |
| } | |
| # Z.ai 전용 설정 | |
| if model_id.startswith("zai/"): | |
| kwargs["api_base"] = os.environ.get("ZAI_API_BASE", "https://api.z.ai/api/paas/v4") | |
| kwargs["api_key"] = os.environ.get("ZAI_API_KEY") | |
| if temperature is not None: | |
| kwargs["temperature"] = temperature | |
| if max_tokens is not None: | |
| kwargs["max_tokens"] = max_tokens | |
| else: | |
| kwargs["max_tokens"] = 8192 | |
| if tools: | |
| kwargs["tools"] = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": t["name"], | |
| "description": t.get("description", ""), | |
| "parameters": t.get("parameters", t.get("input_schema", {})) | |
| } | |
| } | |
| for t in tools | |
| ] | |
| current_tool_calls: Dict[int, Dict[str, Any]] = {} | |
| try: | |
| response = await litellm.acompletion(**kwargs) | |
| async for chunk in response: | |
| if hasattr(chunk, 'choices') and chunk.choices: | |
| choice = chunk.choices[0] | |
| delta = getattr(choice, 'delta', None) | |
| if delta: | |
| if hasattr(delta, 'content') and delta.content: | |
| yield StreamChunk(type="text", text=delta.content) | |
| if hasattr(delta, 'tool_calls') and delta.tool_calls: | |
| for tc in delta.tool_calls: | |
| idx = tc.index if hasattr(tc, 'index') else 0 | |
| if idx not in current_tool_calls: | |
| current_tool_calls[idx] = { | |
| "id": tc.id if hasattr(tc, 'id') and tc.id else f"call_{idx}", | |
| "name": "", | |
| "arguments_json": "" | |
| } | |
| if hasattr(tc, 'function'): | |
| if hasattr(tc.function, 'name') and tc.function.name: | |
| current_tool_calls[idx]["name"] = tc.function.name | |
| if hasattr(tc.function, 'arguments') and tc.function.arguments: | |
| current_tool_calls[idx]["arguments_json"] += tc.function.arguments | |
| finish_reason = getattr(choice, 'finish_reason', None) | |
| if finish_reason: | |
| for idx, tc_data in current_tool_calls.items(): | |
| if tc_data["name"]: | |
| try: | |
| args = json.loads(tc_data["arguments_json"]) if tc_data["arguments_json"] else {} | |
| except json.JSONDecodeError: | |
| args = {} | |
| yield StreamChunk( | |
| type="tool_call", | |
| tool_call=ToolCall( | |
| id=tc_data["id"], | |
| name=tc_data["name"], | |
| arguments=args | |
| ) | |
| ) | |
| usage = None | |
| if hasattr(chunk, 'usage') and chunk.usage: | |
| usage = { | |
| "input_tokens": getattr(chunk.usage, 'prompt_tokens', 0), | |
| "output_tokens": getattr(chunk.usage, 'completion_tokens', 0), | |
| } | |
| stop_reason = self._map_stop_reason(finish_reason) | |
| yield StreamChunk(type="done", usage=usage, stop_reason=stop_reason) | |
| except Exception as e: | |
| yield StreamChunk(type="error", error=str(e)) | |
| async def complete( | |
| self, | |
| model_id: str, | |
| prompt: str, | |
| max_tokens: int = 100, | |
| ) -> str: | |
| """단일 완료 요청 (스트리밍 없음)""" | |
| litellm = self._get_litellm() | |
| actual_model = model_id | |
| kwargs: Dict[str, Any] = { | |
| "model": actual_model, | |
| "messages": [{"role": "user", "content": prompt}], | |
| "max_tokens": max_tokens, | |
| } | |
| # Z.ai 모델 처리 | |
| if model_id.startswith("zai/"): | |
| actual_model = "openai/" + model_id[4:] | |
| kwargs["model"] = actual_model | |
| kwargs["api_base"] = os.environ.get("ZAI_API_BASE", "https://api.z.ai/api/paas/v4") | |
| kwargs["api_key"] = os.environ.get("ZAI_API_KEY") | |
| response = await litellm.acompletion(**kwargs) | |
| return response.choices[0].message.content or "" | |
| def _map_stop_reason(self, finish_reason: Optional[str]) -> str: | |
| if not finish_reason: | |
| return "end_turn" | |
| mapping = { | |
| "stop": "end_turn", | |
| "end_turn": "end_turn", | |
| "tool_calls": "tool_calls", | |
| "function_call": "tool_calls", | |
| "length": "max_tokens", | |
| "max_tokens": "max_tokens", | |
| "content_filter": "content_filter", | |
| } | |
| return mapping.get(finish_reason, "end_turn") | |