"""Gemini service for audio analysis via the Gemini API. This module provides a service for interacting with Gemini's API, supporting audio inputs, file uploads, streaming, and structured outputs. """ import os import base64 import json import mimetypes from typing import Optional, List, Dict, Any, Tuple, Union, Generator from pathlib import Path import requests prompt = """Analyze the input audio to generate detailed caption and lyrics. lyrics need contain structured tags for chorus, verse, bridge, etc. **Output Format:** ```json { "caption": , "lyrics": "[Intro] , [Verse] ..." } ``` """ class GeminiService: """Service for handling Gemini API audio interactions.""" SUPPORTED_AUDIO_TYPES = [ "audio/wav", "audio/mp3", "audio/aiff", "audio/aac", "audio/ogg", "audio/flac", ] def __init__( self, api_key: str, base_url: str = "https://generativelanguage.googleapis.com", model_name: str = "gemini-3-flash", ): self.api_key = api_key self.base_url = base_url.rstrip("/") self.model_name = model_name def _get_endpoint( self, endpoint_type: str = "generate", model_name: Optional[str] = None ) -> str: model = model_name or self.model_name if endpoint_type == "generate": return f"{self.base_url}/v1beta/models/{model}:generateContent" elif endpoint_type == "stream": return ( f"{self.base_url}/v1beta/models/{model}:streamGenerateContent?alt=sse" ) elif endpoint_type == "upload": return f"{self.base_url}/upload/v1beta/files" elif endpoint_type == "files": return f"{self.base_url}/v1beta/files" else: raise ValueError(f"Unknown endpoint type: {endpoint_type}") def _get_headers(self) -> Dict[str, str]: return {"x-goog-api-key": self.api_key, "Content-Type": "application/json"} def _encode_file_to_base64(self, file_path: str) -> str: with open(file_path, "rb") as f: return base64.b64encode(f.read()).decode("utf-8") def _get_mime_type(self, file_path: str) -> str: mime_type, _ = mimetypes.guess_type(file_path) if not mime_type: ext = Path(file_path).suffix.lower() mime_map = { ".mp3": "audio/mp3", ".wav": "audio/wav", ".aac": "audio/aac", ".ogg": "audio/ogg", ".flac": "audio/flac", ".aiff": "audio/aiff", } mime_type = mime_map.get(ext, "application/octet-stream") return mime_type def upload_file( self, file_path: str, display_name: Optional[str] = None, mime_type: Optional[str] = None, ) -> Optional[Dict[str, Any]]: """Upload a file using the Files API. Use this for files larger than 20MB or when you want to reuse files across multiple requests. """ try: if not os.path.exists(file_path): raise FileNotFoundError(f"File not found: {file_path}") if mime_type is None: mime_type = self._get_mime_type(file_path) num_bytes = os.path.getsize(file_path) if display_name is None: display_name = Path(file_path).stem init_headers = { "x-goog-api-key": self.api_key, "X-Goog-Upload-Protocol": "resumable", "X-Goog-Upload-Command": "start", "X-Goog-Upload-Header-Content-Length": str(num_bytes), "X-Goog-Upload-Header-Content-Type": mime_type, "Content-Type": "application/json", } init_data = {"file": {"display_name": display_name}} upload_endpoint = self._get_endpoint("upload") response = requests.post( upload_endpoint, headers=init_headers, json=init_data ) if response.status_code != 200: print( f"Failed to initiate upload: {response.status_code} - {response.text}" ) return None upload_url = response.headers.get("x-goog-upload-url") if not upload_url: print("No upload URL in response headers") return None with open(file_path, "rb") as f: file_data = f.read() upload_headers = { "Content-Length": str(num_bytes), "X-Goog-Upload-Offset": "0", "X-Goog-Upload-Command": "upload, finalize", } response = requests.post(upload_url, headers=upload_headers, data=file_data) if response.status_code == 200: file_info = response.json() print( f"File uploaded successfully: {file_info.get('file', {}).get('uri')}" ) return file_info.get("file") else: print( f"Failed to upload file data: {response.status_code} - {response.text}" ) return None except Exception as e: print(f"Error uploading file: {e}") return None def get_file_info(self, file_name: str) -> Optional[Dict[str, Any]]: try: files_endpoint = self._get_endpoint("files") url = f"{files_endpoint}/{file_name}" response = requests.get(url, headers=self._get_headers()) if response.status_code == 200: return response.json() return None except Exception as e: print(f"Error getting file info: {e}") return None def delete_file(self, file_name: str) -> bool: try: files_endpoint = self._get_endpoint("files") url = f"{files_endpoint}/{file_name}" response = requests.delete(url, headers=self._get_headers()) return response.status_code == 204 except Exception as e: print(f"Error deleting file: {e}") return False def generate_content( self, prompt: Union[str, List[Dict[str, Any]]], audio: Optional[List[str]] = None, file_uris: Optional[List[Dict[str, str]]] = None, system_instruction: Optional[str] = None, generation_config: Optional[Dict[str, Any]] = None, history: Optional[List[Dict[str, Any]]] = None, model_name: Optional[str] = None, ) -> Optional[Dict[str, Any]]: """Generate content with audio support. Args: prompt: Text prompt or list of content parts audio: List of audio file paths (for inline data) file_uris: List of file URIs from uploaded files, each dict with 'mime_type' and 'file_uri' system_instruction: Optional system instruction to guide behavior generation_config: Optional generation configuration (temperature, topP, etc.) history: Optional conversation history for multi-turn chat model_name: Optional model name to use for this request (overrides default) """ try: body = self._build_request_body( prompt=prompt, audio=audio, file_uris=file_uris, system_instruction=system_instruction, generation_config=generation_config, history=history, ) endpoint = self._get_endpoint("generate", model_name) response = requests.post( endpoint, headers=self._get_headers(), json=body, timeout=300 ) if response.status_code == 200: return response.json() else: print( f"API call failed: {response.status_code} - {response.text[:500]}" ) return None except Exception as e: print(f"Error generating content: {e}") return None def stream_generate_content( self, prompt: Union[str, List[Dict[str, Any]]], audio: Optional[List[str]] = None, file_uris: Optional[List[Dict[str, str]]] = None, system_instruction: Optional[str] = None, generation_config: Optional[Dict[str, Any]] = None, history: Optional[List[Dict[str, Any]]] = None, model_name: Optional[str] = None, ) -> Generator[Dict[str, Any], None, None]: """Stream generate content with server-sent events.""" try: body = self._build_request_body( prompt=prompt, audio=audio, file_uris=file_uris, system_instruction=system_instruction, generation_config=generation_config, history=history, ) endpoint = self._get_endpoint("stream", model_name) response = requests.post( endpoint, headers=self._get_headers(), json=body, stream=True, timeout=300, ) if response.status_code == 200: for line in response.iter_lines(): if line: line = line.decode("utf-8") if line.startswith("data: "): data = line[6:] try: yield json.loads(data) except json.JSONDecodeError: continue else: print( f"Streaming failed: {response.status_code} - {response.text[:500]}" ) except Exception as e: print(f"Error streaming content: {e}") def _build_request_body( self, prompt: Union[str, List[Dict[str, Any]]], audio: Optional[List[str]] = None, file_uris: Optional[List[Dict[str, str]]] = None, system_instruction: Optional[str] = None, generation_config: Optional[Dict[str, Any]] = None, history: Optional[List[Dict[str, Any]]] = None, ) -> Dict[str, Any]: body = {} if system_instruction: body["system_instruction"] = {"parts": [{"text": system_instruction}]} contents = [] if history: contents.extend(history) parts = [] if isinstance(prompt, str): parts.append({"text": prompt}) elif isinstance(prompt, list): parts.extend(prompt) if audio: for audio_path in audio: mime_type = self._get_mime_type(audio_path) audio_data = self._encode_file_to_base64(audio_path) parts.append( {"inline_data": {"mime_type": mime_type, "data": audio_data}} ) if file_uris: for file_ref in file_uris: parts.append( { "file_data": { "mime_type": file_ref.get("mime_type"), "file_uri": file_ref.get("file_uri"), } } ) contents.append({"role": "user", "parts": parts}) body["contents"] = contents if generation_config: body["generationConfig"] = generation_config return body def extract_text(self, response: Dict[str, Any]) -> Optional[str]: try: if response.get("candidates"): candidate = response["candidates"][0] if candidate.get("content", {}).get("parts"): parts = candidate["content"]["parts"] texts = [part.get("text", "") for part in parts if "text" in part] return " ".join(texts) except Exception as e: print(f"Error extracting text: {e}") return None def analyze_audio( self, audio_path: str, prompt: str, model_name: Optional[str] = "gemini-3-pro-preview", use_upload: bool = False, **kwargs, ) -> Optional[str]: """Analyze audio with a text prompt. Args: audio_path: Path to the audio file prompt: Text prompt for analysis model_name: Model name for analysis use_upload: Whether to use File API (recommended for audio > 20MB) """ generation_config = { "thinkingConfig": { "thinkingLevel": "HIGH", }, "responseMimeType": "application/json", } if use_upload: file_info = self.upload_file(audio_path) if file_info: file_uris = [ { "mime_type": file_info.get("mimeType"), "file_uri": file_info.get("uri"), } ] response = self.generate_content( prompt=prompt, file_uris=file_uris, model_name=model_name, generation_config=generation_config, **kwargs, ) if response is None: return None data = self.extract_text(response) if data is None: return None result = data.replace("```json", "").replace("```", "") return result return None else: response = self.generate_content( prompt=prompt, audio=[audio_path], model_name=model_name, generation_config=generation_config, **kwargs, ) if response is None: return None data = self.extract_text(response) if data is None: return None result = data.replace("```json", "").replace("```", "") return result def transcribe_audio( self, audio_path: str, use_upload: bool = False, model_name: Optional[str] = "gemini-3-pro-preview", **kwargs, ) -> Optional[str]: """Transcribe audio to text.""" response = self.analyze_audio( audio_path=audio_path, prompt="Please provide a complete transcription of this audio.", use_upload=use_upload, model_name=model_name, **kwargs, ) return self.extract_text(response) if response else None _gemini_service = None def get_gemini_service( api_key: Optional[str] = None, base_url: str = "https://generativelanguage.googleapis.com", model_name: str = "gemini-3-flash", ) -> GeminiService: global _gemini_service if not api_key: raise ValueError( "API key must be provided or set in GEMINI_API_KEY environment variable" ) if _gemini_service is None: _gemini_service = GeminiService(api_key, base_url, model_name) return _gemini_service AUDIO_EXTENSIONS = {".mp3", ".wav", ".flac", ".ogg", ".aac", ".aiff"} def extract_json_from_text(text: str): start = text.find("{") end = text.rfind("}") if start == -1 or end == -1: return None return text[start: end + 1] def analysis_audio_by_gemini( api_key: str, base_url: str, audio_path: str, duration=None, max_retry: int = 3 ): global _gemini_service if _gemini_service is None: _gemini_service = get_gemini_service(api_key, base_url) result = _gemini_service.analyze_audio( audio_path, prompt, model_name="gemini-3-pro-preview" ) try: json_result = json.loads(result) except: json_result = extract_json_from_text(result) if json_result is None: raise Exception(f"无法解析json, {result}") return json_result def analysis_audio_to_files( api_key: str, base_url: str, audio_path: str, output_dir: str, ): """Analyze audio and save lyrics and caption as separate txt files. Output files: {stem}.lyrics.txt - lyrics {stem}.caption.txt - caption Args: api_key: Gemini API key base_url: Gemini API base URL audio_path: Path to the audio file output_dir: Directory to save output txt files """ json_result = analysis_audio_by_gemini(api_key, base_url, audio_path) if isinstance(json_result, str): json_result = json.loads(json_result) stem = Path(audio_path).stem os.makedirs(output_dir, exist_ok=True) lyrics = json_result.get("lyrics", "") caption = json_result.get("caption", "") lyrics_path = os.path.join(output_dir, f"{stem}.lyrics.txt") with open(lyrics_path, "w", encoding="utf-8") as f: f.write(lyrics) caption_path = os.path.join(output_dir, f"{stem}.caption.txt") with open(caption_path, "w", encoding="utf-8") as f: f.write(caption) return lyrics_path, caption_path def process_folder( input_dir: str, output_dir: str, api_key: str, base_url: str = "https://generativelanguage.googleapis.com", ) -> List[str]: """Analyze all audio files in a folder, saving lyrics and caption txt files. Args: input_dir: Directory containing audio files output_dir: Directory to save output txt files api_key: Gemini API key base_url: Gemini API base URL Returns: List of output file paths """ input_path = Path(input_dir) if not input_path.is_dir(): raise NotADirectoryError(f"Input directory not found: {input_dir}") os.makedirs(output_dir, exist_ok=True) audio_files = sorted( f for f in input_path.iterdir() if f.is_file() and f.suffix.lower() in AUDIO_EXTENSIONS ) if not audio_files: print(f"No audio files found in {input_dir}") return [] output_paths = [] for i, audio_file in enumerate(audio_files, 1): print(f"[{i}/{len(audio_files)}] {audio_file.name}") try: lyrics_path, caption_path = analysis_audio_to_files( api_key=api_key, base_url=base_url, audio_path=str(audio_file), output_dir=output_dir, ) output_paths.extend([lyrics_path, caption_path]) print(f" -> {Path(lyrics_path).name}, {Path(caption_path).name}") except Exception as e: print(f" Error: {e}") print(f"Done: {len(output_paths) // 2}/{len(audio_files)} files processed") return output_paths