| import argparse | |
| import torch | |
| import torchaudio | |
| from pathlib import Path | |
| from spectral_ops import STFT, iSTFT | |
| from model import Renaissance | |
| def load_and_preprocess_audio(input_path, device, dtype): | |
| waveform, sr = torchaudio.load(input_path) | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| print(f"Converted to mono from {waveform.shape[0]} channels") | |
| if sr != 48000: | |
| print(f"Resampling from {sr} Hz to 48000 Hz") | |
| resampler = torchaudio.transforms.Resample(sr, 48000) | |
| waveform = resampler(waveform) | |
| waveform = torchaudio.functional.highpass_biquad( | |
| waveform, 48000, cutoff_freq=60.0 | |
| ) | |
| waveform = waveform.to(device).to(dtype) | |
| return waveform | |
| def normalize_audio(audio): | |
| normalization_factor = torch.max(torch.abs(audio)) | |
| if normalization_factor > 0: | |
| normalized_audio = audio / normalization_factor | |
| else: | |
| normalized_audio = audio | |
| return normalized_audio, normalization_factor | |
| def process_audio(model, stft, istft, input_wav, device): | |
| input_wav_norm, norm_factor = normalize_audio(input_wav) | |
| with torch.no_grad(): | |
| input_stft = stft(input_wav_norm) | |
| with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()): | |
| enhanced_stft = model(input_stft) | |
| enhanced_wav = istft(enhanced_stft) | |
| if norm_factor > 0: | |
| enhanced_wav = enhanced_wav * norm_factor | |
| return enhanced_wav | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Smule Renaissance Vocal Restoration" | |
| ) | |
| parser.add_argument( | |
| "input", | |
| type=str, | |
| help="Input audio file path" | |
| ) | |
| parser.add_argument( | |
| "-o", "--output", | |
| type=str, | |
| default=None, | |
| help="Output audio file path (default: input_enhanced.wav)" | |
| ) | |
| parser.add_argument( | |
| "-c", "--checkpoint", | |
| type=str, | |
| required=True, | |
| help="Model checkpoint path" | |
| ) | |
| args = parser.parse_args() | |
| if args.output is None: | |
| input_path = Path(args.input) | |
| args.output = str(input_path.parent / f"{input_path.stem}_enhanced.wav") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if torch.cuda.is_available(): | |
| print("Using device: CUDA with FP16 precision") | |
| dtype = torch.float16 | |
| else: | |
| print("Using device: CPU with FP32 precision") | |
| dtype = torch.float32 | |
| print(f"Loading model from {args.checkpoint}...") | |
| model = Renaissance().to(device).to(dtype) | |
| model.load_state_dict(torch.load(args.checkpoint, map_location=device)) | |
| model.eval() | |
| stft = STFT(n_fft=4096, hop_length=2048, win_length=4096) | |
| istft = iSTFT(n_fft=4096, hop_length=2048, win_length=4096) | |
| print(f"Loading audio from {args.input}...") | |
| input_wav = load_and_preprocess_audio(args.input, device, dtype) | |
| print(f"Audio duration: {input_wav.shape[1] / 48000:.2f} seconds") | |
| print("Processing audio...") | |
| enhanced_wav = process_audio(model, stft, istft, input_wav, device) | |
| print(f"Saving enhanced audio to {args.output}...") | |
| enhanced_wav_cpu = enhanced_wav.cpu().to(torch.float32) | |
| torchaudio.save(args.output, enhanced_wav_cpu, 48000) | |
| print("Done!") | |
| if __name__ == "__main__": | |
| main() |