Spaces:
Sleeping
Sleeping
File size: 3,173 Bytes
1397957 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from typing import Dict, Any, List, Optional, AsyncIterator, AsyncGenerator, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from abc import ABC, abstractmethod
class ModelInfo(BaseModel):
id: str
name: str
provider_id: str
context_limit: int = 128000
output_limit: int = 8192
supports_tools: bool = True
supports_streaming: bool = True
cost_input: float = 0.0 # per 1M tokens
cost_output: float = 0.0 # per 1M tokens
class ProviderInfo(BaseModel):
id: str
name: str
models: Dict[str, ModelInfo] = Field(default_factory=dict)
class MessageContent(BaseModel):
type: str = "text"
text: Optional[str] = None
class Message(BaseModel):
role: str # "user", "assistant", "system"
content: str | List[MessageContent]
class ToolCall(BaseModel):
id: str
name: str
arguments: Dict[str, Any]
class ToolResult(BaseModel):
tool_call_id: str
output: str
class StreamChunk(BaseModel):
type: str # "text", "reasoning", "tool_call", "tool_result", "done", "error"
text: Optional[str] = None
tool_call: Optional[ToolCall] = None
error: Optional[str] = None
usage: Optional[Dict[str, int]] = None
stop_reason: Optional[str] = None # "end_turn", "tool_calls", "max_tokens", etc.
@runtime_checkable
class Provider(Protocol):
@property
def id(self) -> str: ...
@property
def name(self) -> str: ...
@property
def models(self) -> Dict[str, ModelInfo]: ...
def stream(
self,
model_id: str,
messages: List[Message],
tools: Optional[List[Dict[str, Any]]] = None,
system: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> AsyncGenerator[StreamChunk, None]: ...
class BaseProvider(ABC):
@property
@abstractmethod
def id(self) -> str:
pass
@property
@abstractmethod
def name(self) -> str:
pass
@property
@abstractmethod
def models(self) -> Dict[str, ModelInfo]:
pass
@abstractmethod
def stream(
self,
model_id: str,
messages: List[Message],
tools: Optional[List[Dict[str, Any]]] = None,
system: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> AsyncGenerator[StreamChunk, None]:
pass
def get_info(self) -> ProviderInfo:
return ProviderInfo(
id=self.id,
name=self.name,
models=self.models
)
_providers: Dict[str, BaseProvider] = {}
def register_provider(provider: BaseProvider) -> None:
_providers[provider.id] = provider
def get_provider(provider_id: str) -> Optional[BaseProvider]:
return _providers.get(provider_id)
def list_providers() -> List[ProviderInfo]:
return [p.get_info() for p in _providers.values()]
def get_model(provider_id: str, model_id: str) -> Optional[ModelInfo]:
provider = get_provider(provider_id)
if provider:
return provider.models.get(model_id)
return None
|