OmniMind-Orchestrator / core /model_router.py
mgbam's picture
Upload 2 files
b942332 verified
"""
Multi-Model Router - Intelligent model selection for optimal performance
Integrates Claude, Gemini, and GPT-4 with automatic routing
"""
import os
from typing import Dict, Any, List, Optional, Literal
from enum import Enum
import asyncio
from dotenv import load_dotenv
from anthropic import AsyncAnthropic
from openai import AsyncOpenAI
import google.generativeai as genai
from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
# Load environment variables before initializing clients
load_dotenv()
class ModelType(Enum):
"""Available AI models"""
CLAUDE_SONNET = "claude-sonnet-4-20250514" # Best for reasoning, code generation
GEMINI_2_FLASH = "gemini-2.0-flash-exp" # Best for multimodal, speed
GPT4O_MINI = "gpt-4o-mini" # Best for planning, routing decisions
class TaskType(Enum):
"""Task types for intelligent routing"""
REASONING = "reasoning" # Complex logic, analysis
CODE_GEN = "code_generation" # MCP server generation
MULTIMODAL = "multimodal" # Images, audio, video
PLANNING = "planning" # Task breakdown, routing
FAST_QUERY = "fast_query" # Quick responses
VISION = "vision" # Image analysis
AUDIO = "audio" # Audio processing
class MultiModelRouter:
"""
Intelligent multi-model router that selects the best AI model for each task.
Prize Integration:
- Google Gemini: $10K prize for multimodal capabilities
- Anthropic Claude: Core reasoning engine
- OpenAI GPT-4: Planning and routing
"""
def __init__(self):
self.anthropic_key = os.getenv("ANTHROPIC_API_KEY")
self.openai_key = os.getenv("OPENAI_API_KEY")
self.google_key = os.getenv("GOOGLE_API_KEY")
# Initialize clients
self.anthropic_client = AsyncAnthropic(api_key=self.anthropic_key) if self.anthropic_key else None
self.openai_client = AsyncOpenAI(api_key=self.openai_key) if self.openai_key else None
if self.google_key:
genai.configure(api_key=self.google_key)
# LangChain clients for agent integration
self.claude_lc = ChatAnthropic(
model=ModelType.CLAUDE_SONNET.value,
api_key=self.anthropic_key,
temperature=0.7
) if self.anthropic_key else None
self.gpt_lc = ChatOpenAI(
model=ModelType.GPT4O_MINI.value,
api_key=self.openai_key,
temperature=0.7
) if self.openai_key else None
self.gemini_lc = ChatGoogleGenerativeAI(
model=ModelType.GEMINI_2_FLASH.value,
google_api_key=self.google_key,
temperature=0.7
) if self.google_key else None
# Routing rules: Task type -> Best model
self.routing_rules = {
TaskType.REASONING: ModelType.CLAUDE_SONNET,
TaskType.CODE_GEN: ModelType.CLAUDE_SONNET,
TaskType.MULTIMODAL: ModelType.GEMINI_2_FLASH,
TaskType.PLANNING: ModelType.GPT4O_MINI,
TaskType.FAST_QUERY: ModelType.GEMINI_2_FLASH,
TaskType.VISION: ModelType.GEMINI_2_FLASH,
TaskType.AUDIO: ModelType.GEMINI_2_FLASH,
}
# Cost tracking (per 1M tokens)
self.model_costs = {
ModelType.CLAUDE_SONNET: {"input": 3.0, "output": 15.0},
ModelType.GEMINI_2_FLASH: {"input": 0.0, "output": 0.0}, # Free tier
ModelType.GPT4O_MINI: {"input": 0.15, "output": 0.60},
}
self.usage_stats = {
"claude": {"requests": 0, "tokens": 0, "cost": 0.0},
"gemini": {"requests": 0, "tokens": 0, "cost": 0.0},
"gpt4": {"requests": 0, "tokens": 0, "cost": 0.0},
}
def select_model(self, task_type: TaskType, prefer_cost_efficient: bool = False) -> ModelType:
"""
Intelligently select the best model for a task.
Args:
task_type: Type of task to perform
prefer_cost_efficient: Prefer cheaper models when possible
Returns:
Selected model type
"""
base_model = self.routing_rules.get(task_type, ModelType.CLAUDE_SONNET)
# If cost-efficient mode, prefer Gemini (free tier) or GPT-4o-mini
if prefer_cost_efficient:
if task_type in [TaskType.MULTIMODAL, TaskType.FAST_QUERY, TaskType.VISION]:
return ModelType.GEMINI_2_FLASH
elif task_type == TaskType.PLANNING:
return ModelType.GPT4O_MINI
return base_model
async def generate(
self,
prompt: str,
task_type: TaskType = TaskType.REASONING,
system_prompt: Optional[str] = None,
max_tokens: int = 4000,
temperature: float = 0.7,
image_url: Optional[str] = None,
audio_data: Optional[bytes] = None,
stream: bool = False,
) -> Dict[str, Any]:
"""
Generate response using the best model for the task.
Args:
prompt: User prompt
task_type: Type of task
system_prompt: System instructions
max_tokens: Maximum response length
temperature: Creativity (0-1)
image_url: URL for image analysis (Gemini multimodal)
audio_data: Audio bytes for analysis (Gemini)
stream: Stream response tokens
Returns:
Dict with response, model used, tokens, cost
"""
model = self.select_model(task_type)
# Force Gemini for multimodal tasks
if image_url or audio_data:
model = ModelType.GEMINI_2_FLASH
try:
if model == ModelType.CLAUDE_SONNET:
return await self._generate_claude(prompt, system_prompt, max_tokens, temperature, stream)
elif model == ModelType.GEMINI_2_FLASH:
return await self._generate_gemini(prompt, system_prompt, max_tokens, temperature, image_url, audio_data)
elif model == ModelType.GPT4O_MINI:
return await self._generate_gpt(prompt, system_prompt, max_tokens, temperature, stream)
except Exception as e:
# Fallback to Claude if primary model fails
if model != ModelType.CLAUDE_SONNET:
return await self._generate_claude(prompt, system_prompt, max_tokens, temperature, stream)
raise e
async def _generate_claude(
self,
prompt: str,
system_prompt: Optional[str],
max_tokens: int,
temperature: float,
stream: bool
) -> Dict[str, Any]:
"""Generate using Claude Sonnet"""
if not self.anthropic_client:
raise ValueError("Anthropic API key not configured")
messages = [{"role": "user", "content": prompt}]
response = await self.anthropic_client.messages.create(
model=ModelType.CLAUDE_SONNET.value,
max_tokens=max_tokens,
temperature=temperature,
system=system_prompt or "You are a helpful AI assistant.",
messages=messages,
stream=stream
)
if stream:
return {"response": response, "model": "claude", "streaming": True}
content = response.content[0].text
input_tokens = response.usage.input_tokens
output_tokens = response.usage.output_tokens
cost = self._calculate_cost(ModelType.CLAUDE_SONNET, input_tokens, output_tokens)
# Update stats
self.usage_stats["claude"]["requests"] += 1
self.usage_stats["claude"]["tokens"] += input_tokens + output_tokens
self.usage_stats["claude"]["cost"] += cost
return {
"response": content,
"model": "claude-sonnet-4",
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
"cost": cost,
"streaming": False
}
async def _generate_gemini(
self,
prompt: str,
system_prompt: Optional[str],
max_tokens: int,
temperature: float,
image_url: Optional[str] = None,
audio_data: Optional[bytes] = None
) -> Dict[str, Any]:
"""Generate using Gemini 2.0 Flash (multimodal support)"""
if not self.google_key:
raise ValueError("Google API key not configured")
model = genai.GenerativeModel(
ModelType.GEMINI_2_FLASH.value,
system_instruction=system_prompt
)
# Build multimodal content
content_parts = []
if image_url:
# For image analysis
import httpx
async with httpx.AsyncClient() as client:
img_response = await client.get(image_url)
img_data = img_response.content
content_parts.append({"mime_type": "image/jpeg", "data": img_data})
if audio_data:
content_parts.append({"mime_type": "audio/wav", "data": audio_data})
content_parts.append(prompt)
response = await model.generate_content_async(
content_parts,
generation_config=genai.GenerationConfig(
max_output_tokens=max_tokens,
temperature=temperature
)
)
content = response.text
# Gemini free tier - no cost tracking
self.usage_stats["gemini"]["requests"] += 1
return {
"response": content,
"model": "gemini-2.0-flash",
"input_tokens": 0, # Not provided in free tier
"output_tokens": 0,
"total_tokens": 0,
"cost": 0.0,
"streaming": False,
"multimodal": bool(image_url or audio_data)
}
async def _generate_gpt(
self,
prompt: str,
system_prompt: Optional[str],
max_tokens: int,
temperature: float,
stream: bool
) -> Dict[str, Any]:
"""Generate using GPT-4o-mini"""
if not self.openai_client:
raise ValueError("OpenAI API key not configured")
messages = [
{"role": "system", "content": system_prompt or "You are a helpful AI assistant."},
{"role": "user", "content": prompt}
]
response = await self.openai_client.chat.completions.create(
model=ModelType.GPT4O_MINI.value,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
stream=stream
)
if stream:
return {"response": response, "model": "gpt-4o-mini", "streaming": True}
content = response.choices[0].message.content
input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens
cost = self._calculate_cost(ModelType.GPT4O_MINI, input_tokens, output_tokens)
# Update stats
self.usage_stats["gpt4"]["requests"] += 1
self.usage_stats["gpt4"]["tokens"] += input_tokens + output_tokens
self.usage_stats["gpt4"]["cost"] += cost
return {
"response": content,
"model": "gpt-4o-mini",
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
"cost": cost,
"streaming": False
}
def _calculate_cost(self, model: ModelType, input_tokens: int, output_tokens: int) -> float:
"""Calculate cost for API usage"""
costs = self.model_costs[model]
input_cost = (input_tokens / 1_000_000) * costs["input"]
output_cost = (output_tokens / 1_000_000) * costs["output"]
return input_cost + output_cost
def get_usage_stats(self) -> Dict[str, Any]:
"""Get usage statistics across all models"""
total_cost = sum(stats["cost"] for stats in self.usage_stats.values())
total_requests = sum(stats["requests"] for stats in self.usage_stats.values())
return {
"total_requests": total_requests,
"total_cost": round(total_cost, 4),
"by_model": self.usage_stats,
"cost_breakdown": {
"claude": round(self.usage_stats["claude"]["cost"], 4),
"gemini": round(self.usage_stats["gemini"]["cost"], 4),
"gpt4": round(self.usage_stats["gpt4"]["cost"], 4),
}
}
def get_langchain_model(self, task_type: TaskType):
"""Get LangChain-compatible model for agent integration"""
model = self.select_model(task_type)
if model == ModelType.CLAUDE_SONNET:
return self.claude_lc
elif model == ModelType.GEMINI_2_FLASH:
return self.gemini_lc
elif model == ModelType.GPT4O_MINI:
return self.gpt_lc
return self.claude_lc # Default fallback
# Global router instance
router = MultiModelRouter()