OnyxMunk's picture
Add LoRA training assets: scripts, docs (no binaries), ui, my_dataset
bc9c638
"""Gemini service for audio analysis via the Gemini API.
This module provides a service for interacting with Gemini's API,
supporting audio inputs, file uploads, streaming, and structured outputs.
"""
import os
import base64
import json
import mimetypes
from typing import Optional, List, Dict, Any, Tuple, Union, Generator
from pathlib import Path
import requests
prompt = """Analyze the input audio to generate detailed caption and lyrics.
lyrics need contain structured tags for chorus, verse, bridge, etc.
**Output Format:**
```json
{
"caption": <str>,
"lyrics": "[Intro] <str>, [Verse] <str>..."
}
```
"""
class GeminiService:
"""Service for handling Gemini API audio interactions."""
SUPPORTED_AUDIO_TYPES = [
"audio/wav",
"audio/mp3",
"audio/aiff",
"audio/aac",
"audio/ogg",
"audio/flac",
]
def __init__(
self,
api_key: str,
base_url: str = "https://generativelanguage.googleapis.com",
model_name: str = "gemini-3-flash",
):
self.api_key = api_key
self.base_url = base_url.rstrip("/")
self.model_name = model_name
def _get_endpoint(
self, endpoint_type: str = "generate", model_name: Optional[str] = None
) -> str:
model = model_name or self.model_name
if endpoint_type == "generate":
return f"{self.base_url}/v1beta/models/{model}:generateContent"
elif endpoint_type == "stream":
return (
f"{self.base_url}/v1beta/models/{model}:streamGenerateContent?alt=sse"
)
elif endpoint_type == "upload":
return f"{self.base_url}/upload/v1beta/files"
elif endpoint_type == "files":
return f"{self.base_url}/v1beta/files"
else:
raise ValueError(f"Unknown endpoint type: {endpoint_type}")
def _get_headers(self) -> Dict[str, str]:
return {"x-goog-api-key": self.api_key, "Content-Type": "application/json"}
def _encode_file_to_base64(self, file_path: str) -> str:
with open(file_path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
def _get_mime_type(self, file_path: str) -> str:
mime_type, _ = mimetypes.guess_type(file_path)
if not mime_type:
ext = Path(file_path).suffix.lower()
mime_map = {
".mp3": "audio/mp3",
".wav": "audio/wav",
".aac": "audio/aac",
".ogg": "audio/ogg",
".flac": "audio/flac",
".aiff": "audio/aiff",
}
mime_type = mime_map.get(ext, "application/octet-stream")
return mime_type
def upload_file(
self,
file_path: str,
display_name: Optional[str] = None,
mime_type: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
"""Upload a file using the Files API.
Use this for files larger than 20MB or when you want to reuse files
across multiple requests.
"""
try:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
if mime_type is None:
mime_type = self._get_mime_type(file_path)
num_bytes = os.path.getsize(file_path)
if display_name is None:
display_name = Path(file_path).stem
init_headers = {
"x-goog-api-key": self.api_key,
"X-Goog-Upload-Protocol": "resumable",
"X-Goog-Upload-Command": "start",
"X-Goog-Upload-Header-Content-Length": str(num_bytes),
"X-Goog-Upload-Header-Content-Type": mime_type,
"Content-Type": "application/json",
}
init_data = {"file": {"display_name": display_name}}
upload_endpoint = self._get_endpoint("upload")
response = requests.post(
upload_endpoint, headers=init_headers, json=init_data
)
if response.status_code != 200:
print(
f"Failed to initiate upload: {response.status_code} - {response.text}"
)
return None
upload_url = response.headers.get("x-goog-upload-url")
if not upload_url:
print("No upload URL in response headers")
return None
with open(file_path, "rb") as f:
file_data = f.read()
upload_headers = {
"Content-Length": str(num_bytes),
"X-Goog-Upload-Offset": "0",
"X-Goog-Upload-Command": "upload, finalize",
}
response = requests.post(upload_url, headers=upload_headers, data=file_data)
if response.status_code == 200:
file_info = response.json()
print(
f"File uploaded successfully: {file_info.get('file', {}).get('uri')}"
)
return file_info.get("file")
else:
print(
f"Failed to upload file data: {response.status_code} - {response.text}"
)
return None
except Exception as e:
print(f"Error uploading file: {e}")
return None
def get_file_info(self, file_name: str) -> Optional[Dict[str, Any]]:
try:
files_endpoint = self._get_endpoint("files")
url = f"{files_endpoint}/{file_name}"
response = requests.get(url, headers=self._get_headers())
if response.status_code == 200:
return response.json()
return None
except Exception as e:
print(f"Error getting file info: {e}")
return None
def delete_file(self, file_name: str) -> bool:
try:
files_endpoint = self._get_endpoint("files")
url = f"{files_endpoint}/{file_name}"
response = requests.delete(url, headers=self._get_headers())
return response.status_code == 204
except Exception as e:
print(f"Error deleting file: {e}")
return False
def generate_content(
self,
prompt: Union[str, List[Dict[str, Any]]],
audio: Optional[List[str]] = None,
file_uris: Optional[List[Dict[str, str]]] = None,
system_instruction: Optional[str] = None,
generation_config: Optional[Dict[str, Any]] = None,
history: Optional[List[Dict[str, Any]]] = None,
model_name: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
"""Generate content with audio support.
Args:
prompt: Text prompt or list of content parts
audio: List of audio file paths (for inline data)
file_uris: List of file URIs from uploaded files, each dict with 'mime_type' and 'file_uri'
system_instruction: Optional system instruction to guide behavior
generation_config: Optional generation configuration (temperature, topP, etc.)
history: Optional conversation history for multi-turn chat
model_name: Optional model name to use for this request (overrides default)
"""
try:
body = self._build_request_body(
prompt=prompt,
audio=audio,
file_uris=file_uris,
system_instruction=system_instruction,
generation_config=generation_config,
history=history,
)
endpoint = self._get_endpoint("generate", model_name)
response = requests.post(
endpoint, headers=self._get_headers(), json=body, timeout=300
)
if response.status_code == 200:
return response.json()
else:
print(
f"API call failed: {response.status_code} - {response.text[:500]}"
)
return None
except Exception as e:
print(f"Error generating content: {e}")
return None
def stream_generate_content(
self,
prompt: Union[str, List[Dict[str, Any]]],
audio: Optional[List[str]] = None,
file_uris: Optional[List[Dict[str, str]]] = None,
system_instruction: Optional[str] = None,
generation_config: Optional[Dict[str, Any]] = None,
history: Optional[List[Dict[str, Any]]] = None,
model_name: Optional[str] = None,
) -> Generator[Dict[str, Any], None, None]:
"""Stream generate content with server-sent events."""
try:
body = self._build_request_body(
prompt=prompt,
audio=audio,
file_uris=file_uris,
system_instruction=system_instruction,
generation_config=generation_config,
history=history,
)
endpoint = self._get_endpoint("stream", model_name)
response = requests.post(
endpoint,
headers=self._get_headers(),
json=body,
stream=True,
timeout=300,
)
if response.status_code == 200:
for line in response.iter_lines():
if line:
line = line.decode("utf-8")
if line.startswith("data: "):
data = line[6:]
try:
yield json.loads(data)
except json.JSONDecodeError:
continue
else:
print(
f"Streaming failed: {response.status_code} - {response.text[:500]}"
)
except Exception as e:
print(f"Error streaming content: {e}")
def _build_request_body(
self,
prompt: Union[str, List[Dict[str, Any]]],
audio: Optional[List[str]] = None,
file_uris: Optional[List[Dict[str, str]]] = None,
system_instruction: Optional[str] = None,
generation_config: Optional[Dict[str, Any]] = None,
history: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
body = {}
if system_instruction:
body["system_instruction"] = {"parts": [{"text": system_instruction}]}
contents = []
if history:
contents.extend(history)
parts = []
if isinstance(prompt, str):
parts.append({"text": prompt})
elif isinstance(prompt, list):
parts.extend(prompt)
if audio:
for audio_path in audio:
mime_type = self._get_mime_type(audio_path)
audio_data = self._encode_file_to_base64(audio_path)
parts.append(
{"inline_data": {"mime_type": mime_type, "data": audio_data}}
)
if file_uris:
for file_ref in file_uris:
parts.append(
{
"file_data": {
"mime_type": file_ref.get("mime_type"),
"file_uri": file_ref.get("file_uri"),
}
}
)
contents.append({"role": "user", "parts": parts})
body["contents"] = contents
if generation_config:
body["generationConfig"] = generation_config
return body
def extract_text(self, response: Dict[str, Any]) -> Optional[str]:
try:
if response.get("candidates"):
candidate = response["candidates"][0]
if candidate.get("content", {}).get("parts"):
parts = candidate["content"]["parts"]
texts = [part.get("text", "") for part in parts if "text" in part]
return " ".join(texts)
except Exception as e:
print(f"Error extracting text: {e}")
return None
def analyze_audio(
self,
audio_path: str,
prompt: str,
model_name: Optional[str] = "gemini-3-pro-preview",
use_upload: bool = False,
**kwargs,
) -> Optional[str]:
"""Analyze audio with a text prompt.
Args:
audio_path: Path to the audio file
prompt: Text prompt for analysis
model_name: Model name for analysis
use_upload: Whether to use File API (recommended for audio > 20MB)
"""
generation_config = {
"thinkingConfig": {
"thinkingLevel": "HIGH",
},
"responseMimeType": "application/json",
}
if use_upload:
file_info = self.upload_file(audio_path)
if file_info:
file_uris = [
{
"mime_type": file_info.get("mimeType"),
"file_uri": file_info.get("uri"),
}
]
response = self.generate_content(
prompt=prompt,
file_uris=file_uris,
model_name=model_name,
generation_config=generation_config,
**kwargs,
)
if response is None:
return None
data = self.extract_text(response)
if data is None:
return None
result = data.replace("```json", "").replace("```", "")
return result
return None
else:
response = self.generate_content(
prompt=prompt,
audio=[audio_path],
model_name=model_name,
generation_config=generation_config,
**kwargs,
)
if response is None:
return None
data = self.extract_text(response)
if data is None:
return None
result = data.replace("```json", "").replace("```", "")
return result
def transcribe_audio(
self,
audio_path: str,
use_upload: bool = False,
model_name: Optional[str] = "gemini-3-pro-preview",
**kwargs,
) -> Optional[str]:
"""Transcribe audio to text."""
response = self.analyze_audio(
audio_path=audio_path,
prompt="Please provide a complete transcription of this audio.",
use_upload=use_upload,
model_name=model_name,
**kwargs,
)
return self.extract_text(response) if response else None
_gemini_service = None
def get_gemini_service(
api_key: Optional[str] = None,
base_url: str = "https://generativelanguage.googleapis.com",
model_name: str = "gemini-3-flash",
) -> GeminiService:
global _gemini_service
if not api_key:
raise ValueError(
"API key must be provided or set in GEMINI_API_KEY environment variable"
)
if _gemini_service is None:
_gemini_service = GeminiService(api_key, base_url, model_name)
return _gemini_service
AUDIO_EXTENSIONS = {".mp3", ".wav", ".flac", ".ogg", ".aac", ".aiff"}
def extract_json_from_text(text: str):
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1:
return None
return text[start: end + 1]
def analysis_audio_by_gemini(
api_key: str, base_url: str, audio_path: str, duration=None, max_retry: int = 3
):
global _gemini_service
if _gemini_service is None:
_gemini_service = get_gemini_service(api_key, base_url)
result = _gemini_service.analyze_audio(
audio_path, prompt, model_name="gemini-3-pro-preview"
)
try:
json_result = json.loads(result)
except:
json_result = extract_json_from_text(result)
if json_result is None:
raise Exception(f"无法解析json, {result}")
return json_result
def analysis_audio_to_files(
api_key: str,
base_url: str,
audio_path: str,
output_dir: str,
):
"""Analyze audio and save lyrics and caption as separate txt files.
Output files:
{stem}.lyrics.txt - lyrics
{stem}.caption.txt - caption
Args:
api_key: Gemini API key
base_url: Gemini API base URL
audio_path: Path to the audio file
output_dir: Directory to save output txt files
"""
json_result = analysis_audio_by_gemini(api_key, base_url, audio_path)
if isinstance(json_result, str):
json_result = json.loads(json_result)
stem = Path(audio_path).stem
os.makedirs(output_dir, exist_ok=True)
lyrics = json_result.get("lyrics", "")
caption = json_result.get("caption", "")
lyrics_path = os.path.join(output_dir, f"{stem}.lyrics.txt")
with open(lyrics_path, "w", encoding="utf-8") as f:
f.write(lyrics)
caption_path = os.path.join(output_dir, f"{stem}.caption.txt")
with open(caption_path, "w", encoding="utf-8") as f:
f.write(caption)
return lyrics_path, caption_path
def process_folder(
input_dir: str,
output_dir: str,
api_key: str,
base_url: str = "https://generativelanguage.googleapis.com",
) -> List[str]:
"""Analyze all audio files in a folder, saving lyrics and caption txt files.
Args:
input_dir: Directory containing audio files
output_dir: Directory to save output txt files
api_key: Gemini API key
base_url: Gemini API base URL
Returns:
List of output file paths
"""
input_path = Path(input_dir)
if not input_path.is_dir():
raise NotADirectoryError(f"Input directory not found: {input_dir}")
os.makedirs(output_dir, exist_ok=True)
audio_files = sorted(
f for f in input_path.iterdir()
if f.is_file() and f.suffix.lower() in AUDIO_EXTENSIONS
)
if not audio_files:
print(f"No audio files found in {input_dir}")
return []
output_paths = []
for i, audio_file in enumerate(audio_files, 1):
print(f"[{i}/{len(audio_files)}] {audio_file.name}")
try:
lyrics_path, caption_path = analysis_audio_to_files(
api_key=api_key,
base_url=base_url,
audio_path=str(audio_file),
output_dir=output_dir,
)
output_paths.extend([lyrics_path, caption_path])
print(f" -> {Path(lyrics_path).name}, {Path(caption_path).name}")
except Exception as e:
print(f" Error: {e}")
print(f"Done: {len(output_paths) // 2}/{len(audio_files)} files processed")
return output_paths