Spaces:
Running
on
Zero
Running
on
Zero
| import librosa | |
| import librosa.display | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import scipy.signal as signal | |
| import torch | |
| import torch.nn as nn | |
| import soundfile as sf | |
| from networks import audiocnn, AudioCNNWithViTDecoder, AudioCNNWithViTDecoderAndCrossAttention | |
| def highpass_filter(y, sr, cutoff=500, order=5): | |
| """High-pass filter to remove low frequencies below `cutoff` Hz.""" | |
| nyquist = 0.5 * sr | |
| normal_cutoff = cutoff / nyquist | |
| b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) | |
| y_filtered = signal.lfilter(b, a, y) | |
| return y_filtered | |
| def plot_combined_visualization(y_original, y_filtered, sr, save_path="combined_visualization.png"): | |
| """Plot waveform comparison and spectrograms in a single figure.""" | |
| fig, axes = plt.subplots(3, 1, figsize=(12, 12)) | |
| # 1️⃣ Waveform Comparison | |
| time = np.linspace(0, len(y_original) / sr, len(y_original)) | |
| axes[0].plot(time, y_original, label='Original', alpha=0.7) | |
| axes[0].plot(time, y_filtered, label='High-pass Filtered', alpha=0.7, linestyle='dashed') | |
| axes[0].set_xlabel("Time (s)") | |
| axes[0].set_ylabel("Amplitude") | |
| axes[0].set_title("Waveform Comparison (Original vs High-pass Filtered)") | |
| axes[0].legend() | |
| # 2️⃣ Spectrogram - Original | |
| S_orig = librosa.amplitude_to_db(np.abs(librosa.stft(y_original)), ref=np.max) | |
| img = librosa.display.specshow(S_orig, sr=sr, x_axis='time', y_axis='log', ax=axes[1]) | |
| axes[1].set_title("Original Spectrogram") | |
| fig.colorbar(img, ax=axes[1], format="%+2.0f dB") | |
| # 3️⃣ Spectrogram - High-pass Filtered | |
| S_filt = librosa.amplitude_to_db(np.abs(librosa.stft(y_filtered)), ref=np.max) | |
| img = librosa.display.specshow(S_filt, sr=sr, x_axis='time', y_axis='log', ax=axes[2]) | |
| axes[2].set_title("High-pass Filtered Spectrogram") | |
| fig.colorbar(img, ax=axes[2], format="%+2.0f dB") | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=300) | |
| plt.show() | |
| def load_model(checkpoint_path, model_class, device): | |
| """Load a trained model from checkpoint.""" | |
| model = model_class() | |
| model.load_state_dict(torch.load(checkpoint_path, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def predict_audio(model, audio_tensor, device): | |
| """Make predictions using a trained model.""" | |
| with torch.no_grad(): | |
| audio_tensor = audio_tensor.unsqueeze(0).to(device) # Add batch dimension | |
| output = model(audio_tensor) | |
| prediction = torch.argmax(output, dim=1).cpu().numpy()[0] | |
| return prediction | |
| # Load audio | |
| audio_path = "/data/kym/AI Music Detection/audio/FakeMusicCaps/real/musiccaps/_RrA-0lfIiU.wav" # Replace with actual file path | |
| y, sr = librosa.load(audio_path, sr=None) | |
| y_filtered = highpass_filter(y, sr, cutoff=500) | |
| # Convert audio to tensor | |
| audio_tensor = torch.tensor(librosa.feature.melspectrogram(y=y, sr=sr), dtype=torch.float).unsqueeze(0) | |
| audio_tensor_filtered = torch.tensor(librosa.feature.melspectrogram(y=y_filtered, sr=sr), dtype=torch.float).unsqueeze(0) | |
| # Load models | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| original_model = load_model("/data/kym/AI Music Detection/AudioCNN/ckpt/FakeMusicCaps/pretraining/best_model_audiocnn.pth", audiocnn, device) | |
| highpass_model = load_model("/data/kym/AI Music Detection/AudioCNN/ckpt/FakeMusicCaps/500hz_Add_crossattn_decoder/best_model_AudioCNNWithViTDecoderAndCrossAttention.pth", AudioCNNWithViTDecoderAndCrossAttention, device) | |
| # Predict | |
| original_pred = predict_audio(original_model, audio_tensor, device) | |
| highpass_pred = predict_audio(highpass_model, audio_tensor_filtered, device) | |
| print(f"Original Model Prediction: {original_pred}") | |
| print(f"High-pass Filter Model Prediction: {highpass_pred}") | |
| # Generate combined visualization (all plots in one image) | |
| plot_combined_visualization(y, y_filtered, sr, save_path="/data/kym/AI Music Detection/AudioCNN/hf_vis/rawvs500.png") | |