import base64 import json import os from io import BytesIO from typing import Dict, Iterator, List, Optional, Union import requests from PIL import Image # Gemini API Configuration API_KEY_ENV_VAR = "GOOGLE_API_KEY" BASE_URL = os.getenv("GEMINI_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/models/") DEFAULT_TEXT_MODEL = "gemini-2.0-flash" # Default text model VISION_MODEL = "gemini-2.5-flash-lite" # Model with vision capability def _get_api_key(provided_key: Optional[str] = None) -> Optional[str]: """ Returns the provided key if valid, otherwise falls back to the environment variable. """ if provided_key and provided_key.strip(): return provided_key.strip() return os.getenv(API_KEY_ENV_VAR) def _format_payload_for_gemini(messages: List[Dict], image: Optional[Image.Image] = None) -> Optional[Dict]: """ Formats the payload for the Gemini API, handling both text-only and multimodal requests. """ system_instruction = None conversation_history = [] for msg in messages: if msg.get("role") == "system": system_instruction = {"parts": [{"text": msg.get("content", "")}]} else: conversation_history.append(msg) if not conversation_history and not image: return None # For the Vision API, the structure is a simple list of "parts" if image: contents = [] # Prepare the image part buffered = BytesIO() image.save(buffered, format="PNG") img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") contents.append( { "parts": [ {"inline_data": {"mime_type": "image/png", "data": img_base64}}, # Add the system and user prompt text in the same image part {"text": "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])}, ] } ) # For the text API, maintain the alternating "role" structure else: consolidated_contents = [] current_block = None for msg in conversation_history: role = "model" if msg.get("role") == "assistant" else "user" content = msg.get("content", "") if current_block and current_block["role"] == "user" and role == "user": current_block["parts"][0]["text"] += "\n" + content else: if current_block: consolidated_contents.append(current_block) current_block = {"role": role, "parts": [{"text": content}]} if current_block: consolidated_contents.append(current_block) contents = consolidated_contents payload = {"contents": contents} if system_instruction and not image: # System instruction is not supported the same way in the Vision API payload["system_instruction"] = system_instruction return payload def call_api( messages: List[Dict], model_name: Optional[str] = None, image: Optional[Image.Image] = None, stream: bool = False, temperature: float = 0.7, max_tokens: int = 8192, api_key: Optional[str] = None, ) -> Union[Iterator[str], str]: """ Calls the Google Gemini REST API, supporting text and multimodal inputs, with streaming. """ final_api_key = _get_api_key(api_key) if not final_api_key: error_msg = "Error: Authentication required. No API key provided and no server fallback found." return iter([error_msg]) if stream else error_msg # Choose the model based on the presence of an image if image: model_id = model_name or VISION_MODEL else: model_id = model_name or DEFAULT_TEXT_MODEL payload = _format_payload_for_gemini(messages, image) if not payload or not payload.get("contents"): error_msg = "Error: Conversation is empty or malformed." return iter([error_msg]) if stream else error_msg # Add generation configuration to the payload payload["safetySettings"] = [ {"category": f"HARM_CATEGORY_{cat}", "threshold": "BLOCK_NONE"} for cat in [ "SEXUALLY_EXPLICIT", "HATE_SPEECH", "HARASSMENT", "DANGEROUS_CONTENT", ] ] payload["generationConfig"] = { "temperature": temperature, "maxOutputTokens": max_tokens, } stream_param = "streamGenerateContent" if stream else "generateContent" request_url = f"{BASE_URL}{model_id}:{stream_param}?key={final_api_key}" headers = {"Content-Type": "application/json"} try: response = requests.post(request_url, headers=headers, json=payload, stream=stream, timeout=180) if response.status_code != 200: try: error_details = response.json() error_msg = error_details.get("error", {}).get("message", response.text) except json.JSONDecodeError: error_msg = response.text if response.status_code in [400, 401, 403]: raise ValueError(f"Gemini Auth Error: {error_msg}") else: raise RuntimeError(f"Gemini API Error ({response.status_code}): {error_msg}") if stream: # Re-implementing the streaming logic def stream_generator(): for line in response.iter_lines(): if line: decoded_line = line.decode("utf-8") # Streaming responses come in chunks; we need to extract the JSON if decoded_line.startswith("data: "): try: chunk = json.loads(decoded_line[6:]) if chunk.get("candidates"): yield chunk["candidates"][0]["content"]["parts"][0]["text"] except json.JSONDecodeError: # Ignore lines that are not valid JSON continue return stream_generator() else: # Logic for non-streaming data = response.json() if data.get("candidates") and data["candidates"][0].get("content", {}).get("parts"): return data["candidates"][0]["content"]["parts"][0]["text"] else: print(f"Gemini's response format unexpected. Full response: {data}") return f"[BLOCKED OR EMPTY RESPONSE]\n{data}" except requests.exceptions.RequestException as e: raise ConnectionError(f"Connection to Gemini failed: {str(e)}") except Exception as e: raise e