saos / app.py
hugofloresgarcia's picture
Add api_name to generate_audio button for API access
da6984c
import torch
import gradio as gr
import os
import soundfile as sf
import numpy as np
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond
from huggingface_hub import login
# Global model variables
model = None
model_config = None
device = None
def load_model():
"""Load the pretrained model on startup"""
global model, model_config, device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on device: {device}")
# Check for HF_TOKEN environment variable (set in Space settings)
hf_token = os.getenv("HF_TOKEN")
if hf_token:
print("Using HF_TOKEN for authentication")
login(token=hf_token)
else:
print("Warning: HF_TOKEN not found. Model access may fail if authentication is required.")
print("Please set HF_TOKEN as a secret in your Space settings.")
# Download and load the pretrained model
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-small")
sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]
model = model.to(device).eval().requires_grad_(False)
model = model.to(torch.float16) # Use half precision for efficiency
print(f"Model loaded successfully. Sample rate: {sample_rate}, Sample size: {sample_size}")
return model, model_config
def generate_audio(prompt, seconds_total=11):
"""Generate 4 audio variations from a text prompt"""
global model, model_config, device
if model is None:
return None, None, None, None, "Model not loaded. Please wait..."
if not prompt or not prompt.strip():
return None, None, None, None, "Please enter a text prompt."
# Set up text and timing conditioning (repeat for batch_size)
conditioning = [{
"prompt": prompt,
"seconds_total": seconds_total
}] * 4 # Repeat for batch_size=4
# Generate 4 variations using batch_size=4
try:
output = generate_diffusion_cond(
model,
steps=8,
cfg_scale=1.0,
conditioning=conditioning,
sample_size=model_config["sample_size"],
sampler_type="pingpong",
device=device,
batch_size=4 # Generate 4 variations
)
sample_rate = model_config["sample_rate"]
audio_files = []
# Process each variation in the batch
for i in range(4):
# Extract single variation: [channels, samples]
audio = output[i] # Shape: [channels, samples]
# Peak normalize, clip, convert to float32 numpy array
audio = audio.to(torch.float32)
audio_max = torch.max(torch.abs(audio))
if audio_max > 0:
audio = audio.div(audio_max)
audio = audio.clamp(-1, 1).cpu().numpy()
# Transpose to [samples, channels] for soundfile
if audio.ndim == 1:
audio = audio.reshape(-1, 1)
else:
audio = audio.T # [channels, samples] -> [samples, channels]
# Save to temporary file using soundfile
filename = f"output_variation_{i+1}.wav"
sf.write(filename, audio, sample_rate)
audio_files.append(filename)
# Return 4 separate audio files and status message
return audio_files[0], audio_files[1], audio_files[2], audio_files[3], f"Generated 4 variations for: '{prompt}'"
except Exception as e:
import traceback
error_msg = f"Error generating audio: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return None, None, None, None, error_msg
# Load model on startup
print("Initializing model...")
load_model()
# Create Gradio interface
with gr.Blocks(title="Stable Audio Open Small - 4 Variations") as demo:
gr.Markdown("""
# Stable Audio Open Small
Generate up to 4 audio variations from a text prompt.
**Model**: [stabilityai/stable-audio-open-small](https://huggingface.co/stabilityai/stable-audio-open-small)
**Note**: This model requires accepting the license agreement. Make sure to set `HF_TOKEN` as a secret in your Space settings.
Enter a text description and click Generate to create 4 different audio variations.
""")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Text Prompt",
placeholder="e.g., 128 BPM tech house drum loop",
lines=2
)
seconds_input = gr.Slider(
minimum=1,
maximum=11,
value=11,
step=1,
label="Duration (seconds)",
info="Maximum 11 seconds"
)
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column():
status_output = gr.Textbox(label="Status", interactive=False)
gr.Markdown("### Generated Audio Variations")
audio_output_1 = gr.Audio(label="Variation 1", interactive=False)
audio_output_2 = gr.Audio(label="Variation 2", interactive=False)
audio_output_3 = gr.Audio(label="Variation 3", interactive=False)
audio_output_4 = gr.Audio(label="Variation 4", interactive=False)
generate_btn.click(
fn=generate_audio,
inputs=[prompt_input, seconds_input],
outputs=[audio_output_1, audio_output_2, audio_output_3, audio_output_4, status_output],
api_name="generate_audio"
)
gr.Markdown("""
### Tips
- The model works best with English descriptions
- Better at generating sound effects and field recordings than music
- Each variation uses a different random seed for diversity
""")
if __name__ == "__main__":
demo.launch()