Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,048 Bytes
c3c908f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix
from datalib import (
FakeMusicCapsDataset,
closed_test_files, closed_test_labels,
open_test_files, open_test_labels,
val_files, val_labels
)
from networks import Wav2Vec2ForFakeMusic
import tqdm
from tqdm import tqdm
import argparse
'''
python3 test.py --finetune_test --closed_test | --open_test
'''
parser = argparse.ArgumentParser(description="AI Music Detection Testing with Wav2Vec 2.0")
parser.add_argument('--gpu', type=str, default='0', help='GPU ID')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--ckpt_path', type=str, default='', help='Checkpoint directory')
parser.add_argument('--pretrain_test', action="store_true", help="Test Pretrained Wav2Vec2 Model")
parser.add_argument('--finetune_test', action="store_true", help="Test Fine-Tuned Wav2Vec2 Model")
parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)")
parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)")
parser.add_argument('--output_path', type=str, default='', help='Path to save test results')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def plot_confusion_matrix(y_true, y_pred, classes, output_path):
cm = confusion_matrix(y_true, y_pred)
fig, ax = plt.subplots(figsize=(6, 6))
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
num_classes = cm.shape[0]
tick_labels = classes[:num_classes]
ax.set(xticks=np.arange(num_classes),
yticks=np.arange(num_classes),
xticklabels=tick_labels,
yticklabels=tick_labels,
ylabel='True label',
xlabel='Predicted label')
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], 'd'),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
plt.savefig(output_path)
plt.close(fig)
if args.pretrain_test:
ckpt_file = os.path.join(args.ckpt_path, "wav2vec2_pretrain_20.pth")
print("\n🔍 Loading Pretrained Model:", ckpt_file)
model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=True).to(device)
elif args.finetune_test:
ckpt_file = os.path.join(args.ckpt_path, "wav2vec2_finetune_10.pth")
print("\n🔍 Loading Fine-Tuned Model:", ckpt_file)
model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=False).to(device)
else:
raise ValueError("You must specify --pretrain_test or --finetune_test")
if not os.path.exists(ckpt_file):
raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}")
# model.load_state_dict(torch.load(ckpt_file, map_location=device))
# model.eval()
ckpt = torch.load(ckpt_file, map_location=device)
keys_to_remove = [key for key in ckpt.keys() if "masked_spec_embed" in key]
for key in keys_to_remove:
print(f"Removing unexpected key: {key}")
del ckpt[key]
try:
model.load_state_dict(ckpt, strict=False)
except RuntimeError as e:
print("Model loading error:", e)
print("Trying to load entire model...")
model = torch.load(ckpt_file, map_location=device)
model.to(device)
model.eval()
torch.cuda.empty_cache()
if args.closed_test:
print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...")
test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels)
elif args.open_test:
print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...")
test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels)
else:
print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...")
test_dataset = FakeMusicCapsDataset(val_files, val_labels)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16)
def Test(model, test_loader, device, phase="Test"):
model.eval()
test_loss, test_correct, test_total = 0, 0, 0
all_preds, all_labels = [], []
with torch.no_grad():
for inputs, labels in tqdm(test_loader, desc=f"{phase}"):
inputs, labels = inputs.to(device), labels.to(device)
inputs = inputs.squeeze(1) # Ensure correct input shape
output = model(inputs)
loss = F.cross_entropy(output, labels)
test_loss += loss.item() * inputs.size(0)
preds = output.argmax(dim=1)
test_correct += (preds == labels).sum().item()
test_total += labels.size(0)
all_labels.extend(labels.cpu().numpy())
all_preds.extend(preds.cpu().numpy())
test_loss /= test_total
test_acc = test_correct / test_total
test_bal_acc = balanced_accuracy_score(all_labels, all_preds)
test_precision = precision_score(all_labels, all_preds, average="binary")
test_recall = recall_score(all_labels, all_preds, average="binary")
test_f1 = f1_score(all_labels, all_preds, average="binary")
print(f"\n{phase} Test Results - Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.3f} | "
f"Test Balanced Acc: {test_bal_acc:.4f} | Test Precision: {test_precision:.3f} | "
f"Test Recall: {test_recall:.3f} | Test F1: {test_f1:.3f}")
os.makedirs(args.output_path, exist_ok=True)
conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{phase}_opentest.png")
plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path)
print("\nEvaluating Model on Test Set...")
Test(model, test_loader, device, phase="Pretrained Model" if args.pretrain_test else "Fine-Tuned Model")
|