Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import torch | |
| import librosa | |
| import numpy as np | |
| from inference import inference | |
| from huggingface_hub import hf_hub_download | |
| from pathlib import Path | |
| import os | |
| token = os.getenv("HF_TOKEN") | |
| def download_models_from_hub(): | |
| """ | |
| Download model checkpoints from Hugging Face Model Hub | |
| """ | |
| model_dir = Path("checkpoints") | |
| model_dir.mkdir(exist_ok=True) | |
| models = { | |
| "main": "EmbeddingModel_MERT_768_2class_weighted-epoch=0014-val_loss=0.0099-val_acc=0.9993-val_f1=0.9978-val_precision=0.9967-val_recall=0.9989.ckpt", | |
| "backup": "step=003432-val_loss=0.0216-val_acc=0.9963.ckpt" | |
| } | |
| downloaded_models = {} | |
| for model_name, filename in models.items(): | |
| local_path = model_dir / filename | |
| if not local_path.exists(): | |
| print(f"π₯ Downloading {model_name} model from Hugging Face Hub...") | |
| model_path = hf_hub_download( | |
| repo_id="mippia/FST-checkpoints", | |
| filename=filename, | |
| local_dir=str(model_dir), | |
| local_dir_use_symlinks=False, | |
| token=token | |
| ) | |
| print(f"β {model_name} model downloaded successfully!") | |
| downloaded_models[model_name] = str(local_path) | |
| else: | |
| print(f"β {model_name} model already exists locally") | |
| downloaded_models[model_name] = str(local_path) | |
| return downloaded_models | |
| def detect_ai_audio(audio_file): | |
| """ | |
| Detect whether the uploaded audio file was generated by AI | |
| and format the result based on the standardized output. | |
| """ | |
| if audio_file is None: | |
| return "<div>β οΈ Please upload an audio file.</div>" | |
| try: | |
| result = inference(audio_file) # {'prediction': 'Fake', 'confidence': '93.80', ...} | |
| prediction = result.get('prediction', 'Unknown') | |
| confidence = result.get('confidence', '0.00') | |
| fake_prob = result.get('fake_probability', '0.0') | |
| real_prob = result.get('real_probability', '0.0') | |
| raw_output = result.get('raw_output', '') | |
| formatted_result = f""" | |
| <div style="text-align: center; padding: 15px; border-radius: 10px; border: 1px solid #ccc;"> | |
| <h2>Prediction: {prediction}</h2> | |
| <p>Confidence: {confidence}%</p> | |
| <p>Fake Probability: {float(fake_prob)*100:.2f}%</p> | |
| <p>Real Probability: {float(real_prob)*100:.2f}%</p> | |
| <p>Raw Output: {raw_output}</p> | |
| </div> | |
| """ | |
| return formatted_result | |
| except Exception as e: | |
| return f"<div>Error processing audio: {str(e)}</div>" | |
| # λ€ν¬λͺ¨λ νΈν CSS | |
| custom_css = """ | |
| .gradio-container { min-height: 100vh; } | |
| .main-container { border-radius: 15px !important; margin: 20px auto !important; padding: 30px !important; max-width: 800px; } | |
| h1 { text-align: center !important; font-size: 2.5em !important; font-weight: 700 !important; margin-bottom: 15px !important; } | |
| .gradio-markdown p { text-align: center !important; font-size: 1.1em !important; margin-bottom: 20px !important; } | |
| .upload-container { border-radius: 10px !important; padding: 15px !important; margin-bottom: 20px !important; } | |
| .output-container { border-radius: 10px !important; padding: 15px !important; min-height: 150px !important; } | |
| .gr-button { border-radius: 20px !important; padding: 10px 25px !important; font-weight: 600 !important; transition: all 0.2s ease !important; } | |
| .gr-button:hover { transform: translateY(-2px) !important; } | |
| @media (max-width: 768px) { | |
| h1 { font-size: 2em !important; } | |
| .main-container { margin: 10px !important; padding: 20px !important; } | |
| } | |
| """ | |
| # μ΄κΈ°ν | |
| print("π Starting FST AI Audio Detection App...") | |
| print("π¦ Initializing models...") | |
| models = download_models_from_hub() | |
| if models.get("main"): | |
| print("β Main model ready for inference") | |
| else: | |
| print("β οΈ Warning: Main model not available, app may not work properly") | |
| # Gradio μΈν°νμ΄μ€ | |
| demo = gr.Interface( | |
| fn=detect_ai_audio, | |
| inputs=gr.Audio(type="filepath", label="Upload Audio File", elem_classes=["upload-container"]), | |
| outputs=gr.HTML(label="Detection Result", elem_classes=["output-container"]), | |
| title="Fusion Segment Transformer for AI Generated Music Detection", | |
| description=""" | |
| <div style="text-align: center; font-size: 1em; color: #555; margin: 20px 0;"> | |
| <p><strong>Fusion Segment Transformer: Bi-directional attention guided fusion network for AI Generated Music Detection</strong></p> | |
| <p>Authors: Yumin Kim and Seonghyeon Go</p> | |
| <p>Submitted to ICASSP 2026. Detects AI-generated music by modeling full audio segments with content-structure fusion.</p> | |
| <p>β οΈ Note: On Zero GPU environment, processing may take ~30 seconds per audio file.</p> | |
| </div> | |
| """, | |
| examples=[], | |
| css=custom_css, | |
| theme=gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="purple", | |
| neutral_hue="gray", | |
| font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"] | |
| ), | |
| elem_classes=["main-container"] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False, show_error=True) |