Spaces:
Runtime error
Runtime error
| """Benchmark suite for voice model evaluation.""" | |
| import torch | |
| import json | |
| from pathlib import Path | |
| from datetime import datetime | |
| from typing import Dict, Any, List, Optional, Callable | |
| import logging | |
| from .metrics import MetricCalculator | |
| logger = logging.getLogger(__name__) | |
| class BenchmarkSuite: | |
| """ | |
| Comprehensive benchmark suite for voice models. | |
| Evaluates models on multiple metrics and persists results. | |
| """ | |
| def __init__(self, output_dir: str = "results"): | |
| """ | |
| Initialize benchmark suite. | |
| Args: | |
| output_dir: Directory to save benchmark results | |
| """ | |
| self.output_dir = Path(output_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| self.metric_calculator = MetricCalculator() | |
| self.results_history = [] | |
| logger.info(f"Initialized BenchmarkSuite with output_dir={output_dir}") | |
| def run_benchmark( | |
| self, | |
| model_fn: Callable, | |
| test_data: List[Dict[str, Any]], | |
| model_name: str = "model", | |
| checkpoint_path: Optional[str] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Run complete benchmark on a model. | |
| Args: | |
| model_fn: Model inference function | |
| test_data: List of test samples with audio and transcriptions | |
| model_name: Name identifier for the model | |
| checkpoint_path: Path to model checkpoint | |
| Returns: | |
| Dictionary containing all benchmark results | |
| """ | |
| logger.info(f"Running benchmark for {model_name} on {len(test_data)} samples") | |
| start_time = datetime.now() | |
| # Collect predictions and references | |
| predictions = [] | |
| references = [] | |
| audio_pairs = [] | |
| latencies = [] | |
| for sample in test_data: | |
| input_audio = sample['audio'] | |
| reference_text = sample.get('transcription', '') | |
| reference_audio = sample.get('reference_audio', input_audio) | |
| # Measure inference latency | |
| import time | |
| start = time.perf_counter() | |
| output = model_fn(input_audio) | |
| end = time.perf_counter() | |
| latencies.append((end - start) * 1000) | |
| # Extract prediction | |
| if isinstance(output, dict): | |
| pred_text = output.get('transcription', '') | |
| pred_audio = output.get('audio', input_audio) | |
| else: | |
| pred_text = '' | |
| pred_audio = output if isinstance(output, torch.Tensor) else input_audio | |
| predictions.append(pred_text) | |
| references.append(reference_text) | |
| audio_pairs.append((pred_audio, reference_audio)) | |
| # Compute metrics | |
| results = self.compute_metrics( | |
| predictions=predictions, | |
| references=references, | |
| audio_pairs=audio_pairs | |
| ) | |
| # Add latency metrics | |
| results['inference_time_ms'] = sum(latencies) / len(latencies) if latencies else 0.0 | |
| results['samples_per_second'] = len(test_data) / (sum(latencies) / 1000) if latencies else 0.0 | |
| # Add metadata | |
| results['timestamp'] = start_time.isoformat() | |
| results['model_name'] = model_name | |
| results['model_checkpoint'] = checkpoint_path | |
| results['num_samples'] = len(test_data) | |
| # Save results | |
| self._save_results(results, model_name) | |
| self.results_history.append(results) | |
| logger.info(f"Benchmark complete. WER: {results.get('word_error_rate', 'N/A'):.4f}") | |
| return results | |
| def compute_metrics( | |
| self, | |
| predictions: List[str], | |
| references: List[str], | |
| audio_pairs: Optional[List[tuple]] = None | |
| ) -> Dict[str, float]: | |
| """ | |
| Compute all metrics for predictions. | |
| Args: | |
| predictions: List of predicted transcriptions | |
| references: List of reference transcriptions | |
| audio_pairs: Optional list of (generated, reference) audio pairs | |
| Returns: | |
| Dictionary of metric names and values | |
| """ | |
| metrics = {} | |
| # Text-based metrics | |
| if predictions and references: | |
| try: | |
| metrics['word_error_rate'] = self.metric_calculator.compute_word_error_rate( | |
| predictions, references | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Failed to compute WER: {e}") | |
| metrics['word_error_rate'] = float('nan') | |
| try: | |
| metrics['character_error_rate'] = self.metric_calculator.compute_character_error_rate( | |
| predictions, references | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Failed to compute CER: {e}") | |
| metrics['character_error_rate'] = float('nan') | |
| # Audio-based metrics | |
| if audio_pairs: | |
| mcd_scores = [] | |
| pesq_scores = [] | |
| for gen_audio, ref_audio in audio_pairs: | |
| if isinstance(gen_audio, torch.Tensor) and isinstance(ref_audio, torch.Tensor): | |
| try: | |
| mcd = self.metric_calculator.compute_mel_cepstral_distortion( | |
| gen_audio, ref_audio | |
| ) | |
| mcd_scores.append(mcd) | |
| except Exception as e: | |
| logger.warning(f"Failed to compute MCD: {e}") | |
| try: | |
| pesq = self.metric_calculator.compute_perceptual_quality( | |
| gen_audio, ref_audio | |
| ) | |
| pesq_scores.append(pesq) | |
| except Exception as e: | |
| logger.warning(f"Failed to compute PESQ: {e}") | |
| if mcd_scores: | |
| metrics['mel_cepstral_distortion'] = sum(mcd_scores) / len(mcd_scores) | |
| if pesq_scores: | |
| metrics['perceptual_evaluation_speech_quality'] = sum(pesq_scores) / len(pesq_scores) | |
| return metrics | |
| def _save_results(self, results: Dict[str, Any], model_name: str) -> None: | |
| """ | |
| Save benchmark results to file. | |
| Args: | |
| results: Results dictionary | |
| model_name: Model identifier | |
| """ | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"benchmark_{model_name}_{timestamp}.json" | |
| filepath = self.output_dir / filename | |
| # Convert any non-serializable values | |
| serializable_results = {} | |
| for key, value in results.items(): | |
| if isinstance(value, (int, float, str, bool, type(None))): | |
| serializable_results[key] = value | |
| elif isinstance(value, datetime): | |
| serializable_results[key] = value.isoformat() | |
| else: | |
| serializable_results[key] = str(value) | |
| with open(filepath, 'w') as f: | |
| json.dump(serializable_results, f, indent=2) | |
| logger.info(f"Results saved to {filepath}") | |
| def load_results(self, filepath: str) -> Dict[str, Any]: | |
| """ | |
| Load benchmark results from file. | |
| Args: | |
| filepath: Path to results file | |
| Returns: | |
| Results dictionary | |
| """ | |
| with open(filepath, 'r') as f: | |
| results = json.load(f) | |
| return results | |
| def get_latest_results(self, model_name: Optional[str] = None) -> Optional[Dict[str, Any]]: | |
| """ | |
| Get the most recent benchmark results. | |
| Args: | |
| model_name: Optional model name filter | |
| Returns: | |
| Latest results dictionary or None | |
| """ | |
| if not self.results_history: | |
| return None | |
| if model_name: | |
| filtered = [r for r in self.results_history if r.get('model_name') == model_name] | |
| return filtered[-1] if filtered else None | |
| return self.results_history[-1] | |