Spaces:
Runtime error
Runtime error
| # Hacked together using the code from https://github.com/nikhilsinghmus/image2reverb | |
| import os, types | |
| import numpy as np | |
| import gradio as gr | |
| import soundfile as sf | |
| import scipy | |
| import librosa.display | |
| from PIL import Image | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from torch.utils.data import Dataset | |
| import torchvision.transforms as transforms | |
| from pytorch_lightning import Trainer | |
| from image2reverb.model import Image2Reverb | |
| from image2reverb.stft import STFT | |
| predicted_ir = None | |
| predicted_spectrogram = None | |
| predicted_depthmap = None | |
| def test_step(self, batch, batch_idx): | |
| spec, label, paths = batch | |
| examples = [os.path.splitext(os.path.basename(s))[0] for _, s in zip(*paths)] | |
| f, img = self.enc.forward(label) | |
| shape = ( | |
| f.shape[0], | |
| (self._latent_dimension - f.shape[1]) if f.shape[1] < self._latent_dimension else f.shape[1], | |
| f.shape[2], | |
| f.shape[3] | |
| ) | |
| z = torch.cat((f, torch.randn(shape, device=model.device)), 1) | |
| fake_spec = self.g(z) | |
| stft = STFT() | |
| y_f = [stft.inverse(s.squeeze()) for s in fake_spec] | |
| # TODO: bit hacky | |
| global predicted_ir, predicted_spectrogram, predicted_depthmap | |
| predicted_ir = y_f[0] | |
| s = fake_spec.squeeze().cpu().numpy() | |
| predicted_spectrogram = np.exp((((s + 1) * 0.5) * 19.5) - 17.5) - 1e-8 | |
| img = (img + 1) * 0.5 | |
| predicted_depthmap = img.cpu().squeeze().permute(1, 2, 0)[:,:,-1].squeeze().numpy() | |
| return {"test_audio": y_f, "test_examples": examples} | |
| def test_epoch_end(self, outputs): | |
| if not self.test_callback: | |
| return | |
| examples = [] | |
| audio = [] | |
| for output in outputs: | |
| for i in range(len(output["test_examples"])): | |
| audio.append(output["test_audio"][i]) | |
| examples.append(output["test_examples"][i]) | |
| self.test_callback(examples, audio) | |
| checkpoint_path = "./checkpoints/image2reverb_f22.ckpt" | |
| encoder_path = None | |
| depthmodel_path = "./checkpoints/mono_odom_640x192" | |
| constant_depth = None | |
| latent_dimension = 512 | |
| model = Image2Reverb(encoder_path, depthmodel_path) | |
| m = torch.load(checkpoint_path, map_location=model.device) | |
| model.load_state_dict(m["state_dict"]) | |
| model.test_step = types.MethodType(test_step, model) | |
| model.test_epoch_end = types.MethodType(test_epoch_end, model) | |
| image_transforms = transforms.Compose([ | |
| transforms.Resize([224, 224], transforms.functional.InterpolationMode.BICUBIC), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| class Image2ReverbDemoDataset(Dataset): | |
| def __init__(self, image): | |
| self.image = Image.fromarray(image) | |
| self.stft = STFT() | |
| def __getitem__(self, index): | |
| img_tensor = image_transforms(self.image.convert("RGB")) | |
| return torch.zeros(1, int(5.94 * 22050)), img_tensor, ("", "") | |
| def __len__(self): | |
| return 1 | |
| def name(self): | |
| return "Image2ReverbDemo" | |
| def convolve(audio, reverb): | |
| # convolve audio with reverb | |
| wet_audio = np.concatenate((audio, np.zeros(reverb.shape))) | |
| wet_audio = scipy.signal.oaconvolve(wet_audio, reverb, "full")[:len(wet_audio)] | |
| # normalize audio to roughly -1 dB peak and remove DC offset | |
| wet_audio /= np.max(np.abs(wet_audio)) | |
| wet_audio -= np.mean(wet_audio) | |
| wet_audio *= 0.9 | |
| return wet_audio | |
| def predict(image, audio): | |
| # image = numpy (height, width, channels) | |
| # audio = tuple (sample_rate, frames) or (sample_rate, (frames, channels)) | |
| test_set = Image2ReverbDemoDataset(image) | |
| test_loader = torch.utils.data.DataLoader(test_set, num_workers=0, batch_size=1) | |
| trainer = Trainer(limit_test_batches=1) | |
| trainer.test(model, test_loader, verbose=True) | |
| # depthmap output | |
| depthmap_fig = plt.figure() | |
| plt.imshow(predicted_depthmap) | |
| plt.close() | |
| # spectrogram output | |
| spectrogram_fig = plt.figure() | |
| librosa.display.specshow(predicted_spectrogram, sr=22050, x_axis="time", y_axis="hz") | |
| plt.close() | |
| # plot the IR as a waveform | |
| waveform_fig = plt.figure() | |
| librosa.display.waveshow(predicted_ir, sr=22050, alpha=0.5) | |
| plt.close() | |
| # output audio as 16-bit signed integer | |
| ir = (22050, (predicted_ir * 32767).astype(np.int16)) | |
| sample_rate, original_audio = audio | |
| # incoming audio is 16-bit signed integer, convert to float and normalize | |
| original_audio = original_audio.astype(np.float32) / 32768.0 | |
| original_audio /= np.max(np.abs(original_audio)) | |
| # resample reverb to sample_rate first, also normalize | |
| reverb = predicted_ir.copy() | |
| reverb = scipy.signal.resample_poly(reverb, up=sample_rate, down=22050) | |
| reverb /= np.max(np.abs(reverb)) | |
| # stereo? | |
| if len(original_audio.shape) > 1: | |
| wet_left = convolve(original_audio[:, 0], reverb) | |
| wet_right = convolve(original_audio[:, 1], reverb) | |
| wet_audio = np.concatenate([wet_left[:, None], wet_right[:, None]], axis=1) | |
| else: | |
| wet_audio = convolve(original_audio, reverb) | |
| # 50% dry-wet mix | |
| mixed_audio = wet_audio * 0.5 | |
| mixed_audio[:len(original_audio), ...] += original_audio * 0.9 * 0.5 | |
| # convert back to 16-bit signed integer | |
| wet_audio = (wet_audio * 32767).astype(np.int16) | |
| mixed_audio = (mixed_audio * 32767).astype(np.int16) | |
| convolved_audio_100 = (sample_rate, wet_audio) | |
| convolved_audio_50 = (sample_rate, mixed_audio) | |
| return depthmap_fig, spectrogram_fig, waveform_fig, ir, convolved_audio_100, convolved_audio_50 | |
| title = "Image2Reverb: Cross-Modal Reverb Impulse Response Synthesis" | |
| description = """ | |
| <b>Image2Reverb</b> predicts the acoustic reverberation of a given environment from a 2D image. <a href="https://arxiv.org/abs/2103.14201">Read the paper</a> | |
| How to use: Choose an image of a room or other environment and an audio file. | |
| The model will predict what the reverb of the room sounds like and applies this to the audio file. | |
| First, the image is resized to 224×224. The monodepth model is used to predict a depthmap, which is added as an | |
| additional channel to the image input. A ResNet-based encoder then converts the image into features, and | |
| finally a GAN predicts the spectrogram of the reverb's impulse response. | |
| <center><img src="file/model.jpg" width="870" height="297" alt="model architecture"></center> | |
| The predicted impulse response is mono 22050 kHz. It is upsampled to the sampling rate of the audio | |
| file and applied to both channels if the audio is stereo. | |
| Generating the impulse response involves a certain amount of randomness, making it sound a little | |
| different every time you try it. | |
| """ | |
| article = """ | |
| <div style='margin:20px auto;'> | |
| <p>Based on original work by Nikhil Singh, Jeff Mentch, Jerry Ng, Matthew Beveridge, Iddo Drori. | |
| <a href="https://web.media.mit.edu/~nsingh1/image2reverb/">Project Page</a> | | |
| <a href="https://arxiv.org/abs/2103.14201">Paper</a> | | |
| <a href="https://github.com/nikhilsinghmus/image2reverb">GitHub</a></p> | |
| <pre> | |
| @InProceedings{Singh_2021_ICCV, | |
| author = {Singh, Nikhil and Mentch, Jeff and Ng, Jerry and Beveridge, Matthew and Drori, Iddo}, | |
| title = {Image2Reverb: Cross-Modal Reverb Impulse Response Synthesis}, | |
| booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, | |
| month = {October}, | |
| year = {2021}, | |
| pages = {286-295} | |
| } | |
| </pre> | |
| <p>🌠 Example images from <a href="https://web.media.mit.edu/~nsingh1/image2reverb/">the original project page</a>.</p> | |
| <p>🎶 Example sound from <a href="https://freesound.org/people/ashesanddreams/sounds/610414/">Ashes and Dreams @ freesound.org</a> (CC BY 4.0 license). This is a mono 48 kHz recording that has no reverb on it.</p> | |
| </div> | |
| """ | |
| audio_example = "examples/ashesanddreams.wav" | |
| examples = [ | |
| ["examples/input.4e2f71f6.png", audio_example], | |
| ["examples/input.321eef38.png", audio_example], | |
| ["examples/input.2238dc21.png", audio_example], | |
| ["examples/input.4d280b40.png", audio_example], | |
| ["examples/input.0c3f5013.png", audio_example], | |
| ["examples/input.98773b90.png", audio_example], | |
| ["examples/input.ac61500f.png", audio_example], | |
| ["examples/input.5416407f.png", audio_example], | |
| ] | |
| gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(label="Upload Image"), | |
| gr.Audio(label="Upload Audio", source="upload", type="numpy"), | |
| ], | |
| outputs=[ | |
| gr.Plot(label="Depthmap"), | |
| gr.Plot(label="Impulse Response Spectrogram"), | |
| gr.Plot(label="Impulse Response Waveform"), | |
| gr.Audio(label="Impulse Response", type="numpy"), | |
| gr.Audio(label="Output Audio (100% Wet)", type="numpy"), | |
| gr.Audio(label="Output Audio (50% Dry, 50% Wet)", type="numpy"), | |
| ], | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=examples, | |
| ).launch() | |