Spaces:
Sleeping
Sleeping
| """ | |
| AI Client Manager - Unified interface for multiple LLM providers | |
| Supports Anthropic Claude, Google Gemini, and OpenAI | |
| """ | |
| import os | |
| from typing import Optional, Dict, Any, List | |
| from enum import Enum | |
| import anthropic | |
| import google.generativeai as genai | |
| from openai import OpenAI | |
| from src.config import config | |
| class Provider(str, Enum): | |
| """Supported AI providers""" | |
| ANTHROPIC = "anthropic" | |
| GOOGLE = "google" | |
| OPENAI = "openai" | |
| class AIClient: | |
| """Unified AI client supporting multiple providers""" | |
| def __init__(self): | |
| """Initialize all available clients""" | |
| self.anthropic_client: Optional[anthropic.Anthropic] = None | |
| self.google_client: Optional[Any] = None | |
| self.openai_client: Optional[OpenAI] = None | |
| self._initialize_clients() | |
| def _initialize_clients(self): | |
| """Initialize available AI clients based on API keys""" | |
| # Anthropic Claude - read directly from environment | |
| anthropic_key = os.getenv("ANTHROPIC_API_KEY") | |
| if anthropic_key: | |
| try: | |
| self.anthropic_client = anthropic.Anthropic(api_key=anthropic_key) | |
| print(f"✅ Anthropic client initialized (key: {anthropic_key[:10]}...)") | |
| except Exception as e: | |
| print(f"❌ Failed to initialize Anthropic client: {e}") | |
| else: | |
| print("⚠️ ANTHROPIC_API_KEY not found in environment") | |
| # Google Gemini - read directly from environment | |
| google_key = os.getenv("GOOGLE_API_KEY") | |
| if google_key: | |
| try: | |
| genai.configure(api_key=google_key) | |
| self.google_client = genai | |
| print(f"✅ Google client initialized (key: {google_key[:10]}...)") | |
| except Exception as e: | |
| print(f"❌ Failed to initialize Google client: {e}") | |
| else: | |
| print("⚠️ GOOGLE_API_KEY not found in environment") | |
| # OpenAI - read directly from environment | |
| openai_key = os.getenv("OPENAI_API_KEY") | |
| if openai_key: | |
| try: | |
| self.openai_client = OpenAI(api_key=openai_key) | |
| print(f"✅ OpenAI client initialized (key: {openai_key[:10]}...)") | |
| except Exception as e: | |
| print(f"❌ Failed to initialize OpenAI client: {e}") | |
| else: | |
| print("⚠️ OPENAI_API_KEY not found in environment") | |
| def generate( | |
| self, | |
| prompt: str, | |
| provider: Optional[Provider] = None, | |
| model: Optional[str] = None, | |
| temperature: float = 0.7, | |
| max_tokens: int = 2000, | |
| system_prompt: Optional[str] = None, | |
| ) -> str: | |
| """ | |
| Generate text using specified provider | |
| Args: | |
| prompt: User prompt | |
| provider: AI provider to use (defaults to config) | |
| model: Model name (defaults to config) | |
| temperature: Sampling temperature | |
| max_tokens: Maximum tokens to generate | |
| system_prompt: System prompt for context | |
| Returns: | |
| Generated text | |
| """ | |
| # Use defaults from config if not specified | |
| if provider is None: | |
| provider = Provider(config.model.primary_provider) | |
| if model is None: | |
| model = config.model.primary_model | |
| # Route to appropriate provider | |
| if provider == Provider.ANTHROPIC: | |
| return self._generate_anthropic(prompt, model, temperature, max_tokens, system_prompt) | |
| elif provider == Provider.GOOGLE: | |
| return self._generate_google(prompt, model, temperature, max_tokens, system_prompt) | |
| elif provider == Provider.OPENAI: | |
| return self._generate_openai(prompt, model, temperature, max_tokens, system_prompt) | |
| else: | |
| raise ValueError(f"Unsupported provider: {provider}") | |
| def _generate_anthropic( | |
| self, | |
| prompt: str, | |
| model: str, | |
| temperature: float, | |
| max_tokens: int, | |
| system_prompt: Optional[str], | |
| ) -> str: | |
| """Generate using Anthropic Claude""" | |
| if not self.anthropic_client: | |
| # Try to re-initialize from environment | |
| anthropic_key = os.getenv("ANTHROPIC_API_KEY") | |
| if anthropic_key: | |
| print(f"Re-initializing Anthropic client with key: {anthropic_key[:10]}...") | |
| try: | |
| self.anthropic_client = anthropic.Anthropic(api_key=anthropic_key) | |
| print("✅ Anthropic client re-initialized successfully") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to re-initialize Anthropic client: {e}") | |
| else: | |
| raise RuntimeError("Anthropic client not initialized. Set ANTHROPIC_API_KEY in HF Spaces Secrets.") | |
| messages = [{"role": "user", "content": prompt}] | |
| kwargs = { | |
| "model": model, | |
| "messages": messages, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| } | |
| if system_prompt: | |
| kwargs["system"] = system_prompt | |
| try: | |
| response = self.anthropic_client.messages.create(**kwargs) | |
| return response.content[0].text | |
| except Exception as e: | |
| raise RuntimeError(f"Anthropic API error: {e}") | |
| def _generate_google( | |
| self, | |
| prompt: str, | |
| model: str, | |
| temperature: float, | |
| max_tokens: int, | |
| system_prompt: Optional[str], | |
| ) -> str: | |
| """Generate using Google Gemini""" | |
| if not self.google_client: | |
| raise RuntimeError("Google client not initialized") | |
| try: | |
| gemini_model = self.google_client.GenerativeModel(model) | |
| generation_config = { | |
| "temperature": temperature, | |
| "max_output_tokens": max_tokens, | |
| } | |
| # Combine system prompt and user prompt | |
| full_prompt = prompt | |
| if system_prompt: | |
| full_prompt = f"{system_prompt}\n\n{prompt}" | |
| response = gemini_model.generate_content( | |
| full_prompt, | |
| generation_config=generation_config | |
| ) | |
| return response.text | |
| except Exception as e: | |
| raise RuntimeError(f"Google API error: {e}") | |
| def _generate_openai( | |
| self, | |
| prompt: str, | |
| model: str, | |
| temperature: float, | |
| max_tokens: int, | |
| system_prompt: Optional[str], | |
| ) -> str: | |
| """Generate using OpenAI""" | |
| if not self.openai_client: | |
| raise RuntimeError("OpenAI client not initialized") | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.append({"role": "user", "content": prompt}) | |
| try: | |
| response = self.openai_client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| raise RuntimeError(f"OpenAI API error: {e}") | |
| def generate_with_memory( | |
| self, | |
| prompt: str, | |
| context: str, | |
| provider: Optional[Provider] = None, | |
| model: Optional[str] = None, | |
| ) -> str: | |
| """ | |
| Generate with long context using memory model (Gemini 2.0) | |
| Args: | |
| prompt: User prompt | |
| context: Long context/memory to include | |
| provider: Provider to use (defaults to memory provider) | |
| model: Model to use (defaults to memory model) | |
| Returns: | |
| Generated text | |
| """ | |
| # Use memory provider by default | |
| if provider is None: | |
| provider = Provider(config.model.memory_provider) | |
| if model is None: | |
| model = config.model.memory_model | |
| # Combine context and prompt | |
| full_prompt = f"""# Campaign Context | |
| {context} | |
| # Current Query | |
| {prompt} | |
| Please answer based on the campaign context provided.""" | |
| return self.generate( | |
| prompt=full_prompt, | |
| provider=provider, | |
| model=model, | |
| temperature=config.model.balanced_temp, | |
| max_tokens=config.model.max_tokens_memory, | |
| ) | |
| def generate_creative(self, prompt: str, system_prompt: Optional[str] = None) -> str: | |
| """Generate creative content (characters, stories, etc.)""" | |
| return self.generate( | |
| prompt=prompt, | |
| temperature=config.model.creative_temp, | |
| max_tokens=config.model.max_tokens_generation, | |
| system_prompt=system_prompt, | |
| ) | |
| def generate_precise(self, prompt: str, system_prompt: Optional[str] = None) -> str: | |
| """Generate precise content (rules, stats, etc.)""" | |
| return self.generate( | |
| prompt=prompt, | |
| temperature=config.model.precise_temp, | |
| max_tokens=config.model.max_tokens_generation, | |
| system_prompt=system_prompt, | |
| ) | |
| def is_available(self, provider: Provider) -> bool: | |
| """Check if provider is available""" | |
| if provider == Provider.ANTHROPIC: | |
| return self.anthropic_client is not None | |
| elif provider == Provider.GOOGLE: | |
| return self.google_client is not None | |
| elif provider == Provider.OPENAI: | |
| return self.openai_client is not None | |
| return False | |
| # Global client instance | |
| _client: Optional[AIClient] = None | |
| def get_ai_client() -> AIClient: | |
| """Get or create global AI client instance""" | |
| global _client | |
| if _client is None: | |
| _client = AIClient() | |
| return _client | |
| # Convenience functions | |
| def generate_text( | |
| prompt: str, | |
| temperature: float = 0.7, | |
| max_tokens: int = 2000, | |
| system_prompt: Optional[str] = None, | |
| ) -> str: | |
| """Quick text generation""" | |
| client = get_ai_client() | |
| return client.generate(prompt, temperature=temperature, max_tokens=max_tokens, system_prompt=system_prompt) | |
| def generate_creative_text(prompt: str, system_prompt: Optional[str] = None) -> str: | |
| """Quick creative text generation""" | |
| client = get_ai_client() | |
| return client.generate_creative(prompt, system_prompt=system_prompt) | |
| def generate_precise_text(prompt: str, system_prompt: Optional[str] = None) -> str: | |
| """Quick precise text generation""" | |
| client = get_ai_client() | |
| return client.generate_precise(prompt, system_prompt=system_prompt) | |