elismasilva's picture
formatting fixes
5d46e12
import base64
import json
import os
from io import BytesIO
from typing import Dict, Iterator, List, Optional, Union
import requests
from PIL import Image
# Gemini API Configuration
API_KEY_ENV_VAR = "GOOGLE_API_KEY"
BASE_URL = os.getenv("GEMINI_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/models/")
DEFAULT_TEXT_MODEL = "gemini-2.0-flash" # Default text model
VISION_MODEL = "gemini-2.5-flash-lite" # Model with vision capability
def _get_api_key(provided_key: Optional[str] = None) -> Optional[str]:
"""
Returns the provided key if valid, otherwise falls back to the environment variable.
"""
if provided_key and provided_key.strip():
return provided_key.strip()
return os.getenv(API_KEY_ENV_VAR)
def _format_payload_for_gemini(messages: List[Dict], image: Optional[Image.Image] = None) -> Optional[Dict]:
"""
Formats the payload for the Gemini API, handling both text-only and multimodal requests.
"""
system_instruction = None
conversation_history = []
for msg in messages:
if msg.get("role") == "system":
system_instruction = {"parts": [{"text": msg.get("content", "")}]}
else:
conversation_history.append(msg)
if not conversation_history and not image:
return None
# For the Vision API, the structure is a simple list of "parts"
if image:
contents = []
# Prepare the image part
buffered = BytesIO()
image.save(buffered, format="PNG")
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
contents.append(
{
"parts": [
{"inline_data": {"mime_type": "image/png", "data": img_base64}},
# Add the system and user prompt text in the same image part
{"text": "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])},
]
}
)
# For the text API, maintain the alternating "role" structure
else:
consolidated_contents = []
current_block = None
for msg in conversation_history:
role = "model" if msg.get("role") == "assistant" else "user"
content = msg.get("content", "")
if current_block and current_block["role"] == "user" and role == "user":
current_block["parts"][0]["text"] += "\n" + content
else:
if current_block:
consolidated_contents.append(current_block)
current_block = {"role": role, "parts": [{"text": content}]}
if current_block:
consolidated_contents.append(current_block)
contents = consolidated_contents
payload = {"contents": contents}
if system_instruction and not image: # System instruction is not supported the same way in the Vision API
payload["system_instruction"] = system_instruction
return payload
def call_api(
messages: List[Dict],
model_name: Optional[str] = None,
image: Optional[Image.Image] = None,
stream: bool = False,
temperature: float = 0.7,
max_tokens: int = 8192,
api_key: Optional[str] = None,
) -> Union[Iterator[str], str]:
"""
Calls the Google Gemini REST API, supporting text and multimodal inputs, with streaming.
"""
final_api_key = _get_api_key(api_key)
if not final_api_key:
error_msg = "Error: Authentication required. No API key provided and no server fallback found."
return iter([error_msg]) if stream else error_msg
# Choose the model based on the presence of an image
if image:
model_id = model_name or VISION_MODEL
else:
model_id = model_name or DEFAULT_TEXT_MODEL
payload = _format_payload_for_gemini(messages, image)
if not payload or not payload.get("contents"):
error_msg = "Error: Conversation is empty or malformed."
return iter([error_msg]) if stream else error_msg
# Add generation configuration to the payload
payload["safetySettings"] = [
{"category": f"HARM_CATEGORY_{cat}", "threshold": "BLOCK_NONE"}
for cat in [
"SEXUALLY_EXPLICIT",
"HATE_SPEECH",
"HARASSMENT",
"DANGEROUS_CONTENT",
]
]
payload["generationConfig"] = {
"temperature": temperature,
"maxOutputTokens": max_tokens,
}
stream_param = "streamGenerateContent" if stream else "generateContent"
request_url = f"{BASE_URL}{model_id}:{stream_param}?key={final_api_key}"
headers = {"Content-Type": "application/json"}
try:
response = requests.post(request_url, headers=headers, json=payload, stream=stream, timeout=180)
if response.status_code != 200:
try:
error_details = response.json()
error_msg = error_details.get("error", {}).get("message", response.text)
except json.JSONDecodeError:
error_msg = response.text
if response.status_code in [400, 401, 403]:
raise ValueError(f"Gemini Auth Error: {error_msg}")
else:
raise RuntimeError(f"Gemini API Error ({response.status_code}): {error_msg}")
if stream:
# Re-implementing the streaming logic
def stream_generator():
for line in response.iter_lines():
if line:
decoded_line = line.decode("utf-8")
# Streaming responses come in chunks; we need to extract the JSON
if decoded_line.startswith("data: "):
try:
chunk = json.loads(decoded_line[6:])
if chunk.get("candidates"):
yield chunk["candidates"][0]["content"]["parts"][0]["text"]
except json.JSONDecodeError:
# Ignore lines that are not valid JSON
continue
return stream_generator()
else:
# Logic for non-streaming
data = response.json()
if data.get("candidates") and data["candidates"][0].get("content", {}).get("parts"):
return data["candidates"][0]["content"]["parts"][0]["text"]
else:
print(f"Gemini's response format unexpected. Full response: {data}")
return f"[BLOCKED OR EMPTY RESPONSE]\n{data}"
except requests.exceptions.RequestException as e:
raise ConnectionError(f"Connection to Gemini failed: {str(e)}")
except Exception as e:
raise e