DavidCombei's picture
Upload 2 files
4613f1e verified
raw
history blame
4.71 kB
import joblib
from transformers import AutoFeatureExtractor, Wav2Vec2Model
import torch
import librosa
import numpy as np
from sklearn.linear_model import LogisticRegression
import gradio as gr
import os
from scipy.stats import mode
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#truncate the SSL from the 10th layer, since we only need the first 9th transformer layers
class CustomWav2Vec2Model(Wav2Vec2Model):
def __init__(self, config):
super().__init__(config)
self.encoder.layers = self.encoder.layers[:9]
truncated_model = CustomWav2Vec2Model.from_pretrained("facebook/wav2vec2-xls-r-2b")
# calling the SSL model for feature extraction
class HuggingFaceFeatureExtractor:
def __init__(self, model, feature_extractor_name):
self.device = device
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_name)
self.model = model
self.model.eval()
self.model.to(self.device)
def __call__(self, audio, sr):
inputs = self.feature_extractor(
audio,
sampling_rate=sr,
return_tensors="pt",
padding=True,
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
return outputs.hidden_states[9]
FEATURE_EXTRACTOR = HuggingFaceFeatureExtractor(truncated_model, "facebook/wav2vec2-xls-r-2b")
#load our best classifier
classifier = joblib.load('logreg_margin_pruning_ALL_best.joblib')
#segment audio and return the segments
def segment_audio(audio, sr, segment_duration):
segment_samples = int(segment_duration * sr)
total_samples = len(audio)
segments = [audio[i:i + segment_samples] for i in range(0, total_samples, segment_samples)]
return segments
# classification using the EER threshold
def classify_with_eer_threshold(probabilities, eer_thresh):
return (probabilities >= eer_thresh).astype(int)
def process_audio(input_data, segment_duration=30):
# resample to 16 kHz audio, since xls-r-2b it's trained on 16 KHz audio
audio, sr = librosa.load(input_data, sr=16000)
# check for single-channel audio (that's what xls-r-2b expects as input)
if len(audio.shape) > 1:
audio = audio[0]
# segment the audio in 30s chunks to avoid xls-r-2b crashing
print('loaded file')
segments = segment_audio(audio, sr, segment_duration)
final_features = []
print('segments')
# extract the features from each 30s segment
for segment in segments:
features = FEATURE_EXTRACTOR(segment, sr)
features_avg = torch.mean(features, dim=1).cpu().numpy()
final_features.append(features_avg)
print('features extracted')
inference_prob = []
for feature in final_features:
#reshape to avoid the batch dimension output from xls
feature = feature.reshape(1, -1)
#make the classification
print(classifier.classes_)
probability = classifier.predict_proba(feature)[:, 1]
inference_prob.append(probability[0])
print('classifier predicted')
eer_threshold = 0.9999999996754046
#all segment prediction based on probability score and eer threshold
y_pred_inference = classify_with_eer_threshold(np.array(inference_prob), eer_threshold)
print('inference done for segments')
#FINAL PREDICTION based on majority wins
mode_result = mode(y_pred_inference, keepdims=True)
final_prediction = mode_result.mode[0] if mode_result.mode.size > 0 else 0
print('majority voting done')
# confidence score (proportion of segments agreeing with majority prediction)
confidence_score = np.mean(y_pred_inference == final_prediction) if len(y_pred_inference) > 0 else 1.0
confidence_percentage = confidence_score * 100
return {
"Final classification": "Real" if final_prediction == 1 else "Fake",
"Confidence ": round(confidence_percentage, 2)
}
def gradio_interface(audio):
if audio:
return process_audio(audio)
else:
return "please upload audio or provide a YouTube link."
interface = gr.Interface(
fn=gradio_interface,
inputs=[gr.Audio(type="filepath", label="Upload Audio")],
outputs="text",
title="SOL2 Audio Deepfake detection Demo",
description="Upload an audio file to check if it's AI generated",
)
interface.launch(share=True)