slslslrhfem
change color for darkmode
8a4fa1b
import spaces
import gradio as gr
import torch
import librosa
import numpy as np
from inference import inference
from huggingface_hub import hf_hub_download
from pathlib import Path
import os
token = os.getenv("HF_TOKEN")
def download_models_from_hub():
"""
Download model checkpoints from Hugging Face Model Hub
"""
model_dir = Path("checkpoints")
model_dir.mkdir(exist_ok=True)
models = {
"main": "EmbeddingModel_MERT_768_2class_weighted-epoch=0014-val_loss=0.0099-val_acc=0.9993-val_f1=0.9978-val_precision=0.9967-val_recall=0.9989.ckpt",
"backup": "step=003432-val_loss=0.0216-val_acc=0.9963.ckpt"
}
downloaded_models = {}
for model_name, filename in models.items():
local_path = model_dir / filename
if not local_path.exists():
print(f"πŸ“₯ Downloading {model_name} model from Hugging Face Hub...")
model_path = hf_hub_download(
repo_id="mippia/FST-checkpoints",
filename=filename,
local_dir=str(model_dir),
local_dir_use_symlinks=False,
token=token
)
print(f"βœ… {model_name} model downloaded successfully!")
downloaded_models[model_name] = str(local_path)
else:
print(f"βœ… {model_name} model already exists locally")
downloaded_models[model_name] = str(local_path)
return downloaded_models
@spaces.GPU
def detect_ai_audio(audio_file):
"""
Detect whether the uploaded audio file was generated by AI
and format the result based on the standardized output.
"""
if audio_file is None:
return "<div>⚠️ Please upload an audio file.</div>"
try:
result = inference(audio_file) # {'prediction': 'Fake', 'confidence': '93.80', ...}
prediction = result.get('prediction', 'Unknown')
confidence = result.get('confidence', '0.00')
fake_prob = result.get('fake_probability', '0.0')
real_prob = result.get('real_probability', '0.0')
raw_output = result.get('raw_output', '')
formatted_result = f"""
<div style="text-align: center; padding: 15px; border-radius: 10px; border: 1px solid #ccc;">
<h2>Prediction: {prediction}</h2>
<p>Confidence: {confidence}%</p>
<p>Fake Probability: {float(fake_prob)*100:.2f}%</p>
<p>Real Probability: {float(real_prob)*100:.2f}%</p>
<p>Raw Output: {raw_output}</p>
</div>
"""
return formatted_result
except Exception as e:
return f"<div>Error processing audio: {str(e)}</div>"
# 닀크λͺ¨λ“œ ν˜Έν™˜ CSS
custom_css = """
.gradio-container { min-height: 100vh; }
.main-container { border-radius: 15px !important; margin: 20px auto !important; padding: 30px !important; max-width: 800px; }
h1 { text-align: center !important; font-size: 2.5em !important; font-weight: 700 !important; margin-bottom: 15px !important; }
.gradio-markdown p { text-align: center !important; font-size: 1.1em !important; margin-bottom: 20px !important; }
.upload-container { border-radius: 10px !important; padding: 15px !important; margin-bottom: 20px !important; }
.output-container { border-radius: 10px !important; padding: 15px !important; min-height: 150px !important; }
.gr-button { border-radius: 20px !important; padding: 10px 25px !important; font-weight: 600 !important; transition: all 0.2s ease !important; }
.gr-button:hover { transform: translateY(-2px) !important; }
@media (max-width: 768px) {
h1 { font-size: 2em !important; }
.main-container { margin: 10px !important; padding: 20px !important; }
}
"""
# μ΄ˆκΈ°ν™”
print("πŸš€ Starting FST AI Audio Detection App...")
print("πŸ“¦ Initializing models...")
models = download_models_from_hub()
if models.get("main"):
print("βœ… Main model ready for inference")
else:
print("⚠️ Warning: Main model not available, app may not work properly")
# Gradio μΈν„°νŽ˜μ΄μŠ€
demo = gr.Interface(
fn=detect_ai_audio,
inputs=gr.Audio(type="filepath", label="Upload Audio File", elem_classes=["upload-container"]),
outputs=gr.HTML(label="Detection Result", elem_classes=["output-container"]),
title="Fusion Segment Transformer for AI Generated Music Detection",
description="""
<div style="text-align: center; font-size: 1em; color: #555; margin: 20px 0;">
<p><strong>Fusion Segment Transformer: Bi-directional attention guided fusion network for AI Generated Music Detection</strong></p>
<p>Authors: Yumin Kim and Seonghyeon Go</p>
<p>Submitted to ICASSP 2026. Detects AI-generated music by modeling full audio segments with content-structure fusion.</p>
<p>⚠️ Note: On Zero GPU environment, processing may take ~30 seconds per audio file.</p>
</div>
""",
examples=[],
css=custom_css,
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="purple",
neutral_hue="gray",
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
),
elem_classes=["main-container"]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False, show_error=True)