File size: 4,705 Bytes
4613f1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import joblib
from transformers import AutoFeatureExtractor, Wav2Vec2Model
import torch
import librosa
import numpy as np
from sklearn.linear_model import LogisticRegression
import gradio as gr
import os
from scipy.stats import mode
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#truncate the SSL from the 10th layer, since we only need the first 9th transformer layers
class CustomWav2Vec2Model(Wav2Vec2Model):
def __init__(self, config):
super().__init__(config)
self.encoder.layers = self.encoder.layers[:9]
truncated_model = CustomWav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-2b")
# calling the SSL model for feature extraction
class HuggingFaceFeatureExtractor:
def __init__(self, model, feature_extractor_name):
self.device = device
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_name)
self.model = model
self.model.eval()
self.model.to(self.device)
def __call__(self, audio, sr):
inputs = self.feature_extractor(
audio,
sampling_rate=sr,
return_tensors="pt",
padding=True,
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
return outputs.hidden_states[9]
FEATURE_EXTRACTOR = HuggingFaceFeatureExtractor(truncated_model, "facebook/wav2vec2-xls-r-2b")
#load our best classifier
classifier = joblib.load('logreg_margin_pruning_ALL_best.joblib')
#segment audio and return the segments
def segment_audio(audio, sr, segment_duration):
segment_samples = int(segment_duration * sr)
total_samples = len(audio)
segments = [audio[i:i + segment_samples] for i in range(0, total_samples, segment_samples)]
return segments
# classification using the EER threshold
def classify_with_eer_threshold(probabilities, eer_thresh):
return (probabilities >= eer_thresh).astype(int)
def process_audio(input_data, segment_duration=30):
# resample to 16 kHz audio, since xls-r-2b it's trained on 16 KHz audio
audio, sr = librosa.load(input_data, sr=16000)
# check for single-channel audio (that's what xls-r-2b expects as input)
if len(audio.shape) > 1:
audio = audio[0]
# segment the audio in 30s chunks to avoid xls-r-2b crashing
print('loaded file')
segments = segment_audio(audio, sr, segment_duration)
final_features = []
print('segments')
# extract the features from each 30s segment
for segment in segments:
features = FEATURE_EXTRACTOR(segment, sr)
features_avg = torch.mean(features, dim=1).cpu().numpy()
final_features.append(features_avg)
print('features extracted')
inference_prob = []
for feature in final_features:
#reshape to avoid the batch dimension output from xls
feature = feature.reshape(1, -1)
#make the classification
print(classifier.classes_)
probability = classifier.predict_proba(feature)[:, 1]
inference_prob.append(probability[0])
print('classifier predicted')
eer_threshold = 0.9999999996754046
#all segment prediction based on probability score and eer threshold
y_pred_inference = classify_with_eer_threshold(np.array(inference_prob), eer_threshold)
print('inference done for segments')
#FINAL PREDICTION based on majority wins
mode_result = mode(y_pred_inference, keepdims=True)
final_prediction = mode_result.mode[0] if mode_result.mode.size > 0 else 0
print('majority voting done')
# confidence score (proportion of segments agreeing with majority prediction)
confidence_score = np.mean(y_pred_inference == final_prediction) if len(y_pred_inference) > 0 else 1.0
confidence_percentage = confidence_score * 100
return {
"Final classification": "Real" if final_prediction == 1 else "Fake",
"Confidence ": round(confidence_percentage, 2)
}
def gradio_interface(audio):
if audio:
return process_audio(audio)
else:
return "please upload audio or provide a YouTube link."
interface = gr.Interface(
fn=gradio_interface,
inputs=[gr.Audio(type="filepath", label="Upload Audio")],
outputs="text",
title="SOL2 Audio Deepfake detection Demo",
description="Upload an audio file to check if it's AI generated",
)
interface.launch(share=True)
|