|
|
import base64 |
|
|
import json |
|
|
import os |
|
|
from io import BytesIO |
|
|
from typing import Dict, Iterator, List, Optional, Union |
|
|
|
|
|
import requests |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
API_KEY_ENV_VAR = "GOOGLE_API_KEY" |
|
|
BASE_URL = os.getenv("GEMINI_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/models/") |
|
|
DEFAULT_TEXT_MODEL = "gemini-2.0-flash" |
|
|
VISION_MODEL = "gemini-2.5-flash-lite" |
|
|
|
|
|
|
|
|
def _get_api_key(provided_key: Optional[str] = None) -> Optional[str]: |
|
|
""" |
|
|
Returns the provided key if valid, otherwise falls back to the environment variable. |
|
|
""" |
|
|
if provided_key and provided_key.strip(): |
|
|
return provided_key.strip() |
|
|
return os.getenv(API_KEY_ENV_VAR) |
|
|
|
|
|
|
|
|
def _format_payload_for_gemini(messages: List[Dict], image: Optional[Image.Image] = None) -> Optional[Dict]: |
|
|
""" |
|
|
Formats the payload for the Gemini API, handling both text-only and multimodal requests. |
|
|
""" |
|
|
system_instruction = None |
|
|
conversation_history = [] |
|
|
|
|
|
for msg in messages: |
|
|
if msg.get("role") == "system": |
|
|
system_instruction = {"parts": [{"text": msg.get("content", "")}]} |
|
|
else: |
|
|
conversation_history.append(msg) |
|
|
|
|
|
if not conversation_history and not image: |
|
|
return None |
|
|
|
|
|
|
|
|
if image: |
|
|
contents = [] |
|
|
|
|
|
buffered = BytesIO() |
|
|
image.save(buffered, format="PNG") |
|
|
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
contents.append( |
|
|
{ |
|
|
"parts": [ |
|
|
{"inline_data": {"mime_type": "image/png", "data": img_base64}}, |
|
|
|
|
|
{"text": "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])}, |
|
|
] |
|
|
} |
|
|
) |
|
|
|
|
|
else: |
|
|
consolidated_contents = [] |
|
|
current_block = None |
|
|
for msg in conversation_history: |
|
|
role = "model" if msg.get("role") == "assistant" else "user" |
|
|
content = msg.get("content", "") |
|
|
|
|
|
if current_block and current_block["role"] == "user" and role == "user": |
|
|
current_block["parts"][0]["text"] += "\n" + content |
|
|
else: |
|
|
if current_block: |
|
|
consolidated_contents.append(current_block) |
|
|
current_block = {"role": role, "parts": [{"text": content}]} |
|
|
if current_block: |
|
|
consolidated_contents.append(current_block) |
|
|
contents = consolidated_contents |
|
|
|
|
|
payload = {"contents": contents} |
|
|
if system_instruction and not image: |
|
|
payload["system_instruction"] = system_instruction |
|
|
|
|
|
return payload |
|
|
|
|
|
|
|
|
def call_api( |
|
|
messages: List[Dict], |
|
|
model_name: Optional[str] = None, |
|
|
image: Optional[Image.Image] = None, |
|
|
stream: bool = False, |
|
|
temperature: float = 0.7, |
|
|
max_tokens: int = 8192, |
|
|
api_key: Optional[str] = None, |
|
|
) -> Union[Iterator[str], str]: |
|
|
""" |
|
|
Calls the Google Gemini REST API, supporting text and multimodal inputs, with streaming. |
|
|
""" |
|
|
final_api_key = _get_api_key(api_key) |
|
|
if not final_api_key: |
|
|
error_msg = "Error: Authentication required. No API key provided and no server fallback found." |
|
|
return iter([error_msg]) if stream else error_msg |
|
|
|
|
|
|
|
|
if image: |
|
|
model_id = model_name or VISION_MODEL |
|
|
else: |
|
|
model_id = model_name or DEFAULT_TEXT_MODEL |
|
|
|
|
|
payload = _format_payload_for_gemini(messages, image) |
|
|
if not payload or not payload.get("contents"): |
|
|
error_msg = "Error: Conversation is empty or malformed." |
|
|
return iter([error_msg]) if stream else error_msg |
|
|
|
|
|
|
|
|
payload["safetySettings"] = [ |
|
|
{"category": f"HARM_CATEGORY_{cat}", "threshold": "BLOCK_NONE"} |
|
|
for cat in [ |
|
|
"SEXUALLY_EXPLICIT", |
|
|
"HATE_SPEECH", |
|
|
"HARASSMENT", |
|
|
"DANGEROUS_CONTENT", |
|
|
] |
|
|
] |
|
|
payload["generationConfig"] = { |
|
|
"temperature": temperature, |
|
|
"maxOutputTokens": max_tokens, |
|
|
} |
|
|
|
|
|
stream_param = "streamGenerateContent" if stream else "generateContent" |
|
|
request_url = f"{BASE_URL}{model_id}:{stream_param}?key={final_api_key}" |
|
|
headers = {"Content-Type": "application/json"} |
|
|
|
|
|
try: |
|
|
response = requests.post(request_url, headers=headers, json=payload, stream=stream, timeout=180) |
|
|
if response.status_code != 200: |
|
|
try: |
|
|
error_details = response.json() |
|
|
error_msg = error_details.get("error", {}).get("message", response.text) |
|
|
except json.JSONDecodeError: |
|
|
error_msg = response.text |
|
|
|
|
|
if response.status_code in [400, 401, 403]: |
|
|
raise ValueError(f"Gemini Auth Error: {error_msg}") |
|
|
else: |
|
|
raise RuntimeError(f"Gemini API Error ({response.status_code}): {error_msg}") |
|
|
|
|
|
if stream: |
|
|
|
|
|
def stream_generator(): |
|
|
for line in response.iter_lines(): |
|
|
if line: |
|
|
decoded_line = line.decode("utf-8") |
|
|
|
|
|
if decoded_line.startswith("data: "): |
|
|
try: |
|
|
chunk = json.loads(decoded_line[6:]) |
|
|
if chunk.get("candidates"): |
|
|
yield chunk["candidates"][0]["content"]["parts"][0]["text"] |
|
|
except json.JSONDecodeError: |
|
|
|
|
|
continue |
|
|
|
|
|
return stream_generator() |
|
|
else: |
|
|
|
|
|
data = response.json() |
|
|
if data.get("candidates") and data["candidates"][0].get("content", {}).get("parts"): |
|
|
return data["candidates"][0]["content"]["parts"][0]["text"] |
|
|
else: |
|
|
print(f"Gemini's response format unexpected. Full response: {data}") |
|
|
return f"[BLOCKED OR EMPTY RESPONSE]\n{data}" |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
raise ConnectionError(f"Connection to Gemini failed: {str(e)}") |
|
|
except Exception as e: |
|
|
raise e |
|
|
|