Spaces:
Runtime error
Runtime error
| import os | |
| import gc | |
| import json | |
| import logging | |
| import tempfile | |
| from datetime import datetime, timedelta | |
| from pathlib import Path | |
| from dataclasses import dataclass | |
| import gradio as gr | |
| import whisper | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| import numpy as np | |
| import soundfile as sf | |
| import humanize | |
| import joblib | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Constants | |
| MAX_FILE_SIZE = 25 * 1024 * 1024 # 25MB | |
| MAX_AUDIO_DURATION = 600 # 10 minutes | |
| MIN_SAMPLE_RATE = 16000 # 16kHz | |
| SUPPORTED_FORMATS = {'.wav', '.mp3', '.m4a'} | |
| # Model configuration | |
| MODEL_CONFIG = { | |
| "path": "gpt2", | |
| "description": "Efficient open-source model for analysis", | |
| "memory_required": "8GB" | |
| } | |
| class VCStyle: | |
| name: str | |
| note_format: dict | |
| key_interests: list | |
| custom_sections: list | |
| insight_preferences: dict | |
| class AudioValidator: | |
| def validate_audio_file(file): | |
| stats = { | |
| 'file_size': None, | |
| 'duration': None, | |
| 'sample_rate': None, | |
| 'format': None | |
| } | |
| try: | |
| if file is None: | |
| logger.debug("No file was uploaded.") | |
| return False, "No file was uploaded.", stats | |
| # Check file size | |
| file_size = len(file.read()) | |
| file.seek(0) # Reset file pointer | |
| stats['file_size'] = humanize.naturalsize(file_size) | |
| logger.info(f"File size: {stats['file_size']}") | |
| if file_size > MAX_FILE_SIZE: | |
| logger.warning(f"File size exceeds limit: {stats['file_size']}") | |
| return False, f"File size ({stats['file_size']}) exceeds limit", stats | |
| # Check file extension | |
| file_extension = Path(file.name).suffix.lower() | |
| stats['format'] = file_extension | |
| logger.info(f"File format: {file_extension}") | |
| if file_extension not in SUPPORTED_FORMATS: | |
| logger.warning(f"Unsupported format: {file_extension}") | |
| return False, f"Unsupported format {file_extension}", stats | |
| # Create temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file: | |
| tmp_file.write(file.read()) | |
| tmp_file_path = tmp_file.name | |
| logger.debug(f"Temporary file created at {tmp_file_path}") | |
| try: | |
| # Check audio properties | |
| y, sr = sf.read(tmp_file_path) | |
| duration = len(y) / sr | |
| stats.update({ | |
| 'duration': str(timedelta(seconds=int(duration))), | |
| 'sample_rate': f"{sr/1000:.1f}kHz" | |
| }) | |
| logger.info(f"Audio duration: {stats['duration']}, Sample rate: {stats['sample_rate']}") | |
| if duration > MAX_AUDIO_DURATION: | |
| logger.warning(f"Duration exceeds limit: {stats['duration']}") | |
| return False, f"Duration ({stats['duration']}) exceeds limit", stats | |
| if sr < MIN_SAMPLE_RATE: | |
| logger.warning(f"Sample rate too low: {stats['sample_rate']}") | |
| return False, f"Sample rate too low ({stats['sample_rate']})", stats | |
| return True, "Audio file is valid", stats | |
| finally: | |
| os.unlink(tmp_file_path) | |
| logger.debug(f"Temporary file {tmp_file_path} deleted") | |
| except Exception as e: | |
| logger.exception("Validation error:") | |
| return False, str(e), stats | |
| class AudioProcessor: | |
| def __init__(self, model): | |
| self.model = model | |
| self.validator = AudioValidator() | |
| def process_audio(self, audio_file): | |
| stats = { | |
| 'status': 'processing', | |
| 'start_time': datetime.now(), | |
| 'file_info': None, | |
| 'processing_time': None, | |
| 'error': None | |
| } | |
| try: | |
| # Validate file | |
| logger.debug("Starting audio file validation.") | |
| is_valid, message, file_stats = self.validator.validate_audio_file(audio_file) | |
| stats['file_info'] = file_stats | |
| if not is_valid: | |
| stats['status'] = 'failed' | |
| stats['error'] = message | |
| logger.error(f"Audio validation failed: {message}") | |
| return None, stats | |
| # Process audio | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_stats['format']) as tmp_file: | |
| tmp_file.write(audio_file.read()) | |
| tmp_file_path = tmp_file.name | |
| logger.debug(f"Temporary file for processing created at {tmp_file_path}") | |
| try: | |
| logger.info("Starting transcription with Whisper model.") | |
| result = self.model.transcribe( | |
| tmp_file_path, | |
| language="en", | |
| task="transcribe", | |
| fp16=torch.cuda.is_available() | |
| ) | |
| stats['status'] = 'success' | |
| stats['processing_time'] = str(datetime.now() - stats['start_time']) | |
| logger.info(f"Transcription successful. Processing time: {stats['processing_time']}") | |
| return result["text"], stats | |
| finally: | |
| os.unlink(tmp_file_path) | |
| logger.debug(f"Temporary file {tmp_file_path} deleted after processing") | |
| except Exception as e: | |
| logger.exception("Processing error:") | |
| stats['status'] = 'failed' | |
| stats['error'] = str(e) | |
| return None, stats | |
| finally: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.debug("Cleared CUDA cache") | |
| gc.collect() | |
| logger.debug("Garbage collection complete") | |
| def load_whisper(): | |
| try: | |
| logger.info("Loading Whisper model.") | |
| cached_model = joblib.load("whisper_model_cache.pkl") if os.path.exists("whisper_model_cache.pkl") else None | |
| if cached_model: | |
| logger.info("Loaded Whisper model from cache.") | |
| return cached_model | |
| model = whisper.load_model("base") | |
| joblib.dump(model, "whisper_model_cache.pkl") | |
| logger.info("Whisper model loaded and cached.") | |
| return model | |
| except Exception as e: | |
| logger.error(f"Whisper model loading error: {str(e)}") | |
| return None | |
| def load_llm(): | |
| try: | |
| logger.info("Loading LLM model.") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_CONFIG["path"], | |
| trust_remote_code=True | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_CONFIG["path"], | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True | |
| ) | |
| logger.info("Initializing text generation pipeline.") | |
| return pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| top_p=0.95, | |
| repetition_penalty=1.15, | |
| batch_size=1 | |
| ) | |
| except Exception as e: | |
| logger.error(f"LLM loading error: {str(e)}") | |
| return None | |
| class ContentAnalyzer: | |
| def __init__(self, generator): | |
| self.generator = generator | |
| def analyze_text(self, text, vc_style): | |
| try: | |
| logger.info("Creating analysis prompt.") | |
| prompt = self._create_analysis_prompt(text, vc_style) | |
| logger.debug(f"Prompt created: {prompt}") | |
| response = self._generate_response(prompt) | |
| logger.info("Analysis response generated.") | |
| return self._parse_response(response) | |
| except Exception as e: | |
| logger.exception("Analysis error:") | |
| return None | |
| def _create_analysis_prompt(self, text, vc_style): | |
| interests = ', '.join(vc_style.key_interests) | |
| return f"""Analyze this startup pitch focusing on {interests}: | |
| {text} | |
| Provide structured insights for: | |
| 1. Key Points | |
| 2. Metrics | |
| 3. Risks | |
| 4. Questions""" | |
| def _generate_response(self, prompt): | |
| try: | |
| logger.info("Generating response using LLM.") | |
| response = self.generator(prompt) | |
| logger.debug(f"Generated response: {response}") | |
| return response[0]['generated_text'] | |
| except Exception as e: | |
| logger.exception("Generation error:") | |
| return "" | |
| def _parse_response(self, response): | |
| try: | |
| logger.info("Parsing generated response.") | |
| sections = response.split('\n\n') | |
| parsed = {} | |
| current_section = "general" | |
| for section in sections: | |
| if section.strip().endswith(':'): | |
| current_section = section.strip()[:-1].lower() | |
| parsed[current_section] = [] | |
| else: | |
| if current_section in parsed: | |
| parsed[current_section].append(section.strip()) | |
| else: | |
| parsed[current_section] = [section.strip()] | |
| logger.debug(f"Parsed response: {parsed}") | |
| return parsed | |
| except Exception as e: | |
| logger.exception("Parsing error:") | |
| return {"error": "Failed to parse response"} | |
| def process_audio_file(audio_file, vc_name, note_style, interests): | |
| logger.info("Processing audio file.") | |
| whisper_model = load_whisper() | |
| llm = load_llm() | |
| if not whisper_model or not llm: | |
| logger.error("Failed to load models.") | |
| return "Failed to load models. Please try again.", None | |
| audio_processor = AudioProcessor(whisper_model) | |
| analyzer = ContentAnalyzer(llm) | |
| transcription, stats = audio_processor.process_audio(audio_file) | |
| if transcription and stats['status'] == 'success': | |
| logger.info("Transcription successful, starting analysis.") | |
| vc_style = VCStyle( | |
| name=vc_name, | |
| note_format={"style": note_style}, | |
| key_interests=interests, | |
| custom_sections=[], | |
| insight_preferences={} | |
| ) | |
| analysis = analyzer.analyze_text(transcription, vc_style) | |
| return transcription, analysis, stats | |
| else: | |
| logger.error(f"Audio processing failed: {stats['error']}") | |
| return None, None, stats | |
| # Gradio Interface | |
| def main_interface(audio_file, vc_name, note_style, interests): | |
| logger.info("Starting main interface process.") | |
| transcription, analysis, stats = process_audio_file(audio_file, vc_name, note_style, interests) | |
| if transcription: | |
| logger.info("Interface processing completed successfully.") | |
| return transcription, json.dumps(analysis, indent=2), stats | |
| else: | |
| logger.error("Interface processing failed.") | |
| return "", "", stats | |
| iface = gr.Interface( | |
| fn=main_interface, | |
| inputs=[ | |
| gr.Audio(type="file", label="Upload Audio File (WAV, MP3, M4A)"), | |
| gr.Textbox(label="Your Name"), | |
| gr.Dropdown(choices=["Bullet Points", "Paragraphs", "Q&A"], label="Note Style"), | |
| gr.CheckboxGroup(choices=["Product", "Market", "Team", "Financials", "Technology"], label="Focus Areas") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Transcript"), | |
| gr.Textbox(label="Analysis"), | |
| gr.JSON(label="Processing Stats") | |
| ], | |
| title="VC Call Assistant", | |
| description="Upload an audio file, and get a transcript along with analysis tailored to your focus areas.", | |
| theme="huggingface" | |
| ) | |
| if __name__ == "__main__": | |
| logger.info("Launching Gradio interface.") | |
| iface.launch() | |
| # requirements.txt | |
| # gradio | |
| # whisper | |
| # torch | |
| # transformers | |
| # numpy | |
| # soundfile | |
| # humanize | |
| # huggingface_hub | |
| # SentencePiece # required by some models in transformers | |
| # ffmpeg-python # for handling audio files (may be required by Whisper) | |
| # typing-extensions | |
| # joblib # for model caching | |