Spaces:
Build error
Build error
| import requests | |
| import os | |
| import json | |
| import logging | |
| from typing import List, Dict, Any, Optional, Union | |
| logger = logging.getLogger(__name__) | |
| class GroqClient: | |
| """Direct implementation of Groq API client""" | |
| def __init__(self, | |
| model: str = "llama3-70b-8192", | |
| temperature: float = 0.7, | |
| max_tokens: int = 2048, | |
| groq_api_key: Optional[str] = None): | |
| """Initialize the Groq client with model parameters""" | |
| self.model = model | |
| self.temperature = temperature | |
| self.max_tokens = max_tokens | |
| # Get API key from params or environment variables | |
| self.api_key = groq_api_key or os.environ.get("GROQ_API_KEY_FALLBACK", os.environ.get("GROQ_API_KEY")) | |
| if not self.api_key: | |
| raise ValueError("Groq API key not found. Please provide it or set GROQ_API_KEY environment variable.") | |
| self.base_url = "https://api.groq.com/openai/v1" | |
| self.headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| def list_models(self) -> Dict[str, Any]: | |
| """List available models from Groq""" | |
| url = f"{self.base_url}/models" | |
| response = requests.get(url, headers=self.headers) | |
| response.raise_for_status() | |
| return response.json() | |
| def generate(self, | |
| messages: List[Dict[str, str]], | |
| stream: bool = False) -> Union[Dict[str, Any], Any]: | |
| """Generate a response given a list of messages""" | |
| url = f"{self.base_url}/chat/completions" | |
| payload = { | |
| "model": self.model, | |
| "messages": messages, | |
| "temperature": self.temperature, | |
| "max_tokens": self.max_tokens, | |
| "stream": stream | |
| } | |
| try: | |
| response = requests.post(url, headers=self.headers, json=payload, stream=stream) | |
| response.raise_for_status() | |
| if stream: | |
| return self._handle_streaming_response(response) | |
| else: | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Error calling Groq API: {e}") | |
| if hasattr(e, 'response') and e.response is not None: | |
| try: | |
| error_details = e.response.json() | |
| logger.error(f"API error details: {error_details}") | |
| except: | |
| logger.error(f"API error status code: {e.response.status_code}") | |
| raise | |
| def _handle_streaming_response(self, response): | |
| """Handle streaming response from Groq API""" | |
| for line in response.iter_lines(): | |
| if line: | |
| line = line.decode('utf-8') | |
| if line.startswith('data: '): | |
| data = line[6:] # Remove 'data: ' prefix | |
| if data.strip() == '[DONE]': | |
| break | |
| try: | |
| json_data = json.loads(data) | |
| yield json_data | |
| except json.JSONDecodeError: | |
| logger.error(f"Failed to decode JSON: {data}") | |
| def __call__(self, prompt: str, **kwargs) -> str: | |
| """Make the client callable with a prompt for compatibility""" | |
| messages = [ | |
| {"role": "user", "content": prompt} | |
| ] | |
| response = self.generate(messages) | |
| return response['choices'][0]['message']['content'] |