Step-Audio-R1 / app.py
moevis's picture
Update app.py
b5da221 verified
raw
history blame
10.2 kB
#!/usr/bin/env python3
"""
Step Audio R1 vLLM Gradio Interface
"""
import base64
import json
import os
import gradio as gr
import httpx
API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:9999/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Step-Audio-R1")
def encode_audio(audio_path):
"""编码音频为base64"""
if not audio_path or not os.path.exists(audio_path):
return None
try:
with open(audio_path, "rb") as f:
return base64.b64encode(f.read()).decode()
except Exception as e:
print(f"[DEBUG] Audio error: {e}")
return None
def format_messages(system, history, user_text, audio_data=None, audio_format="wav"):
"""Format message list"""
messages = []
if system:
messages.append({"role": "system", "content": system})
if not history:
history = []
# 处理历史记录
for item in history:
# 支持 list of dicts 格式
if isinstance(item, dict) and "role" in item and "content" in item:
messages.append(item)
# 支持 Gradio ChatMessage 对象
elif hasattr(item, "role") and hasattr(item, "content"):
messages.append({"role": item.role, "content": item.content})
# 添加当前用户消息
if user_text and audio_data:
messages.append({
"role": "user",
"content": [
{
"type": "input_audio",
"input_audio": {
"data": audio_data,
"format": audio_format
}
},
{
"type": "text",
"text": user_text
}
]
})
elif user_text:
messages.append({"role": "user", "content": user_text})
elif audio_data:
messages.append({
"role": "user",
"content": [
{
"type": "input_audio",
"input_audio": {
"data": audio_data,
"format": audio_format
}
}
]
})
return messages
def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature, top_p, model_name=None):
"""Chat function"""
# If model is not specified, use global configuration
if model_name is None:
model_name = MODEL_NAME
if not user_text and not audio_file:
return history or [], "Please enter text or upload audio"
# Ensure history is a list and formatted correctly
history = history or []
clean_history = []
for item in history:
if isinstance(item, dict) and 'role' in item and 'content' in item:
clean_history.append(item)
elif hasattr(item, "role") and hasattr(item, "content"):
# Keep ChatMessage object
clean_history.append(item)
history = clean_history
# Process audio
audio_data = None
audio_format = "wav"
if audio_file:
audio_data = encode_audio(audio_file)
if audio_file.lower().endswith(".mp3"):
audio_format = "mp3"
messages = format_messages(system_prompt, history, user_text, audio_data, audio_format)
if not messages:
return history or [], "Invalid input"
# Debug: Print message format
print(f"[DEBUG] Messages to API: {json.dumps(messages, ensure_ascii=False, indent=2)}")
print(f"[DEBUG] Messages type: {type(messages)}")
for i, msg in enumerate(messages):
print(f"[DEBUG] Message {i}: {type(msg)} - {msg}")
try:
with httpx.Client(base_url=API_BASE_URL, timeout=120) as client:
response = client.post("/chat/completions", json={
"model": model_name,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"stream": True,
"repetition_penalty": 1.07,
"stop_token_ids": [151665]
})
if response.status_code != 200:
error_msg = f"❌ API Error {response.status_code}"
if response.status_code == 404:
error_msg += " - vLLM service not ready"
elif response.status_code == 400:
error_msg += " - Bad request"
elif response.status_code == 500:
error_msg += " - Model error"
return history, error_msg
# Process streaming response
content_parts = []
for line in response.iter_lines():
if not line:
continue
# Ensure line is string format
if isinstance(line, bytes):
line = line.decode('utf-8')
else:
line = str(line)
if line.startswith('data: '):
data_str = line[6:]
if data_str.strip() == '[DONE]':
break
try:
data = json.loads(data_str)
if 'choices' in data and len(data['choices']) > 0:
delta = data['choices'][0].get('delta', {})
if 'content' in delta:
content_parts.append(delta['content'])
except json.JSONDecodeError:
continue
full_content = ''.join(content_parts)
# Update history - only add when no error
history = history or []
# Add user message
if audio_file:
# If audio exists, show audio file and text (if any)
# Gradio Chatbot supports tuple (file_path,) to show file
# But in messages format, we need to construct proper content
# Here we use tuple format to let Gradio render audio player, or use HTML
# Simpler way: if multimodal, add messages separately
# 1. Add audio message
history.append({"role": "user", "content": gr.Audio(audio_file)})
# 2. If text exists, add text message
if user_text:
history.append({"role": "user", "content": user_text})
else:
# Text only
history.append({"role": "user", "content": user_text})
# Split think and content
if "</think>" in full_content:
parts = full_content.split("</think>", 1)
think_content = parts[0].strip()
response_content = parts[1].strip()
# Remove possible start tag
if think_content.startswith("<think>"):
think_content = think_content[len("<think>"):].strip()
# Add thinking process message (use ChatMessage and metadata)
if think_content:
history.append(gr.ChatMessage(
role="assistant",
content=think_content,
metadata={"title": "⏳ Thinking Process"}
))
# Add formal response message
if response_content:
history.append({"role": "assistant", "content": response_content})
else:
# No think tag, add full response directly
assistant_text = full_content.strip()
if assistant_text:
history.append({"role": "assistant", "content": assistant_text})
return history, ""
except httpx.ConnectError:
return history, "❌ Cannot connect to vLLM API"
except Exception as e:
return history, f"❌ Error: {str(e)}"
# Gradio Interface
with gr.Blocks(title="Step Audio R1") as demo:
gr.Markdown("# Step Audio R1 Chat")
with gr.Row():
# Left Configuration
with gr.Column(scale=1):
with gr.Accordion("Configuration", open=True):
system_prompt = gr.Textbox(
label="System Prompt",
lines=2,
value="You are an audio analysis expert"
)
max_tokens = gr.Slider(1, 8192, value=1024, label="Max Tokens")
temperature = gr.Slider(0.0, 2.0, value=0.7, label="Temperature")
top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top P")
status = gr.Textbox(label="Status", interactive=False)
# Right Chat
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Chat History", height=450)
user_text = gr.Textbox(label="Input", lines=2, placeholder="Enter message...")
audio_file = gr.Audio(label="Audio", type="filepath", sources=["microphone", "upload"])
with gr.Row():
submit_btn = gr.Button("Send", variant="primary", scale=2)
clear_btn = gr.Button("Clear", scale=1)
# 事件绑定 - 函数将在启动时定义
# 直接绑定 chat 函数;不要传递外部的 `model_to_use`,chat 使用默认的 `MODEL_NAME` 或内部参数
submit_btn.click(
fn=chat,
inputs=[system_prompt, user_text, audio_file, chatbot, max_tokens, temperature, top_p],
outputs=[chatbot, status]
)
clear_btn.click(
fn=lambda: ([], "", None),
outputs=[chatbot, user_text, audio_file]
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--model", default=MODEL_NAME)
args = parser.parse_args()
# 更新全局模型名称
if args.model:
MODEL_NAME = args.model
print(f"启动Gradio: http://{args.host}:{args.port}")
print(f"API地址: {API_BASE_URL}")
print(f"模型: {MODEL_NAME}")
demo.launch(server_name=args.host, server_port=args.port, share=False)