Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import spaces | |
| import sys | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| # Add src to path to import sfi_utmos | |
| project_root = Path(__file__).resolve().parent | |
| sys.path.insert(0, str(project_root / "src")) | |
| from sfi_utmos.model.ssl_mos import SSLMOSLightningModule | |
| # Global variable for the model | |
| model: SSLMOSLightningModule | None = None | |
| device = "cuda" | |
| def load_model(checkpoint_path: str): | |
| """Loads the model from the given checkpoint path.""" | |
| global model | |
| model = SSLMOSLightningModule.load_from_checkpoint( | |
| checkpoint_path, map_location=device, | |
| pretrained_model_path=None, | |
| ) | |
| model.eval() | |
| print(f"Model loaded from {checkpoint_path}") | |
| def predict_mos(audio_path: str): | |
| """Predicts the MOS score for the given audio file.""" | |
| if model is None: | |
| return "Error: Model not loaded. Please provide a valid checkpoint path." | |
| ratings = [] | |
| for listner in range(1, 11): | |
| wav, sr = torchaudio.load(audio_path) | |
| if model.condition_sr: | |
| if sr not in model.sr2id.keys(): | |
| return f"Error: Sample rate {sr} not supported by the model. Supported rates: {list(model.sr2id.keys())}" | |
| waves = [wav.view(-1).to(model.device)] | |
| srs = torch.tensor(sr).view(1, -1).to(model.device) | |
| if model.condition_sr: | |
| srs = torch.stack( | |
| [torch.tensor(model.sr2id[sr.detach().cpu().item()]) for sr in srs] | |
| ).to(model.device) | |
| listner_tensor = torch.tensor(listner).view(-1).to(model.device) | |
| if hasattr(model, "is_sfi") and model.is_sfi: | |
| model.ssl_model.set_sample_rate(srs[0].item()) | |
| waves = torch.nn.utils.rnn.pad_sequence( | |
| [w.view(-1) for w in waves], batch_first=True | |
| ).to(device) | |
| else: | |
| waves = [torchaudio.functional.resample(w, sr, 16_000) for w in waves] | |
| output = model.forward( | |
| waves, | |
| listner_tensor, | |
| srs, | |
| ) | |
| ratings.append(output.cpu().item()) | |
| mos_score = 2*(sum(ratings) / len(ratings)) + 3 | |
| return f"{mos_score:.3f}" | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Run MOS prediction demo with Gradio.") | |
| parser.add_argument( | |
| "--checkpoint_path", | |
| type=str, | |
| default="https://huggingface.co/sarulab-speech/MSR-UTMOS_w2v2_fold0/resolve/main/model.ckpt", | |
| help="Path to the model checkpoint (.ckpt file).", | |
| ) | |
| args = parser.parse_args() | |
| load_model(args.checkpoint_path) | |
| if model is None: | |
| print("Failed to load model. Exiting.") | |
| sys.exit(1) | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=predict_mos, | |
| inputs=gr.Audio(type="filepath", label="Upload Audio File"), | |
| outputs="text", | |
| title="MSR-UTMOS: MOS Prediction Demo", | |
| description=( | |
| "Upload an audio file (WAV, MP3, etc.) to get its predicted Mean Opinion Score (MOS). " | |
| ), | |
| ) | |
| iface.launch() | |
| if __name__ == "__main__": | |
| main() | |