|
|
import torch |
|
|
import gradio as gr |
|
|
import os |
|
|
import soundfile as sf |
|
|
import numpy as np |
|
|
from stable_audio_tools import get_pretrained_model |
|
|
from stable_audio_tools.inference.generation import generate_diffusion_cond |
|
|
from huggingface_hub import login |
|
|
|
|
|
|
|
|
model = None |
|
|
model_config = None |
|
|
device = None |
|
|
|
|
|
def load_model(): |
|
|
"""Load the pretrained model on startup""" |
|
|
global model, model_config, device |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Loading model on device: {device}") |
|
|
|
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
if hf_token: |
|
|
print("Using HF_TOKEN for authentication") |
|
|
login(token=hf_token) |
|
|
else: |
|
|
print("Warning: HF_TOKEN not found. Model access may fail if authentication is required.") |
|
|
print("Please set HF_TOKEN as a secret in your Space settings.") |
|
|
|
|
|
|
|
|
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-small") |
|
|
sample_rate = model_config["sample_rate"] |
|
|
sample_size = model_config["sample_size"] |
|
|
|
|
|
model = model.to(device).eval().requires_grad_(False) |
|
|
model = model.to(torch.float16) |
|
|
|
|
|
print(f"Model loaded successfully. Sample rate: {sample_rate}, Sample size: {sample_size}") |
|
|
return model, model_config |
|
|
|
|
|
def generate_audio(prompt, seconds_total=11): |
|
|
"""Generate 4 audio variations from a text prompt""" |
|
|
global model, model_config, device |
|
|
|
|
|
if model is None: |
|
|
return None, None, None, None, "Model not loaded. Please wait..." |
|
|
|
|
|
if not prompt or not prompt.strip(): |
|
|
return None, None, None, None, "Please enter a text prompt." |
|
|
|
|
|
|
|
|
conditioning = [{ |
|
|
"prompt": prompt, |
|
|
"seconds_total": seconds_total |
|
|
}] * 4 |
|
|
|
|
|
|
|
|
try: |
|
|
output = generate_diffusion_cond( |
|
|
model, |
|
|
steps=8, |
|
|
cfg_scale=1.0, |
|
|
conditioning=conditioning, |
|
|
sample_size=model_config["sample_size"], |
|
|
sampler_type="pingpong", |
|
|
device=device, |
|
|
batch_size=4 |
|
|
) |
|
|
|
|
|
sample_rate = model_config["sample_rate"] |
|
|
audio_files = [] |
|
|
|
|
|
|
|
|
for i in range(4): |
|
|
|
|
|
audio = output[i] |
|
|
|
|
|
|
|
|
audio = audio.to(torch.float32) |
|
|
audio_max = torch.max(torch.abs(audio)) |
|
|
if audio_max > 0: |
|
|
audio = audio.div(audio_max) |
|
|
audio = audio.clamp(-1, 1).cpu().numpy() |
|
|
|
|
|
|
|
|
if audio.ndim == 1: |
|
|
audio = audio.reshape(-1, 1) |
|
|
else: |
|
|
audio = audio.T |
|
|
|
|
|
|
|
|
filename = f"output_variation_{i+1}.wav" |
|
|
sf.write(filename, audio, sample_rate) |
|
|
audio_files.append(filename) |
|
|
|
|
|
|
|
|
return audio_files[0], audio_files[1], audio_files[2], audio_files[3], f"Generated 4 variations for: '{prompt}'" |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = f"Error generating audio: {str(e)}\n{traceback.format_exc()}" |
|
|
print(error_msg) |
|
|
return None, None, None, None, error_msg |
|
|
|
|
|
|
|
|
print("Initializing model...") |
|
|
load_model() |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Stable Audio Open Small - 4 Variations") as demo: |
|
|
gr.Markdown(""" |
|
|
# Stable Audio Open Small |
|
|
|
|
|
Generate up to 4 audio variations from a text prompt. |
|
|
|
|
|
**Model**: [stabilityai/stable-audio-open-small](https://huggingface.co/stabilityai/stable-audio-open-small) |
|
|
|
|
|
**Note**: This model requires accepting the license agreement. Make sure to set `HF_TOKEN` as a secret in your Space settings. |
|
|
|
|
|
Enter a text description and click Generate to create 4 different audio variations. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
prompt_input = gr.Textbox( |
|
|
label="Text Prompt", |
|
|
placeholder="e.g., 128 BPM tech house drum loop", |
|
|
lines=2 |
|
|
) |
|
|
seconds_input = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=11, |
|
|
value=11, |
|
|
step=1, |
|
|
label="Duration (seconds)", |
|
|
info="Maximum 11 seconds" |
|
|
) |
|
|
generate_btn = gr.Button("Generate", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
status_output = gr.Textbox(label="Status", interactive=False) |
|
|
gr.Markdown("### Generated Audio Variations") |
|
|
audio_output_1 = gr.Audio(label="Variation 1", interactive=False) |
|
|
audio_output_2 = gr.Audio(label="Variation 2", interactive=False) |
|
|
audio_output_3 = gr.Audio(label="Variation 3", interactive=False) |
|
|
audio_output_4 = gr.Audio(label="Variation 4", interactive=False) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_audio, |
|
|
inputs=[prompt_input, seconds_input], |
|
|
outputs=[audio_output_1, audio_output_2, audio_output_3, audio_output_4, status_output], |
|
|
api_name="generate_audio" |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### Tips |
|
|
- The model works best with English descriptions |
|
|
- Better at generating sound effects and field recordings than music |
|
|
- Each variation uses a different random seed for diversity |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|
|
|
|