File size: 4,705 Bytes
4613f1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import joblib
from transformers import AutoFeatureExtractor, Wav2Vec2Model
import torch
import librosa
import numpy as np
from sklearn.linear_model import LogisticRegression
import gradio as gr
import os
from scipy.stats import mode


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')




#truncate the SSL from the 10th layer, since we only need the first 9th transformer layers
class CustomWav2Vec2Model(Wav2Vec2Model):
    def __init__(self, config):
        super().__init__(config)
        self.encoder.layers = self.encoder.layers[:9]


truncated_model = CustomWav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-2b")


# calling the SSL model for feature extraction
class HuggingFaceFeatureExtractor:
    def __init__(self, model, feature_extractor_name):
        self.device = device
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_name)
        self.model = model
        self.model.eval()
        self.model.to(self.device)

    def __call__(self, audio, sr):
        inputs = self.feature_extractor(
            audio,
            sampling_rate=sr,
            return_tensors="pt",
            padding=True,
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
        return outputs.hidden_states[9]

FEATURE_EXTRACTOR = HuggingFaceFeatureExtractor(truncated_model, "facebook/wav2vec2-xls-r-2b")

#load our best classifier
classifier = joblib.load('logreg_margin_pruning_ALL_best.joblib')

#segment audio and return the segments
def segment_audio(audio, sr, segment_duration):
    segment_samples = int(segment_duration * sr)
    total_samples = len(audio)
    segments = [audio[i:i + segment_samples] for i in range(0, total_samples, segment_samples)]
    return segments


# classification using the EER threshold
def classify_with_eer_threshold(probabilities, eer_thresh):
    return (probabilities >= eer_thresh).astype(int)


def process_audio(input_data, segment_duration=30):
    # resample to 16 kHz audio, since xls-r-2b it's trained on 16 KHz audio

    audio, sr = librosa.load(input_data, sr=16000)

    # check for single-channel audio (that's what xls-r-2b expects as input)
    if len(audio.shape) > 1:
        audio = audio[0]

    # segment the audio in 30s chunks to avoid xls-r-2b crashing
    print('loaded file')
    segments = segment_audio(audio, sr, segment_duration)
    final_features = []
    print('segments')

    # extract the features from each 30s segment
    for segment in segments:
        features = FEATURE_EXTRACTOR(segment, sr)
        features_avg = torch.mean(features, dim=1).cpu().numpy()
        final_features.append(features_avg)
    print('features extracted')
    inference_prob = []
    for feature in final_features:
        #reshape to avoid the batch dimension output from xls
        feature = feature.reshape(1, -1)
        #make the classification
        print(classifier.classes_)

        probability = classifier.predict_proba(feature)[:, 1]
        inference_prob.append(probability[0])
        print('classifier predicted')
        eer_threshold = 0.9999999996754046

        #all segment prediction based on probability score and eer threshold
        y_pred_inference = classify_with_eer_threshold(np.array(inference_prob), eer_threshold)
        print('inference done for segments')
        #FINAL PREDICTION based on majority wins
        mode_result = mode(y_pred_inference, keepdims=True)
        final_prediction = mode_result.mode[0] if mode_result.mode.size > 0 else 0

        print('majority voting done')
        # confidence score (proportion of segments agreeing with majority prediction)
        confidence_score = np.mean(y_pred_inference == final_prediction) if len(y_pred_inference) > 0 else 1.0
        confidence_percentage = confidence_score * 100

        return {
            "Final classification": "Real" if final_prediction == 1 else "Fake",
            "Confidence ": round(confidence_percentage, 2)
        }



def gradio_interface(audio):
    if audio:
        return process_audio(audio)
    else:
        return "please upload audio or provide a YouTube link."

interface = gr.Interface(
    fn=gradio_interface,
    inputs=[gr.Audio(type="filepath", label="Upload Audio")],
    outputs="text",
    title="SOL2 Audio Deepfake detection Demo",
    description="Upload an audio file to check if it's AI generated",
)

interface.launch(share=True)