|
|
""" |
|
|
Multi-Model Router - Intelligent model selection for optimal performance |
|
|
Integrates Claude, Gemini, and GPT-4 with automatic routing |
|
|
""" |
|
|
|
|
|
import os |
|
|
from typing import Dict, Any, List, Optional, Literal |
|
|
from enum import Enum |
|
|
import asyncio |
|
|
from dotenv import load_dotenv |
|
|
from anthropic import AsyncAnthropic |
|
|
from openai import AsyncOpenAI |
|
|
import google.generativeai as genai |
|
|
from langchain_anthropic import ChatAnthropic |
|
|
from langchain_openai import ChatOpenAI |
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
class ModelType(Enum): |
|
|
"""Available AI models""" |
|
|
CLAUDE_SONNET = "claude-sonnet-4-20250514" |
|
|
GEMINI_2_FLASH = "gemini-2.0-flash-exp" |
|
|
GPT4O_MINI = "gpt-4o-mini" |
|
|
|
|
|
|
|
|
class TaskType(Enum): |
|
|
"""Task types for intelligent routing""" |
|
|
REASONING = "reasoning" |
|
|
CODE_GEN = "code_generation" |
|
|
MULTIMODAL = "multimodal" |
|
|
PLANNING = "planning" |
|
|
FAST_QUERY = "fast_query" |
|
|
VISION = "vision" |
|
|
AUDIO = "audio" |
|
|
|
|
|
|
|
|
class MultiModelRouter: |
|
|
""" |
|
|
Intelligent multi-model router that selects the best AI model for each task. |
|
|
|
|
|
Prize Integration: |
|
|
- Google Gemini: $10K prize for multimodal capabilities |
|
|
- Anthropic Claude: Core reasoning engine |
|
|
- OpenAI GPT-4: Planning and routing |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.anthropic_key = os.getenv("ANTHROPIC_API_KEY") |
|
|
self.openai_key = os.getenv("OPENAI_API_KEY") |
|
|
self.google_key = os.getenv("GOOGLE_API_KEY") |
|
|
|
|
|
|
|
|
self.anthropic_client = AsyncAnthropic(api_key=self.anthropic_key) if self.anthropic_key else None |
|
|
self.openai_client = AsyncOpenAI(api_key=self.openai_key) if self.openai_key else None |
|
|
|
|
|
if self.google_key: |
|
|
genai.configure(api_key=self.google_key) |
|
|
|
|
|
|
|
|
self.claude_lc = ChatAnthropic( |
|
|
model=ModelType.CLAUDE_SONNET.value, |
|
|
api_key=self.anthropic_key, |
|
|
temperature=0.7 |
|
|
) if self.anthropic_key else None |
|
|
|
|
|
self.gpt_lc = ChatOpenAI( |
|
|
model=ModelType.GPT4O_MINI.value, |
|
|
api_key=self.openai_key, |
|
|
temperature=0.7 |
|
|
) if self.openai_key else None |
|
|
|
|
|
self.gemini_lc = ChatGoogleGenerativeAI( |
|
|
model=ModelType.GEMINI_2_FLASH.value, |
|
|
google_api_key=self.google_key, |
|
|
temperature=0.7 |
|
|
) if self.google_key else None |
|
|
|
|
|
|
|
|
self.routing_rules = { |
|
|
TaskType.REASONING: ModelType.CLAUDE_SONNET, |
|
|
TaskType.CODE_GEN: ModelType.CLAUDE_SONNET, |
|
|
TaskType.MULTIMODAL: ModelType.GEMINI_2_FLASH, |
|
|
TaskType.PLANNING: ModelType.GPT4O_MINI, |
|
|
TaskType.FAST_QUERY: ModelType.GEMINI_2_FLASH, |
|
|
TaskType.VISION: ModelType.GEMINI_2_FLASH, |
|
|
TaskType.AUDIO: ModelType.GEMINI_2_FLASH, |
|
|
} |
|
|
|
|
|
|
|
|
self.model_costs = { |
|
|
ModelType.CLAUDE_SONNET: {"input": 3.0, "output": 15.0}, |
|
|
ModelType.GEMINI_2_FLASH: {"input": 0.0, "output": 0.0}, |
|
|
ModelType.GPT4O_MINI: {"input": 0.15, "output": 0.60}, |
|
|
} |
|
|
|
|
|
self.usage_stats = { |
|
|
"claude": {"requests": 0, "tokens": 0, "cost": 0.0}, |
|
|
"gemini": {"requests": 0, "tokens": 0, "cost": 0.0}, |
|
|
"gpt4": {"requests": 0, "tokens": 0, "cost": 0.0}, |
|
|
} |
|
|
|
|
|
def select_model(self, task_type: TaskType, prefer_cost_efficient: bool = False) -> ModelType: |
|
|
""" |
|
|
Intelligently select the best model for a task. |
|
|
|
|
|
Args: |
|
|
task_type: Type of task to perform |
|
|
prefer_cost_efficient: Prefer cheaper models when possible |
|
|
|
|
|
Returns: |
|
|
Selected model type |
|
|
""" |
|
|
base_model = self.routing_rules.get(task_type, ModelType.CLAUDE_SONNET) |
|
|
|
|
|
|
|
|
if prefer_cost_efficient: |
|
|
if task_type in [TaskType.MULTIMODAL, TaskType.FAST_QUERY, TaskType.VISION]: |
|
|
return ModelType.GEMINI_2_FLASH |
|
|
elif task_type == TaskType.PLANNING: |
|
|
return ModelType.GPT4O_MINI |
|
|
|
|
|
return base_model |
|
|
|
|
|
async def generate( |
|
|
self, |
|
|
prompt: str, |
|
|
task_type: TaskType = TaskType.REASONING, |
|
|
system_prompt: Optional[str] = None, |
|
|
max_tokens: int = 4000, |
|
|
temperature: float = 0.7, |
|
|
image_url: Optional[str] = None, |
|
|
audio_data: Optional[bytes] = None, |
|
|
stream: bool = False, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Generate response using the best model for the task. |
|
|
|
|
|
Args: |
|
|
prompt: User prompt |
|
|
task_type: Type of task |
|
|
system_prompt: System instructions |
|
|
max_tokens: Maximum response length |
|
|
temperature: Creativity (0-1) |
|
|
image_url: URL for image analysis (Gemini multimodal) |
|
|
audio_data: Audio bytes for analysis (Gemini) |
|
|
stream: Stream response tokens |
|
|
|
|
|
Returns: |
|
|
Dict with response, model used, tokens, cost |
|
|
""" |
|
|
model = self.select_model(task_type) |
|
|
|
|
|
|
|
|
if image_url or audio_data: |
|
|
model = ModelType.GEMINI_2_FLASH |
|
|
|
|
|
try: |
|
|
if model == ModelType.CLAUDE_SONNET: |
|
|
return await self._generate_claude(prompt, system_prompt, max_tokens, temperature, stream) |
|
|
elif model == ModelType.GEMINI_2_FLASH: |
|
|
return await self._generate_gemini(prompt, system_prompt, max_tokens, temperature, image_url, audio_data) |
|
|
elif model == ModelType.GPT4O_MINI: |
|
|
return await self._generate_gpt(prompt, system_prompt, max_tokens, temperature, stream) |
|
|
except Exception as e: |
|
|
|
|
|
if model != ModelType.CLAUDE_SONNET: |
|
|
return await self._generate_claude(prompt, system_prompt, max_tokens, temperature, stream) |
|
|
raise e |
|
|
|
|
|
async def _generate_claude( |
|
|
self, |
|
|
prompt: str, |
|
|
system_prompt: Optional[str], |
|
|
max_tokens: int, |
|
|
temperature: float, |
|
|
stream: bool |
|
|
) -> Dict[str, Any]: |
|
|
"""Generate using Claude Sonnet""" |
|
|
if not self.anthropic_client: |
|
|
raise ValueError("Anthropic API key not configured") |
|
|
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
|
|
|
response = await self.anthropic_client.messages.create( |
|
|
model=ModelType.CLAUDE_SONNET.value, |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
system=system_prompt or "You are a helpful AI assistant.", |
|
|
messages=messages, |
|
|
stream=stream |
|
|
) |
|
|
|
|
|
if stream: |
|
|
return {"response": response, "model": "claude", "streaming": True} |
|
|
|
|
|
content = response.content[0].text |
|
|
input_tokens = response.usage.input_tokens |
|
|
output_tokens = response.usage.output_tokens |
|
|
cost = self._calculate_cost(ModelType.CLAUDE_SONNET, input_tokens, output_tokens) |
|
|
|
|
|
|
|
|
self.usage_stats["claude"]["requests"] += 1 |
|
|
self.usage_stats["claude"]["tokens"] += input_tokens + output_tokens |
|
|
self.usage_stats["claude"]["cost"] += cost |
|
|
|
|
|
return { |
|
|
"response": content, |
|
|
"model": "claude-sonnet-4", |
|
|
"input_tokens": input_tokens, |
|
|
"output_tokens": output_tokens, |
|
|
"total_tokens": input_tokens + output_tokens, |
|
|
"cost": cost, |
|
|
"streaming": False |
|
|
} |
|
|
|
|
|
async def _generate_gemini( |
|
|
self, |
|
|
prompt: str, |
|
|
system_prompt: Optional[str], |
|
|
max_tokens: int, |
|
|
temperature: float, |
|
|
image_url: Optional[str] = None, |
|
|
audio_data: Optional[bytes] = None |
|
|
) -> Dict[str, Any]: |
|
|
"""Generate using Gemini 2.0 Flash (multimodal support)""" |
|
|
if not self.google_key: |
|
|
raise ValueError("Google API key not configured") |
|
|
|
|
|
model = genai.GenerativeModel( |
|
|
ModelType.GEMINI_2_FLASH.value, |
|
|
system_instruction=system_prompt |
|
|
) |
|
|
|
|
|
|
|
|
content_parts = [] |
|
|
if image_url: |
|
|
|
|
|
import httpx |
|
|
async with httpx.AsyncClient() as client: |
|
|
img_response = await client.get(image_url) |
|
|
img_data = img_response.content |
|
|
content_parts.append({"mime_type": "image/jpeg", "data": img_data}) |
|
|
|
|
|
if audio_data: |
|
|
content_parts.append({"mime_type": "audio/wav", "data": audio_data}) |
|
|
|
|
|
content_parts.append(prompt) |
|
|
|
|
|
response = await model.generate_content_async( |
|
|
content_parts, |
|
|
generation_config=genai.GenerationConfig( |
|
|
max_output_tokens=max_tokens, |
|
|
temperature=temperature |
|
|
) |
|
|
) |
|
|
|
|
|
content = response.text |
|
|
|
|
|
|
|
|
self.usage_stats["gemini"]["requests"] += 1 |
|
|
|
|
|
return { |
|
|
"response": content, |
|
|
"model": "gemini-2.0-flash", |
|
|
"input_tokens": 0, |
|
|
"output_tokens": 0, |
|
|
"total_tokens": 0, |
|
|
"cost": 0.0, |
|
|
"streaming": False, |
|
|
"multimodal": bool(image_url or audio_data) |
|
|
} |
|
|
|
|
|
async def _generate_gpt( |
|
|
self, |
|
|
prompt: str, |
|
|
system_prompt: Optional[str], |
|
|
max_tokens: int, |
|
|
temperature: float, |
|
|
stream: bool |
|
|
) -> Dict[str, Any]: |
|
|
"""Generate using GPT-4o-mini""" |
|
|
if not self.openai_client: |
|
|
raise ValueError("OpenAI API key not configured") |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": system_prompt or "You are a helpful AI assistant."}, |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
|
|
|
response = await self.openai_client.chat.completions.create( |
|
|
model=ModelType.GPT4O_MINI.value, |
|
|
messages=messages, |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
stream=stream |
|
|
) |
|
|
|
|
|
if stream: |
|
|
return {"response": response, "model": "gpt-4o-mini", "streaming": True} |
|
|
|
|
|
content = response.choices[0].message.content |
|
|
input_tokens = response.usage.prompt_tokens |
|
|
output_tokens = response.usage.completion_tokens |
|
|
cost = self._calculate_cost(ModelType.GPT4O_MINI, input_tokens, output_tokens) |
|
|
|
|
|
|
|
|
self.usage_stats["gpt4"]["requests"] += 1 |
|
|
self.usage_stats["gpt4"]["tokens"] += input_tokens + output_tokens |
|
|
self.usage_stats["gpt4"]["cost"] += cost |
|
|
|
|
|
return { |
|
|
"response": content, |
|
|
"model": "gpt-4o-mini", |
|
|
"input_tokens": input_tokens, |
|
|
"output_tokens": output_tokens, |
|
|
"total_tokens": input_tokens + output_tokens, |
|
|
"cost": cost, |
|
|
"streaming": False |
|
|
} |
|
|
|
|
|
def _calculate_cost(self, model: ModelType, input_tokens: int, output_tokens: int) -> float: |
|
|
"""Calculate cost for API usage""" |
|
|
costs = self.model_costs[model] |
|
|
input_cost = (input_tokens / 1_000_000) * costs["input"] |
|
|
output_cost = (output_tokens / 1_000_000) * costs["output"] |
|
|
return input_cost + output_cost |
|
|
|
|
|
def get_usage_stats(self) -> Dict[str, Any]: |
|
|
"""Get usage statistics across all models""" |
|
|
total_cost = sum(stats["cost"] for stats in self.usage_stats.values()) |
|
|
total_requests = sum(stats["requests"] for stats in self.usage_stats.values()) |
|
|
|
|
|
return { |
|
|
"total_requests": total_requests, |
|
|
"total_cost": round(total_cost, 4), |
|
|
"by_model": self.usage_stats, |
|
|
"cost_breakdown": { |
|
|
"claude": round(self.usage_stats["claude"]["cost"], 4), |
|
|
"gemini": round(self.usage_stats["gemini"]["cost"], 4), |
|
|
"gpt4": round(self.usage_stats["gpt4"]["cost"], 4), |
|
|
} |
|
|
} |
|
|
|
|
|
def get_langchain_model(self, task_type: TaskType): |
|
|
"""Get LangChain-compatible model for agent integration""" |
|
|
model = self.select_model(task_type) |
|
|
|
|
|
if model == ModelType.CLAUDE_SONNET: |
|
|
return self.claude_lc |
|
|
elif model == ModelType.GEMINI_2_FLASH: |
|
|
return self.gemini_lc |
|
|
elif model == ModelType.GPT4O_MINI: |
|
|
return self.gpt_lc |
|
|
|
|
|
return self.claude_lc |
|
|
|
|
|
|
|
|
|
|
|
router = MultiModelRouter() |
|
|
|