AUXteam's picture
Upload folder using huggingface_hub
1397957 verified
raw
history blame
12.4 kB
from typing import Dict, Any, List, Optional, AsyncGenerator
import json
import os
from .provider import BaseProvider, ModelInfo, Message, StreamChunk, ToolCall
DEFAULT_MODELS = {
"claude-sonnet-4-20250514": ModelInfo(
id="claude-sonnet-4-20250514",
name="Claude Sonnet 4",
provider_id="litellm",
context_limit=200000,
output_limit=64000,
supports_tools=True,
supports_streaming=True,
cost_input=3.0,
cost_output=15.0,
),
"claude-opus-4-20250514": ModelInfo(
id="claude-opus-4-20250514",
name="Claude Opus 4",
provider_id="litellm",
context_limit=200000,
output_limit=32000,
supports_tools=True,
supports_streaming=True,
cost_input=15.0,
cost_output=75.0,
),
"claude-3-5-haiku-20241022": ModelInfo(
id="claude-3-5-haiku-20241022",
name="Claude 3.5 Haiku",
provider_id="litellm",
context_limit=200000,
output_limit=8192,
supports_tools=True,
supports_streaming=True,
cost_input=0.8,
cost_output=4.0,
),
"gpt-4o": ModelInfo(
id="gpt-4o",
name="GPT-4o",
provider_id="litellm",
context_limit=128000,
output_limit=16384,
supports_tools=True,
supports_streaming=True,
cost_input=2.5,
cost_output=10.0,
),
"gpt-4o-mini": ModelInfo(
id="gpt-4o-mini",
name="GPT-4o Mini",
provider_id="litellm",
context_limit=128000,
output_limit=16384,
supports_tools=True,
supports_streaming=True,
cost_input=0.15,
cost_output=0.6,
),
"o1": ModelInfo(
id="o1",
name="O1",
provider_id="litellm",
context_limit=200000,
output_limit=100000,
supports_tools=True,
supports_streaming=True,
cost_input=15.0,
cost_output=60.0,
),
"gemini/gemini-2.0-flash": ModelInfo(
id="gemini/gemini-2.0-flash",
name="Gemini 2.0 Flash",
provider_id="litellm",
context_limit=1000000,
output_limit=8192,
supports_tools=True,
supports_streaming=True,
cost_input=0.075,
cost_output=0.3,
),
"gemini/gemini-2.5-pro-preview-05-06": ModelInfo(
id="gemini/gemini-2.5-pro-preview-05-06",
name="Gemini 2.5 Pro",
provider_id="litellm",
context_limit=1000000,
output_limit=65536,
supports_tools=True,
supports_streaming=True,
cost_input=1.25,
cost_output=10.0,
),
"groq/llama-3.3-70b-versatile": ModelInfo(
id="groq/llama-3.3-70b-versatile",
name="Llama 3.3 70B (Groq)",
provider_id="litellm",
context_limit=128000,
output_limit=32768,
supports_tools=True,
supports_streaming=True,
cost_input=0.59,
cost_output=0.79,
),
"deepseek/deepseek-chat": ModelInfo(
id="deepseek/deepseek-chat",
name="DeepSeek Chat",
provider_id="litellm",
context_limit=64000,
output_limit=8192,
supports_tools=True,
supports_streaming=True,
cost_input=0.14,
cost_output=0.28,
),
"openrouter/anthropic/claude-sonnet-4": ModelInfo(
id="openrouter/anthropic/claude-sonnet-4",
name="Claude Sonnet 4 (OpenRouter)",
provider_id="litellm",
context_limit=200000,
output_limit=64000,
supports_tools=True,
supports_streaming=True,
cost_input=3.0,
cost_output=15.0,
),
# Z.ai Free Flash Models
"zai/glm-4.7-flash": ModelInfo(
id="zai/glm-4.7-flash",
name="GLM-4.7 Flash (Free)",
provider_id="litellm",
context_limit=128000,
output_limit=8192,
supports_tools=True,
supports_streaming=True,
cost_input=0.0,
cost_output=0.0,
),
"zai/glm-4.6v-flash": ModelInfo(
id="zai/glm-4.6v-flash",
name="GLM-4.6V Flash (Free)",
provider_id="litellm",
context_limit=128000,
output_limit=8192,
supports_tools=True,
supports_streaming=True,
cost_input=0.0,
cost_output=0.0,
),
"zai/glm-4.5-flash": ModelInfo(
id="zai/glm-4.5-flash",
name="GLM-4.5 Flash (Free)",
provider_id="litellm",
context_limit=128000,
output_limit=8192,
supports_tools=True,
supports_streaming=True,
cost_input=0.0,
cost_output=0.0,
),
}
class LiteLLMProvider(BaseProvider):
def __init__(self):
self._litellm = None
self._models = dict(DEFAULT_MODELS)
@property
def id(self) -> str:
return "litellm"
@property
def name(self) -> str:
return "LiteLLM (Multi-Provider)"
@property
def models(self) -> Dict[str, ModelInfo]:
return self._models
def add_model(self, model: ModelInfo) -> None:
self._models[model.id] = model
def _get_litellm(self):
if self._litellm is None:
try:
import litellm
litellm.drop_params = True
self._litellm = litellm
except ImportError:
raise ImportError("litellm package is required. Install with: pip install litellm")
return self._litellm
async def stream(
self,
model_id: str,
messages: List[Message],
tools: Optional[List[Dict[str, Any]]] = None,
system: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> AsyncGenerator[StreamChunk, None]:
litellm = self._get_litellm()
litellm_messages = []
if system:
litellm_messages.append({"role": "system", "content": system})
for msg in messages:
content = msg.content
if isinstance(content, str):
litellm_messages.append({"role": msg.role, "content": content})
else:
litellm_messages.append({
"role": msg.role,
"content": [{"type": c.type, "text": c.text} for c in content if c.text]
})
# Z.ai 모델 처리: OpenAI-compatible API 사용
actual_model = model_id
if model_id.startswith("zai/"):
# zai/glm-4.7-flash -> openai/glm-4.7-flash with custom api_base
actual_model = "openai/" + model_id[4:]
kwargs: Dict[str, Any] = {
"model": actual_model,
"messages": litellm_messages,
"stream": True,
}
# Z.ai 전용 설정
if model_id.startswith("zai/"):
kwargs["api_base"] = os.environ.get("ZAI_API_BASE", "https://api.z.ai/api/paas/v4")
kwargs["api_key"] = os.environ.get("ZAI_API_KEY")
if temperature is not None:
kwargs["temperature"] = temperature
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
else:
kwargs["max_tokens"] = 8192
if tools:
kwargs["tools"] = [
{
"type": "function",
"function": {
"name": t["name"],
"description": t.get("description", ""),
"parameters": t.get("parameters", t.get("input_schema", {}))
}
}
for t in tools
]
current_tool_calls: Dict[int, Dict[str, Any]] = {}
try:
response = await litellm.acompletion(**kwargs)
async for chunk in response:
if hasattr(chunk, 'choices') and chunk.choices:
choice = chunk.choices[0]
delta = getattr(choice, 'delta', None)
if delta:
if hasattr(delta, 'content') and delta.content:
yield StreamChunk(type="text", text=delta.content)
if hasattr(delta, 'tool_calls') and delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index if hasattr(tc, 'index') else 0
if idx not in current_tool_calls:
current_tool_calls[idx] = {
"id": tc.id if hasattr(tc, 'id') and tc.id else f"call_{idx}",
"name": "",
"arguments_json": ""
}
if hasattr(tc, 'function'):
if hasattr(tc.function, 'name') and tc.function.name:
current_tool_calls[idx]["name"] = tc.function.name
if hasattr(tc.function, 'arguments') and tc.function.arguments:
current_tool_calls[idx]["arguments_json"] += tc.function.arguments
finish_reason = getattr(choice, 'finish_reason', None)
if finish_reason:
for idx, tc_data in current_tool_calls.items():
if tc_data["name"]:
try:
args = json.loads(tc_data["arguments_json"]) if tc_data["arguments_json"] else {}
except json.JSONDecodeError:
args = {}
yield StreamChunk(
type="tool_call",
tool_call=ToolCall(
id=tc_data["id"],
name=tc_data["name"],
arguments=args
)
)
usage = None
if hasattr(chunk, 'usage') and chunk.usage:
usage = {
"input_tokens": getattr(chunk.usage, 'prompt_tokens', 0),
"output_tokens": getattr(chunk.usage, 'completion_tokens', 0),
}
stop_reason = self._map_stop_reason(finish_reason)
yield StreamChunk(type="done", usage=usage, stop_reason=stop_reason)
except Exception as e:
yield StreamChunk(type="error", error=str(e))
async def complete(
self,
model_id: str,
prompt: str,
max_tokens: int = 100,
) -> str:
"""단일 완료 요청 (스트리밍 없음)"""
litellm = self._get_litellm()
actual_model = model_id
kwargs: Dict[str, Any] = {
"model": actual_model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
}
# Z.ai 모델 처리
if model_id.startswith("zai/"):
actual_model = "openai/" + model_id[4:]
kwargs["model"] = actual_model
kwargs["api_base"] = os.environ.get("ZAI_API_BASE", "https://api.z.ai/api/paas/v4")
kwargs["api_key"] = os.environ.get("ZAI_API_KEY")
response = await litellm.acompletion(**kwargs)
return response.choices[0].message.content or ""
def _map_stop_reason(self, finish_reason: Optional[str]) -> str:
if not finish_reason:
return "end_turn"
mapping = {
"stop": "end_turn",
"end_turn": "end_turn",
"tool_calls": "tool_calls",
"function_call": "tool_calls",
"length": "max_tokens",
"max_tokens": "max_tokens",
"content_filter": "content_filter",
}
return mapping.get(finish_reason, "end_turn")