Seonghyeon Go
fix some typo
8f94ca9
import os
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import wandb
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, balanced_accuracy_score
from datalib import FakeMusicCapsDataset
from datalib import (
FakeMusicCapsDataset,
train_files, val_files, train_labels, val_labels,
closed_test_files, closed_test_labels,
open_test_files, open_test_labels,
preprocess_audio
)
from datalib import preprocess_audio
from networks import CCV
from attentionmap import visualize_attention_map
from confusion_matrix import plot_confusion_matrix
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
'''
python3 main.py --model_name CCV --batch_size 32 --epochs 10 --loss_type ce --oversample True
audiocnn encoder - crossattn based decoder (ViT) model
'''
# Argument parsing
import argparse
parser = argparse.ArgumentParser(description='AI Music Detection Training')
parser.add_argument('--gpu', type=str, default='1', help='GPU ID')
parser.add_argument('--model_name', type=str, choices=['audiocnn', 'CCV'], default='CCV', help='Model name')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
parser.add_argument('--epochs', type=int, default=10, help='Number of epochs')
parser.add_argument('--audio_duration', type=float, default=10, help='Length of the audio slice in seconds')
parser.add_argument('--patience_counter', type=int, default=5, help='Early stopping patience')
parser.add_argument('--log_dir', type=str, default='', help='TensorBoard log directory')
parser.add_argument('--ckpt_path', type=str, default='', help='Checkpoint directory')
parser.add_argument("--weight_decay", type=float, default=0.05, help="weight decay (default: 0.0)")
parser.add_argument("--loss_type", type=str, choices=["ce", "weighted_ce", "focal"], default="ce", help="Loss function type")
parser.add_argument('--inference', type=str, help='Path to a .wav file for inference')
parser.add_argument("--closed_test", action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)")
parser.add_argument("--open_test", action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)")
parser.add_argument("--oversample", type=bool, default=True, help="Apply Oversampling to balance classes") # real data oversampling
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
wandb.init(project="",
name=f"{args.model_name}_lr{args.learning_rate}_ep{args.epochs}_bs{args.batch_size}", config=args)
if args.model_name == 'CCV':
model = CCV(embed_dim=512, num_heads=8, num_layers=6, num_classes=2).cuda()
feat_type = 'mel'
else:
raise ValueError(f"Invalid model name: {args.model_name}")
model = model.to(device)
print(f"Using model: {args.model_name}, Parameters: {count_parameters(model)}")
print(f"weight_decay WD: {args.weight_decay}")
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
if args.loss_type == "ce":
print("Using CrossEntropyLoss")
criterion = nn.CrossEntropyLoss()
elif args.loss_type == "weighted_ce":
print("Using Weighted CrossEntropyLoss")
num_real = sum(1 for label in train_labels if label == 0)
num_fake = sum(1 for label in train_labels if label == 1)
total_samples = num_real + num_fake
weight_real = total_samples / (2 * num_real)
weight_fake = total_samples / (2 * num_fake)
class_weights = torch.tensor([weight_real, weight_fake]).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
elif args.loss_type == "focal":
print("Using Focal Loss")
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
criterion = FocalLoss().to(device)
if not os.path.exists(args.ckpt_path):
os.makedirs(args.ckpt_path)
train_dataset = FakeMusicCapsDataset(train_files, train_labels, feat_type=feat_type, target_duration=args.audio_duration)
val_dataset = FakeMusicCapsDataset(val_files, val_labels, feat_type=feat_type, target_duration=args.audio_duration)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)
def train(model, train_loader, val_loader, optimizer, scheduler, criterion, device, args):
writer = SummaryWriter(log_dir=args.log_dir)
best_val_bal_acc = float('inf')
early_stop_cnt = 0
log_interval = 1
for epoch in range(args.epochs):
print(f"\n[Epoch {epoch + 1}/{args.epochs}]")
model.train()
train_loss, train_correct, train_total = 0, 0, 0
all_train_preds= []
all_train_labels = []
attention_maps = []
train_pbar = tqdm(train_loader, desc="Train", leave=False)
for batch_idx, (data, target) in enumerate(train_pbar):
data = data.to(device)
target = target.to(device)
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item() * data.size(0)
preds = output.argmax(dim=1)
train_correct += (preds == target).sum().item()
train_total += target.size(0)
all_train_labels.extend(target.cpu().numpy())
all_train_preds.extend(preds.cpu().numpy())
if hasattr(model, "get_attention_maps"):
attention_maps.append(model.get_attention_maps())
train_loss /= train_total
train_acc = train_correct / train_total
train_bal_acc = balanced_accuracy_score(all_train_labels, all_train_preds)
train_precision = precision_score(all_train_labels, all_train_preds, average="binary")
train_recall = recall_score(all_train_labels, all_train_preds, average="binary")
train_f1 = f1_score(all_train_labels, all_train_preds, average="binary")
wandb.log({
"Train Loss": train_loss, "Train Accuracy": train_acc,
"Train Precision": train_precision, "Train Recall": train_recall,
"Train F1 Score": train_f1, "Train B_ACC": train_bal_acc,
})
print(f"Train Epoch: {epoch+1} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.3f} | "
f"Train B_ACC: {train_bal_acc:.4f} | Train Prec: {train_precision:.3f} | "
f"Train Rec: {train_recall:.3f} | Train F1: {train_f1:.3f}")
model.eval()
val_loss, val_correct, val_total = 0, 0, 0
all_val_preds, all_val_labels = [], []
attention_maps = []
val_pbar = tqdm(val_loader, desc=" Val ", leave=False)
with torch.no_grad():
for data, target in val_pbar:
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
val_loss += loss.item() * data.size(0)
preds = output.argmax(dim=1)
val_correct += (preds == target).sum().item()
val_total += target.size(0)
all_val_labels.extend(target.cpu().numpy())
all_val_preds.extend(preds.cpu().numpy())
if hasattr(model, "get_attention_maps"):
attention_maps.append(model.get_attention_maps())
val_loss /= val_total
val_acc = val_correct / val_total
val_bal_acc = balanced_accuracy_score(all_val_labels, all_val_preds)
val_precision = precision_score(all_val_labels, all_val_preds, average="binary")
val_recall = recall_score(all_val_labels, all_val_preds, average="binary")
val_f1 = f1_score(all_val_labels, all_val_preds, average="binary")
wandb.log({
"Validation Loss": val_loss, "Validation Accuracy": val_acc,
"Validation Precision": val_precision, "Validation Recall": val_recall,
"Validation F1 Score": val_f1, "Validation B_ACC": val_bal_acc,
})
print(f"Val Epoch: {epoch+1} [{batch_idx * len(data)}/{len(val_loader.dataset)} "
f"({100. * batch_idx / len(val_loader):.0f}%)]\t"
f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.3f} | "
f"Val B_ACC: {val_bal_acc:.4f} | Val Prec: {val_precision:.3f} | "
f"Val Rec: {val_recall:.3f} | Val F1: {val_f1:.3f}")
if epoch % 1 == 0 and len(attention_maps) > 0:
print(f"Visualizing Attention Map at Epoch {epoch+1}")
if isinstance(attention_maps[0], list):
attn_map_numpy = np.array([t.detach().cpu().numpy() for t in attention_maps[0]])
elif isinstance(attention_maps[0], torch.Tensor):
attn_map_numpy = attention_maps[0].detach().cpu().numpy()
else:
attn_map_numpy = np.array(attention_maps[0])
print(f"Attention Map Shape: {attn_map_numpy.shape}")
if len(attn_map_numpy) > 0:
fig, ax = plt.subplots(figsize=(10, 8))
ax.imshow(attn_map_numpy[0], cmap='viridis', interpolation='nearest')
ax.set_title(f"Attention Map - Epoch {epoch+1}")
plt.colorbar(ax.imshow(attn_map_numpy[0], cmap='viridis'))
plt.savefig("")
plt.show()
else:
print(f"Warning: attention_maps[0] is empty! Shape={attn_map_numpy.shape}")
if val_bal_acc < best_val_bal_acc:
best_val_bal_acc = val_bal_acc
early_stop_cnt = 0
torch.save(model.state_dict(), os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth"))
print("Best model saved.")
else:
early_stop_cnt += 1
print(f'PATIENCE {early_stop_cnt}/{args.patience_counter}')
if early_stop_cnt >= args.patience_counter:
print("Early stopping triggered.")
break
scheduler.step()
plot_confusion_matrix(all_val_labels, all_val_preds, classes=["REAL", "FAKE"], writer=writer, epoch=epoch)
wandb.finish()
writer.close()
def predict(audio_path):
print(f"Loading model from {args.ckpt_path}/celoss_best_model_{args.model_name}.pth")
model.load_state_dict(torch.load(os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth"), map_location=device))
model.eval()
input_tensor = preprocess_audio(audio_path).to(device)
with torch.no_grad():
output = model(input_tensor)
probabilities = F.softmax(output, dim=1)
ai_music_prob = probabilities[0, 1].item()
if ai_music_prob > 0.5:
print(f"FAKE MUSIC {ai_music_prob:.2%})")
else:
print(f"REAL MUSIC {100 - ai_music_prob * 100:.2f}%")
def Test(model, test_loader, criterion, device):
model.load_state_dict(torch.load(os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth"), map_location=device))
model.eval()
test_loss, test_correct, test_total = 0, 0, 0
all_preds, all_labels = [], []
with torch.no_grad():
for data, target in tqdm(test_loader, desc=" Test ", leave=False):
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
test_loss += loss.item() * data.size(0)
preds = output.argmax(dim=1)
test_correct += (preds == target).sum().item()
test_total += target.size(0)
all_labels.extend(target.cpu().numpy())
all_preds.extend(preds.cpu().numpy())
test_loss /= test_total
test_acc = test_correct / test_total
test_bal_acc = balanced_accuracy_score(all_labels, all_preds)
test_precision = precision_score(all_labels, all_preds, average="binary")
test_recall = recall_score(all_labels, all_preds, average="binary")
test_f1 = f1_score(all_labels, all_preds, average="binary")
print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | "
f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | "
f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}")
if __name__ == "__main__":
train(model, train_loader, val_loader, optimizer, scheduler, criterion, device, args)
if args.closed_test:
print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...")
test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, feat_type=feat_type, target_duration=args.audio_duration)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)
elif args.open_test:
print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...")
test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, feat_type=feat_type, target_duration=args.audio_duration)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)
else:
print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...")
test_dataset = FakeMusicCapsDataset(val_files, val_labels, feat_type=feat_type, target_duration=args.audio_duration)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)
print("\nEvaluating Model on Test Set...")
Test(model, test_loader, criterion, device)
if args.inference:
if not os.path.exists(args.inference):
print(f"[ERROR] No File Found: {args.inference}")
else:
predict(args.inference)