AUXteam's picture
Upload folder using huggingface_hub
1397957 verified
raw
history blame
6.3 kB
from typing import Dict, Any, List, Optional, AsyncGenerator
import os
import json
from .provider import BaseProvider, ModelInfo, Message, StreamChunk, ToolCall
class OpenAIProvider(BaseProvider):
def __init__(self, api_key: Optional[str] = None):
self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
self._client = None
@property
def id(self) -> str:
return "openai"
@property
def name(self) -> str:
return "OpenAI"
@property
def models(self) -> Dict[str, ModelInfo]:
return {
"gpt-4o": ModelInfo(
id="gpt-4o",
name="GPT-4o",
provider_id="openai",
context_limit=128000,
output_limit=16384,
supports_tools=True,
supports_streaming=True,
cost_input=2.5,
cost_output=10.0,
),
"gpt-4o-mini": ModelInfo(
id="gpt-4o-mini",
name="GPT-4o Mini",
provider_id="openai",
context_limit=128000,
output_limit=16384,
supports_tools=True,
supports_streaming=True,
cost_input=0.15,
cost_output=0.6,
),
"o1": ModelInfo(
id="o1",
name="o1",
provider_id="openai",
context_limit=200000,
output_limit=100000,
supports_tools=True,
supports_streaming=True,
cost_input=15.0,
cost_output=60.0,
),
}
def _get_client(self):
if self._client is None:
try:
from openai import AsyncOpenAI
self._client = AsyncOpenAI(api_key=self._api_key)
except ImportError:
raise ImportError("openai package is required. Install with: pip install openai")
return self._client
async def stream(
self,
model_id: str,
messages: List[Message],
tools: Optional[List[Dict[str, Any]]] = None,
system: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> AsyncGenerator[StreamChunk, None]:
client = self._get_client()
openai_messages = []
if system:
openai_messages.append({"role": "system", "content": system})
for msg in messages:
content = msg.content
if isinstance(content, str):
openai_messages.append({"role": msg.role, "content": content})
else:
openai_messages.append({
"role": msg.role,
"content": [{"type": c.type, "text": c.text} for c in content if c.text]
})
kwargs: Dict[str, Any] = {
"model": model_id,
"messages": openai_messages,
"stream": True,
}
if max_tokens:
kwargs["max_tokens"] = max_tokens
if temperature is not None:
kwargs["temperature"] = temperature
if tools:
kwargs["tools"] = [
{
"type": "function",
"function": {
"name": t["name"],
"description": t.get("description", ""),
"parameters": t.get("parameters", t.get("input_schema", {}))
}
}
for t in tools
]
tool_calls: Dict[int, Dict[str, Any]] = {}
usage_data = None
finish_reason = None
async for chunk in await client.chat.completions.create(**kwargs):
if chunk.choices and chunk.choices[0].delta:
delta = chunk.choices[0].delta
if delta.content:
yield StreamChunk(type="text", text=delta.content)
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_calls:
tool_calls[idx] = {
"id": tc.id or "",
"name": tc.function.name if tc.function else "",
"arguments": ""
}
if tc.id:
tool_calls[idx]["id"] = tc.id
if tc.function:
if tc.function.name:
tool_calls[idx]["name"] = tc.function.name
if tc.function.arguments:
tool_calls[idx]["arguments"] += tc.function.arguments
if chunk.choices and chunk.choices[0].finish_reason:
finish_reason = chunk.choices[0].finish_reason
if chunk.usage:
usage_data = {
"input_tokens": chunk.usage.prompt_tokens,
"output_tokens": chunk.usage.completion_tokens,
}
for tc_data in tool_calls.values():
try:
args = json.loads(tc_data["arguments"]) if tc_data["arguments"] else {}
except json.JSONDecodeError:
args = {}
yield StreamChunk(
type="tool_call",
tool_call=ToolCall(
id=tc_data["id"],
name=tc_data["name"],
arguments=args
)
)
stop_reason = self._map_stop_reason(finish_reason)
yield StreamChunk(type="done", usage=usage_data, stop_reason=stop_reason)
def _map_stop_reason(self, openai_finish_reason: Optional[str]) -> str:
mapping = {
"stop": "end_turn",
"tool_calls": "tool_calls",
"length": "max_tokens",
"content_filter": "end_turn",
}
return mapping.get(openai_finish_reason or "", "end_turn")