Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,261 Bytes
acb0b29 c3c908f e836611 035eb8a acb0b29 035eb8a e836611 c3c908f e836611 c3c908f e836611 c3c908f e836611 6da214d e836611 c3c908f e836611 c3c908f e836611 035eb8a e836611 c3c908f e836611 f62dfd2 e836611 f62dfd2 e836611 f62dfd2 e836611 f62dfd2 e836611 f62dfd2 e836611 f62dfd2 e836611 8a4fa1b c3c908f 8a4fa1b c3c908f f62dfd2 c3c908f f62dfd2 e836611 f62dfd2 c3c908f f62dfd2 01e4765 c3c908f 01e4765 216b804 c3c908f 50d7dc5 c3c908f 216b804 c3c908f 8a4fa1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import spaces
import gradio as gr
import torch
import librosa
import numpy as np
from inference import inference
from huggingface_hub import hf_hub_download
from pathlib import Path
import os
token = os.getenv("HF_TOKEN")
def download_models_from_hub():
"""
Download model checkpoints from Hugging Face Model Hub
"""
model_dir = Path("checkpoints")
model_dir.mkdir(exist_ok=True)
models = {
"main": "EmbeddingModel_MERT_768_2class_weighted-epoch=0014-val_loss=0.0099-val_acc=0.9993-val_f1=0.9978-val_precision=0.9967-val_recall=0.9989.ckpt",
"backup": "step=003432-val_loss=0.0216-val_acc=0.9963.ckpt"
}
downloaded_models = {}
for model_name, filename in models.items():
local_path = model_dir / filename
if not local_path.exists():
print(f"๐ฅ Downloading {model_name} model from Hugging Face Hub...")
model_path = hf_hub_download(
repo_id="mippia/FST-checkpoints",
filename=filename,
local_dir=str(model_dir),
local_dir_use_symlinks=False,
token=token
)
print(f"โ
{model_name} model downloaded successfully!")
downloaded_models[model_name] = str(local_path)
else:
print(f"โ
{model_name} model already exists locally")
downloaded_models[model_name] = str(local_path)
return downloaded_models
@spaces.GPU
def detect_ai_audio(audio_file):
"""
Detect whether the uploaded audio file was generated by AI
and format the result based on the standardized output.
"""
if audio_file is None:
return "<div>โ ๏ธ Please upload an audio file.</div>"
try:
result = inference(audio_file) # {'prediction': 'Fake', 'confidence': '93.80', ...}
prediction = result.get('prediction', 'Unknown')
confidence = result.get('confidence', '0.00')
fake_prob = result.get('fake_probability', '0.0')
real_prob = result.get('real_probability', '0.0')
raw_output = result.get('raw_output', '')
formatted_result = f"""
<div style="text-align: center; padding: 15px; border-radius: 10px; border: 1px solid #ccc;">
<h2>Prediction: {prediction}</h2>
<p>Confidence: {confidence}%</p>
<p>Fake Probability: {float(fake_prob)*100:.2f}%</p>
<p>Real Probability: {float(real_prob)*100:.2f}%</p>
<p>Raw Output: {raw_output}</p>
</div>
"""
return formatted_result
except Exception as e:
return f"<div>Error processing audio: {str(e)}</div>"
# ๋คํฌ๋ชจ๋ ํธํ CSS
custom_css = """
.gradio-container { min-height: 100vh; }
.main-container { border-radius: 15px !important; margin: 20px auto !important; padding: 30px !important; max-width: 800px; }
h1 { text-align: center !important; font-size: 2.5em !important; font-weight: 700 !important; margin-bottom: 15px !important; }
.gradio-markdown p { text-align: center !important; font-size: 1.1em !important; margin-bottom: 20px !important; }
.upload-container { border-radius: 10px !important; padding: 15px !important; margin-bottom: 20px !important; }
.output-container { border-radius: 10px !important; padding: 15px !important; min-height: 150px !important; }
.gr-button { border-radius: 20px !important; padding: 10px 25px !important; font-weight: 600 !important; transition: all 0.2s ease !important; }
.gr-button:hover { transform: translateY(-2px) !important; }
@media (max-width: 768px) {
h1 { font-size: 2em !important; }
.main-container { margin: 10px !important; padding: 20px !important; }
}
"""
# ์ด๊ธฐํ
print("๐ Starting FST AI Audio Detection App...")
print("๐ฆ Initializing models...")
models = download_models_from_hub()
if models.get("main"):
print("โ
Main model ready for inference")
else:
print("โ ๏ธ Warning: Main model not available, app may not work properly")
# Gradio ์ธํฐํ์ด์ค
demo = gr.Interface(
fn=detect_ai_audio,
inputs=gr.Audio(type="filepath", label="Upload Audio File", elem_classes=["upload-container"]),
outputs=gr.HTML(label="Detection Result", elem_classes=["output-container"]),
title="Fusion Segment Transformer for AI Generated Music Detection",
description="""
<div style="text-align: center; font-size: 1em; color: #555; margin: 20px 0;">
<p><strong>Fusion Segment Transformer: Bi-directional attention guided fusion network for AI Generated Music Detection</strong></p>
<p>Authors: Yumin Kim and Seonghyeon Go</p>
<p>Submitted to ICASSP 2026. Detects AI-generated music by modeling full audio segments with content-structure fusion.</p>
<p>โ ๏ธ Note: On Zero GPU environment, processing may take ~30 seconds per audio file.</p>
</div>
""",
examples=[],
css=custom_css,
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="purple",
neutral_hue="gray",
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
),
elem_classes=["main-container"]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False, show_error=True) |