Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import glob | |
| import random | |
| import torch | |
| import librosa | |
| import numpy as np | |
| import utils | |
| from sklearn.model_selection import train_test_split | |
| from torch.utils.data import Dataset, DataLoader | |
| import scipy.signal as signal | |
| import scipy.signal | |
| from scipy.signal import butter, lfilter | |
| import numpy as np | |
| import scipy.signal as signal | |
| import librosa | |
| import torch | |
| import random | |
| from torch.utils.data import Dataset | |
| import logging | |
| import csv | |
| import logging | |
| import time | |
| import numpy as np | |
| import h5py | |
| import torch | |
| import torchaudio | |
| # Oversampling Lib | |
| from imblearn.over_sampling import RandomOverSampler | |
| class FakeMusicCapsDataset(Dataset): | |
| def __init__(self, file_paths, labels, feat_type=['mel'], sr=16000, n_mels=64, target_duration=10.0, augment=True, augment_real=True): | |
| self.file_paths = file_paths | |
| self.labels = labels | |
| self.feat_type = feat_type | |
| self.sr = sr | |
| self.n_mels = n_mels | |
| self.target_duration = target_duration | |
| self.target_samples = int(target_duration * sr) | |
| self.augment = augment | |
| self.augment_real = augment_real | |
| def pre_emphasis(self, x, alpha=0.97): | |
| return np.append(x[0], x[1:] - alpha * x[:-1]) | |
| def highpass_filter(self, y, sr, cutoff=1000, order=5): | |
| nyquist = 0.5 * sr | |
| normal_cutoff = cutoff / nyquist | |
| b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) | |
| return signal.lfilter(b, a, y) | |
| def augment_audio(self, y, sr): | |
| if random.random() < 0.5: | |
| rate = random.uniform(0.8, 1.2) | |
| y = librosa.effects.time_stretch(y=y, rate=rate) | |
| if random.random() < 0.5: | |
| n_steps = random.randint(-2, 2) | |
| y = librosa.effects.pitch_shift(y=y, sr=sr, n_steps=n_steps) | |
| if random.random() < 0.5: | |
| noise_level = np.random.uniform(0.001, 0.005) | |
| y = y + np.random.normal(0, noise_level, y.shape) | |
| if random.random() < 0.5: | |
| gain = np.random.uniform(0.9, 1.1) | |
| y = y * gain | |
| return y | |
| def __len__(self): | |
| return len(self.file_paths) | |
| def __getitem__(self, idx): | |
| """ | |
| Load and preprocess audio file. | |
| """ | |
| audio_path = self.file_paths[idx] | |
| label = self.labels[idx] | |
| waveform, sr = librosa.load(audio_path, sr=self.sr, mono=True) | |
| if label == 0: | |
| if self.augment_real: | |
| waveform = self.augment_audio(waveform, self.sr) | |
| if label == 1: | |
| waveform = self.highpass_filter(waveform, self.sr) | |
| waveform = self.augment_audio(waveform, self.sr) | |
| current_samples = waveform.shape[0] | |
| if current_samples > self.target_samples: | |
| start_idx = (current_samples - self.target_samples) // 2 | |
| waveform = waveform[start_idx:start_idx + self.target_samples] | |
| elif current_samples < self.target_samples: | |
| waveform = np.pad(waveform, (0, self.target_samples - current_samples), mode='constant') | |
| mel_spec = librosa.feature.melspectrogram( | |
| y=waveform, sr=self.sr, n_mels=self.n_mels, n_fft=1024, hop_length=256 | |
| ) | |
| log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max) | |
| log_mel_spec = np.expand_dims(log_mel_spec, axis=0) | |
| mel_tensor = torch.tensor(log_mel_spec, dtype=torch.float) | |
| label_tensor = torch.tensor(label, dtype=torch.long) | |
| return mel_tensor, label_tensor | |
| def extract_feature(self, waveform, feat): | |
| """Extracts specified feature (mel, stft, cqt) from waveform.""" | |
| try: | |
| if feat == 'mel': | |
| mel_spec = librosa.feature.melspectrogram(y=waveform, sr=self.sr, n_mels=self.n_mels, n_fft=1024, hop_length=256) | |
| log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max) | |
| return torch.tensor(log_mel_spec, dtype=torch.float).unsqueeze(0) | |
| elif feat == 'stft': | |
| stft = librosa.stft(waveform, n_fft=512, hop_length=128, window="hann") | |
| logSTFT = np.log(np.abs(stft) + 1e-3) | |
| return torch.tensor(logSTFT, dtype=torch.float).unsqueeze(0) | |
| elif feat == 'cqt': | |
| cqt = librosa.cqt(waveform, sr=self.sr, hop_length=128, bins_per_octave=24) | |
| logCQT = np.log(np.abs(cqt) + 1e-3) | |
| return torch.tensor(logCQT, dtype=torch.float).unsqueeze(0) | |
| else: | |
| raise ValueError(f"[ERROR] Unsupported feature type: {feat}") | |
| except Exception as e: | |
| print(f"[ERROR] Feature extraction failed for {feat}: {e}") | |
| return None | |
| def highpass_filter(self, y, sr, cutoff=1000, order=5): | |
| if isinstance(sr, np.ndarray): | |
| sr = np.mean(sr) | |
| if not isinstance(sr, (int, float)): | |
| raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}") | |
| if sr <= 0: | |
| raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.") | |
| nyquist = 0.5 * sr | |
| if cutoff <= 0 or cutoff >= nyquist: | |
| print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...") | |
| cutoff = max(10, min(cutoff, nyquist - 1)) | |
| normal_cutoff = cutoff / nyquist | |
| b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) | |
| y_filtered = signal.lfilter(b, a, y) | |
| return y_filtered | |
| def preprocess_audio(audio_path, sr=16000, n_mels=64, target_duration=10.0): | |
| try: | |
| waveform, _ = librosa.load(audio_path, sr=sr, mono=True) | |
| target_samples = int(target_duration * sr) | |
| if len(waveform) > target_samples: | |
| start_idx = (len(waveform) - target_samples) // 2 | |
| waveform = waveform[start_idx:start_idx + target_samples] | |
| elif len(waveform) < target_samples: | |
| waveform = np.pad(waveform, (0, target_samples - len(waveform)), mode='constant') | |
| mel_spec = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=n_mels, n_fft=1024, hop_length=256) | |
| log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max) | |
| return torch.tensor(log_mel_spec, dtype=torch.float32).unsqueeze(0).unsqueeze(0) | |
| except Exception as e: | |
| print(f"[ERROR] 전처리 실패: {audio_path} | 오류: {e}") | |
| return None | |
| DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps" | |
| SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터 | |
| real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True) | |
| gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True) | |
| open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True) | |
| open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True) | |
| real_labels = [0] * len(real_files) | |
| gen_labels = [1] * len(gen_files) | |
| open_real_labels = [0] * len(open_real_files) | |
| open_gen_labels = [1] * len(open_gen_files) | |
| real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42) | |
| gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42) | |
| train_files = real_train + gen_train | |
| train_labels = real_train_labels + gen_train_labels | |
| val_files = real_val + gen_val | |
| val_labels = real_val_labels + gen_val_labels | |
| closed_test_files = real_files + gen_files | |
| closed_test_labels = real_labels + gen_labels | |
| open_test_files = open_real_files + open_gen_files | |
| open_test_labels = open_real_labels + open_gen_labels | |
| ros = RandomOverSampler(sampling_strategy='auto', random_state=42) | |
| train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels) | |
| train_files = train_files_resampled.reshape(-1).tolist() | |
| train_labels = train_labels_resampled | |
| print(f"type(train_labels_resampled): {type(train_labels_resampled)}") | |
| print(f"Train Org Fake: {len(gen_val)}") | |
| print(f"Train set (Oversampled) - Real: {sum(1 for label in train_labels if label == 0)}, " | |
| f"Fake: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}") | |
| print(f"Validation set - Real: {len(real_val)}, Fake: {len(gen_val)}, Total: {len(val_files)}") | |
| print(f"Closed Test set - Real: {len(real_files)}, Fake: {len(gen_files)}, Total: {len(closed_test_files)}") | |
| print(f"Open Test set - Real: {len(open_real_files)}, Fake: {len(open_gen_files)}, Total: {len(open_test_files)}") |