File size: 2,837 Bytes
c3c908f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import torch
import torch.nn.functional as F
import torchaudio
import argparse
from AI_Music_Detection.Code.model.wav2vec.wav2vec_datalib import preprocess_audio
from networks import Wav2Vec2ForFakeMusic

'''
command: python inference.py --gpu 0 --model_type pretrain --inference .wav
'''
parser = argparse.ArgumentParser(description="Wav2Vec2 AI Music Detection Inference")
parser.add_argument('--gpu', type=str, default='0', help='GPU ID')
parser.add_argument('--model_name', type=str, choices=['Wav2Vec2ForFakeMusic'], default='Wav2Vec2ForFakeMusic', help='Model name')
parser.add_argument('--ckpt_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/wav2vec/ckpt/', help='Checkpoint directory')
parser.add_argument('--model_type', type=str, choices=['pretrain', 'finetune'], required=True, help='Choose between pretrained or fine-tuned model')
parser.add_argument('--inference', type=str, help='Path to a .wav file for inference')  
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if args.model_type == 'pretrain':
    model_file = os.path.join(args.ckpt_path, "wav2vec2_pretrain_10.pth")
elif args.model_type == 'finetune':
    model_file = os.path.join(args.ckpt_path, "wav2vec2_finetune_5.pth")
else:
    raise ValueError("Invalid model type. Choose between 'pretrain' or 'finetune'.")

if not os.path.exists(model_file):
    raise FileNotFoundError(f"Model checkpoint not found: {model_file}")

if args.model_name == 'Wav2Vec2ForFakeMusic':
    model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=(args.model_type == 'finetune'))
else:
    raise ValueError(f"Invalid model name: {args.model_name}")

def predict(audio_path):
    print(f"\n🔍 Loading model from {model_file}")

    if not os.path.exists(audio_path):
        raise FileNotFoundError(f"[ERROR] Audio file not found: {audio_path}")

    model.to(device)

    input_tensor = preprocess_audio(audio_path).to(device)  
    print(f"Input shape after preprocessing: {input_tensor.shape}")  

    with torch.no_grad():
        output = model(input_tensor)  
        print(f"Raw model output (logits): {output}")

        probabilities = F.softmax(output, dim=1)
        ai_music_prob = probabilities[0, 1].item()

        print(f"Softmax Probabilities: {probabilities}")
        print(f"AI Music Probability: {ai_music_prob:.4f}")

        if ai_music_prob > 0.5:
            print(f" FAKE MUSIC DETECTED ({ai_music_prob:.2%})")
        else:
            print(f" REAL MUSIC DETECTED ({100 - ai_music_prob * 100:.2f}%)")



if __name__ == "__main__":
    if args.inference:
        if not os.path.exists(args.inference):
            print(f"[ERROR] No File Found: {args.inference}")
        else:
            predict(args.inference)