import librosa import librosa.display import numpy as np import matplotlib.pyplot as plt import scipy.signal as signal import torch import torch.nn as nn import soundfile as sf from networks import audiocnn, AudioCNNWithViTDecoder, AudioCNNWithViTDecoderAndCrossAttention def highpass_filter(y, sr, cutoff=500, order=5): """High-pass filter to remove low frequencies below `cutoff` Hz.""" nyquist = 0.5 * sr normal_cutoff = cutoff / nyquist b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) y_filtered = signal.lfilter(b, a, y) return y_filtered def plot_combined_visualization(y_original, y_filtered, sr, save_path="combined_visualization.png"): """Plot waveform comparison and spectrograms in a single figure.""" fig, axes = plt.subplots(3, 1, figsize=(12, 12)) # 1️⃣ Waveform Comparison time = np.linspace(0, len(y_original) / sr, len(y_original)) axes[0].plot(time, y_original, label='Original', alpha=0.7) axes[0].plot(time, y_filtered, label='High-pass Filtered', alpha=0.7, linestyle='dashed') axes[0].set_xlabel("Time (s)") axes[0].set_ylabel("Amplitude") axes[0].set_title("Waveform Comparison (Original vs High-pass Filtered)") axes[0].legend() # 2️⃣ Spectrogram - Original S_orig = librosa.amplitude_to_db(np.abs(librosa.stft(y_original)), ref=np.max) img = librosa.display.specshow(S_orig, sr=sr, x_axis='time', y_axis='log', ax=axes[1]) axes[1].set_title("Original Spectrogram") fig.colorbar(img, ax=axes[1], format="%+2.0f dB") # 3️⃣ Spectrogram - High-pass Filtered S_filt = librosa.amplitude_to_db(np.abs(librosa.stft(y_filtered)), ref=np.max) img = librosa.display.specshow(S_filt, sr=sr, x_axis='time', y_axis='log', ax=axes[2]) axes[2].set_title("High-pass Filtered Spectrogram") fig.colorbar(img, ax=axes[2], format="%+2.0f dB") plt.tight_layout() plt.savefig(save_path, dpi=300) plt.show() def load_model(checkpoint_path, model_class, device): """Load a trained model from checkpoint.""" model = model_class() model.load_state_dict(torch.load(checkpoint_path, map_location=device)) model.to(device) model.eval() return model def predict_audio(model, audio_tensor, device): """Make predictions using a trained model.""" with torch.no_grad(): audio_tensor = audio_tensor.unsqueeze(0).to(device) # Add batch dimension output = model(audio_tensor) prediction = torch.argmax(output, dim=1).cpu().numpy()[0] return prediction # Load audio audio_path = "/data/kym/AI Music Detection/audio/FakeMusicCaps/real/musiccaps/_RrA-0lfIiU.wav" # Replace with actual file path y, sr = librosa.load(audio_path, sr=None) y_filtered = highpass_filter(y, sr, cutoff=500) # Convert audio to tensor audio_tensor = torch.tensor(librosa.feature.melspectrogram(y=y, sr=sr), dtype=torch.float).unsqueeze(0) audio_tensor_filtered = torch.tensor(librosa.feature.melspectrogram(y=y_filtered, sr=sr), dtype=torch.float).unsqueeze(0) # Load models device = torch.device("cuda" if torch.cuda.is_available() else "cpu") original_model = load_model("/data/kym/AI Music Detection/AudioCNN/ckpt/FakeMusicCaps/pretraining/best_model_audiocnn.pth", audiocnn, device) highpass_model = load_model("/data/kym/AI Music Detection/AudioCNN/ckpt/FakeMusicCaps/500hz_Add_crossattn_decoder/best_model_AudioCNNWithViTDecoderAndCrossAttention.pth", AudioCNNWithViTDecoderAndCrossAttention, device) # Predict original_pred = predict_audio(original_model, audio_tensor, device) highpass_pred = predict_audio(highpass_model, audio_tensor_filtered, device) print(f"Original Model Prediction: {original_pred}") print(f"High-pass Filter Model Prediction: {highpass_pred}") # Generate combined visualization (all plots in one image) plot_combined_visualization(y, y_filtered, sr, save_path="/data/kym/AI Music Detection/AudioCNN/hf_vis/rawvs500.png")