Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import glob | |
| import random | |
| import torch | |
| import librosa | |
| import numpy as np | |
| import utils | |
| from sklearn.model_selection import train_test_split | |
| from torch.utils.data import Dataset, DataLoader | |
| import scipy.signal as signal | |
| import scipy.signal | |
| from scipy.signal import butter, lfilter | |
| import numpy as np | |
| import scipy.signal as signal | |
| import librosa | |
| import torch | |
| import random | |
| from torch.utils.data import Dataset | |
| import logging | |
| import csv | |
| import logging | |
| import time | |
| import numpy as np | |
| import h5py | |
| import torch | |
| import torchaudio | |
| from imblearn.over_sampling import RandomOverSampler | |
| from networks import Wav2Vec2ForFakeMusic | |
| from transformers import Wav2Vec2Processor | |
| import torchaudio.transforms as T | |
| class FakeMusicCapsDataset(Dataset): | |
| def __init__(self, file_paths, labels, sr=16000, target_duration=10.0): | |
| self.file_paths = file_paths | |
| self.labels = labels | |
| self.sr = sr | |
| self.target_duration = target_duration | |
| self.target_samples = int(target_duration * sr) | |
| self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base") | |
| def highpass_filter(self, y, sr, cutoff=500, order=5): | |
| if isinstance(sr, np.ndarray): | |
| sr = np.mean(sr) | |
| if not isinstance(sr, (int, float)): | |
| raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}") | |
| if sr <= 0: | |
| raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.") | |
| nyquist = 0.5 * sr | |
| if cutoff <= 0 or cutoff >= nyquist: | |
| print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...") | |
| cutoff = max(10, min(cutoff, nyquist - 1)) | |
| normal_cutoff = cutoff / nyquist | |
| b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) | |
| y_filtered = signal.lfilter(b, a, y) | |
| return y_filtered | |
| def __len__(self): | |
| return len(self.file_paths) | |
| def __getitem__(self, idx): | |
| audio_path = self.file_paths[idx] | |
| label = self.labels[idx] | |
| waveform, sr = torchaudio.load(audio_path) | |
| waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sr)(waveform) | |
| waveform = waveform.squeeze(0) | |
| if label == 0: | |
| waveform = self.augment_audio(waveform, self.sr) | |
| if label == 1: | |
| waveform = self.highpass_filter(waveform, self.sr) | |
| current_samples = waveform.shape[0] | |
| if current_samples > self.target_samples: | |
| start_idx = (current_samples - self.target_samples) // 2 | |
| waveform = waveform[start_idx:start_idx + self.target_samples] | |
| elif current_samples < self.target_samples: | |
| waveform = torch.nn.functional.pad(waveform, (0, self.target_samples - current_samples)) | |
| waveform = torch.tensor(waveform, dtype=torch.float32).unsqueeze(0) | |
| label = torch.tensor(label, dtype=torch.long) | |
| return waveform, label | |
| def preprocess_audio(audio_path, target_sr=16000, target_duration=10.0): | |
| waveform, sr = librosa.load(audio_path, sr=target_sr) | |
| target_samples = int(target_duration * target_sr) | |
| current_samples = len(waveform) | |
| if current_samples > target_samples: | |
| waveform = waveform[:target_samples] | |
| elif current_samples < target_samples: | |
| waveform = np.pad(waveform, (0, target_samples - current_samples)) | |
| waveform = torch.tensor(waveform).unsqueeze(0) | |
| return waveform | |
| DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps" | |
| SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터 | |
| real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True) | |
| gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True) | |
| open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True) | |
| open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True) | |
| real_labels = [0] * len(real_files) | |
| gen_labels = [1] * len(gen_files) | |
| open_real_labels = [0] * len(open_real_files) | |
| open_gen_labels = [1] * len(open_gen_files) | |
| real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42) | |
| gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42) | |
| train_files = real_train + gen_train | |
| train_labels = real_train_labels + gen_train_labels | |
| val_files = real_val + gen_val | |
| val_labels = real_val_labels + gen_val_labels | |
| closed_test_files = real_files + gen_files | |
| closed_test_labels = real_labels + gen_labels | |
| open_test_files = open_real_files + open_gen_files | |
| open_test_labels = open_real_labels + open_gen_labels | |
| ros = RandomOverSampler(sampling_strategy='auto', random_state=42) | |
| train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels) | |
| train_files = train_files_resampled.reshape(-1).tolist() | |
| train_labels = train_labels_resampled | |
| print(f"Train Original FAKE: {len(gen_train)}") | |
| print(f"Train set (Oversampled) - REAL: {sum(1 for label in train_labels if label == 0)}, " | |
| f"FAKE: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}") | |
| print(f"Validation set - REAL: {len(real_val)}, FAKE: {len(gen_val)}, Total: {len(val_files)}") | |