#!/usr/bin/env python3
"""
Step Audio R1 vLLM Gradio Interface
"""
import base64
import json
import os
import io
from pydub import AudioSegment
import gradio as gr
import httpx
API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:9999/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Step-Audio-R1")
def escape_html(text):
"""Escape HTML special characters to prevent XSS"""
if not isinstance(text, str):
return text
return (text
.replace("&", "&")
.replace("<", "<")
.replace(">", ">")
.replace('"', """)
.replace("'", "'"))
def process_audio(audio_path):
"""
Process audio: convert to wav, split if > 25s.
Returns a list of base64 encoded wav strings.
"""
if not audio_path or not os.path.exists(audio_path):
return []
try:
# Load audio (pydub handles mp3, wav, etc. automatically if ffmpeg is installed)
audio = AudioSegment.from_file(audio_path)
# Split into chunks of 25 seconds (25000 ms)
chunk_length_ms = 25000
chunks = []
if len(audio) > chunk_length_ms:
for i in range(0, len(audio), chunk_length_ms):
chunk = audio[i:i + chunk_length_ms]
chunks.append(chunk)
else:
chunks.append(audio)
# Convert chunks to base64 wav
audio_data_list = []
for chunk in chunks:
buffer = io.BytesIO()
chunk.export(buffer, format="wav")
encoded = base64.b64encode(buffer.getvalue()).decode()
audio_data_list.append(encoded)
return audio_data_list
except Exception as e:
print(f"[DEBUG] Audio processing error: {e}")
return []
def format_messages(system, history, user_text, audio_data_list=None):
"""Format message list"""
messages = []
if system:
messages.append({"role": "system", "content": system})
if not history:
history = []
# 处理历史记录
for item in history:
role = item.get("role") if isinstance(item, dict) else getattr(item, "role", None)
content = item.get("content") if isinstance(item, dict) else getattr(item, "content", None)
if not role or content is None:
continue
# If content contains thinking process (with thinking-block div), extract only the response part
if role == "assistant" and isinstance(content, str) and '
' in content:
# Find the end of the thinking block and extract what comes after
# Match the entire thinking block
pattern = r'
.*?
\s*
\s*'
remaining_content = re.sub(pattern, '', content, flags=re.DOTALL).strip()
# If there's meaningful content after the thinking block, use it
if remaining_content and not remaining_content.startswith('<'):
content = remaining_content
else:
# Still in thinking phase or no response yet, skip
continue
# Check for Audio
is_audio = isinstance(content, dict) and content.get("component") == "audio"
if is_audio:
audio_path = content["value"]["path"]
if audio_path and os.path.exists(audio_path):
try:
item_audio_data_list = process_audio(audio_path)
new_content = []
for audio_data in item_audio_data_list:
new_content.append({
"type": "input_audio",
"input_audio": {
"data": audio_data,
"format": "wav"
}
})
messages.append({"role": role, "content": new_content})
except Exception as e:
print(f"[ERROR] Failed to process history audio: {e}")
elif isinstance(content, str):
messages.append({"role": role, "content": content})
elif isinstance(content, list):
# Process list items and ensure text comes before audio
text_items = []
audio_items = []
other_items = []
for c in content:
# Check for Audio in list
is_c_audio = isinstance(c, dict) and c.get('component') == "audio"
if is_c_audio:
audio_path = c["value"]["path"]
if audio_path and os.path.exists(audio_path):
try:
item_audio_data_list = process_audio(audio_path)
for audio_data in item_audio_data_list:
audio_items.append({
"type": "input_audio",
"input_audio": {
"data": audio_data,
"format": "wav"
}
})
except Exception as e:
print(f"[ERROR] Failed to process history audio in list: {e}")
elif isinstance(c, str):
text_items.append({"type": "text", "text": c})
elif isinstance(c, dict):
# Distinguish between text and audio types
if c.get("type") == "text":
text_items.append(c)
elif c.get("type") == "input_audio":
audio_items.append(c)
else:
other_items.append(c)
# Combine: text first, then audio, then others
safe_content = text_items + audio_items + other_items
if safe_content:
messages.append({"role": role, "content": safe_content})
# 添加当前用户消息(文本在前,音频在后)
if user_text and audio_data_list:
content = []
# 先添加文本
content.append({
"type": "text",
"text": user_text
})
# 再添加音频
for audio_data in audio_data_list:
content.append({
"type": "input_audio",
"input_audio": {
"data": audio_data,
"format": "wav"
}
})
messages.append({
"role": "user",
"content": content
})
elif user_text:
messages.append({"role": "user", "content": user_text})
elif audio_data_list:
content = []
for audio_data in audio_data_list:
content.append({
"type": "input_audio",
"input_audio": {
"data": audio_data,
"format": "wav"
}
})
messages.append({
"role": "user",
"content": content
})
return messages
def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature, top_p, show_thinking=True, model_name=None):
"""Chat function"""
# If model is not specified, use global configuration
if model_name is None:
model_name = MODEL_NAME
if not user_text and not audio_file:
yield history or []
return
# Ensure history is a list and formatted correctly
history = history or []
clean_history = []
for item in history:
if isinstance(item, dict) and 'role' in item and 'content' in item:
clean_history.append(item)
elif hasattr(item, "role") and hasattr(item, "content"):
# Keep ChatMessage object
clean_history.append(item)
history = clean_history
# Process audio
audio_data_list = []
if audio_file:
audio_data_list = process_audio(audio_file)
messages = format_messages(system_prompt, history, user_text, audio_data_list)
if not messages:
yield history or []
return
# Debug: Print message format
debug_messages = []
for msg in messages:
if isinstance(msg, dict) and isinstance(msg.get("content"), list):
new_content = []
for item in msg["content"]:
if isinstance(item, dict) and item.get("type") == "input_audio":
item_copy = item.copy()
if "input_audio" in item_copy:
audio_info = item_copy["input_audio"].copy()
if "data" in audio_info:
audio_info["data"] = f"[BASE64_AUDIO_DATA_LEN_{len(audio_info['data'])}]"
item_copy["input_audio"] = audio_info
new_content.append(item_copy)
else:
new_content.append(item)
msg_copy = msg.copy()
msg_copy["content"] = new_content
debug_messages.append(msg_copy)
else:
debug_messages.append(msg)
print(f"[DEBUG] Messages to API: {json.dumps(debug_messages, ensure_ascii=False, indent=2)}")
# Update history with user message immediately (text first, then audio)
if user_text and audio_file:
# 1. Add text message first
history.append({"role": "user", "content": user_text})
# 2. Add audio message second
history.append({"role": "user", "content": gr.Audio(audio_file)})
elif user_text:
# Text only
history.append({"role": "user", "content": user_text})
elif audio_file:
# Audio only
history.append({"role": "user", "content": gr.Audio(audio_file)})
# Add thinking placeholder
if show_thinking:
history.append({
"role": "assistant",
"content": (
'\n'
'\n'
'
Processing your request...
\n'
'
'
)
})
yield history
else:
history.append({
"role": "assistant",
"content": "⏳ Generating response..."
})
yield history
try:
# 禁用代理以访问内网 API
with httpx.Client(base_url=API_BASE_URL, timeout=120, proxies={}) as client:
response = client.post("/chat/completions", json={
"model": model_name,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"stream": True,
"repetition_penalty": 1.0,
"stop_token_ids": [151665]
})
if response.status_code != 200:
error_msg = f"❌ API Error {response.status_code}"
if response.status_code == 404:
error_msg += " - vLLM service not ready"
elif response.status_code == 400:
error_msg += " - Bad request"
elif response.status_code == 500:
error_msg += " - Model error"
# Update the last message with error
history[-1]["content"] = error_msg
yield history
return
# Process streaming response
buffer = ""
is_thinking = True
for line in response.iter_lines():
if not line:
continue
# Ensure line is string format
if isinstance(line, bytes):
line = line.decode('utf-8')
else:
line = str(line)
if line.startswith('data: '):
data_str = line[6:]
if data_str.strip() == '[DONE]':
break
try:
data = json.loads(data_str)
if 'choices' in data and len(data['choices']) > 0:
delta = data['choices'][0].get('delta', {})
if 'content' in delta:
content = delta['content']
buffer += content
if is_thinking:
if "" in buffer:
is_thinking = False
parts = buffer.split("", 1)
think_content = parts[0]
response_content = parts[1]
if think_content.startswith(""):
think_content = think_content[len(""):].strip()
if show_thinking:
# Format thinking with custom styled block (escape HTML for safety)
escaped_think = escape_html(think_content)
formatted_content = (
f'\n'
f'\n'
f'
{escaped_think}
\n'
f'
\n\n'
f'{response_content}'
)
history[-1]["content"] = formatted_content
else:
# Don't show thinking, replace with response message directly
history[-1]["content"] = response_content
else:
# Update thinking message with collapsible format (only if showing)
if show_thinking:
current_think = buffer
if current_think.startswith(""):
current_think = current_think[len(""):].strip()
escaped_think = escape_html(current_think)
formatted_content = (
f'\n'
f'\n'
f'
{escaped_think}
\n'
f'
'
)
history[-1]["content"] = formatted_content
else:
# Already split, update the combined message
parts = buffer.split("", 1)
think_content = parts[0]
response_content = parts[1]
if think_content.startswith(""):
think_content = think_content[len(""):].strip()
if show_thinking:
# Update with formatted thinking + response
escaped_think = escape_html(think_content)
formatted_content = (
f'\n'
f'\n'
f'
{escaped_think}
\n'
f'
\n\n'
f'{response_content}'
)
history[-1]["content"] = formatted_content
else:
# Only show response
history[-1]["content"] = response_content
yield history
except json.JSONDecodeError:
continue
except httpx.ConnectError:
history[-1]["content"] = "❌ Cannot connect to vLLM API"
yield history
except Exception as e:
history[-1]["content"] = f"❌ Error: {str(e)}"
yield history
# Custom CSS for better UI
custom_css = """
/* 全局样式 */
.gradio-container {
max-width: 100% !important;
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
}
/* 标题样式 */
.app-header {
text-align: center;
padding: 2.5rem 1.5rem;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
position: relative;
overflow: hidden;
border-radius: 16px;
margin-bottom: 1.5rem;
box-shadow: 0 8px 24px rgba(102, 126, 234, 0.35);
}
/* 标题背景装饰 */
.app-header::before {
content: '';
position: absolute;
top: -50%;
right: -50%;
width: 200%;
height: 200%;
background: radial-gradient(circle, rgba(255, 255, 255, 0.1) 0%, transparent 70%);
animation: rotate 20s linear infinite;
}
@keyframes rotate {
from { transform: rotate(0deg); }
to { transform: rotate(360deg); }
}
.app-header h1 {
margin: 0;
font-size: 2.8rem;
font-weight: 700;
color: white !important;
text-shadow: 0 3px 6px rgba(0, 0, 0, 0.25);
letter-spacing: 1px;
position: relative;
z-index: 1;
}
.app-header p {
color: rgba(255, 255, 255, 0.95) !important;
text-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
position: relative;
z-index: 1;
line-height: 1.5;
}
/* 聊天框样式 */
.chatbot-container {
border-radius: 12px;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08);
overflow: hidden;
}
/* 思考过程样式 - 模仿Claude/ChatGPT的风格 */
.thinking-block {
background: linear-gradient(135deg, #f5f7fa 0%, #eef2f7 100%);
border-left: 4px solid #667eea;
padding: 16px 20px;
margin: 12px 0;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
}
.thinking-header {
display: flex;
align-items: center;
font-weight: 600;
color: #667eea;
margin-bottom: 10px;
font-size: 0.95rem;
}
.thinking-content {
background: #ffffff;
padding: 12px 16px;
border-radius: 6px;
font-family: 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas, 'Courier New', monospace;
font-size: 0.9rem;
line-height: 1.6;
color: #374151;
white-space: pre-wrap;
word-wrap: break-word;
border: 1px solid #e5e7eb;
}
/* 回复分隔线 */
.response-divider {
border: none;
height: 2px;
background: linear-gradient(to right, transparent, #e5e7eb, transparent);
margin: 20px 0;
}
/* 按钮样式 */
.primary-btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
border: none !important;
transition: all 0.3s ease !important;
}
.primary-btn:hover {
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4) !important;
}
/* 左侧面板样式 */
.left-panel {
background: #f9fafb;
border-radius: 12px;
padding: 1rem;
height: 100%;
}
/* 输入框样式 */
.input-box textarea {
border-radius: 8px !important;
border: 2px solid #e5e7eb !important;
transition: border-color 0.3s ease !important;
}
.input-box textarea:focus {
border-color: #667eea !important;
box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important;
}
/* 输入区域标题 */
h3 {
color: #374151;
font-size: 1.1rem;
margin: 1rem 0 0.5rem 0;
}
/* 聊天消息样式优化 */
.message-wrap {
padding: 1rem !important;
}
.message {
padding: 1rem !important;
border-radius: 12px !important;
line-height: 1.6 !important;
}
/* 用户消息 */
.message.user {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
color: white !important;
}
/* 助手消息 */
.message.bot {
background: #f9fafb !important;
border: 1px solid #e5e7eb !important;
}
/* 左侧面板整体样式 */
.left-column {
background: linear-gradient(to bottom, #ffffff 0%, #f9fafb 100%);
border-radius: 12px;
padding: 1rem;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05);
}
/* 按钮容器样式 */
.button-row {
margin-top: 1rem;
gap: 0.5rem;
}
/* 滚动条美化 */
::-webkit-scrollbar {
width: 8px;
height: 8px;
}
::-webkit-scrollbar-track {
background: #f1f1f1;
border-radius: 4px;
}
::-webkit-scrollbar-thumb {
background: #888;
border-radius: 4px;
}
::-webkit-scrollbar-thumb:hover {
background: #555;
}
"""
# Gradio Interface
with gr.Blocks(title="Step Audio R1", css=custom_css, theme=gr.themes.Soft()) as demo:
# Header
gr.HTML("""
""")
with gr.Row():
# Left Panel - Input Area
with gr.Column(scale=1, min_width=350):
# Configuration
with gr.Accordion("⚙️ Configuration", open=False):
system_prompt = gr.Textbox(
label="System Prompt",
lines=2,
value="You are a voice assistant with extensive experience in audio processing.",
placeholder="Enter system prompt...",
elem_classes=["input-box"]
)
max_tokens = gr.Slider(
1, 7192,
value=6400,
label="Max Tokens",
info="Maximum tokens to generate"
)
temperature = gr.Slider(
0.0, 2.0,
value=0.7,
label="Temperature",
info="Higher = more random"
)
top_p = gr.Slider(
0.0, 1.0,
value=0.9,
label="Top P",
info="Nucleus sampling"
)
show_thinking = gr.Checkbox(
label="💭 Show Thinking Process",
value=True,
info="Display reasoning steps"
)
# Input Area
gr.Markdown("### 📝 Your Input")
user_text = gr.Textbox(
label="Text Message",
lines=4,
placeholder="Type your message here...",
elem_classes=["input-box"],
show_label=False
)
audio_file = gr.Audio(
label="🎤 Audio Input",
type="filepath",
sources=["microphone", "upload"],
show_label=True
)
# Buttons
with gr.Row():
clear_btn = gr.Button("🗑️ Clear", scale=1, size="lg")
submit_btn = gr.Button(
"🚀 Send",
variant="primary",
scale=2,
size="lg",
elem_classes=["primary-btn"]
)
# Usage Guide at bottom
with gr.Accordion("📖 Quick Guide", open=False):
gr.Markdown("""
**Usage:**
- Type text, upload audio, or both
- Audio > 25s auto-splits
- Toggle thinking process display
**Tips:**
- Thinking shown in blue gradient block
- History auto-cleaned for API
- Adjust params in Configuration
""")
# Right Panel - Conversation Area
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="💬 Conversation",
height=700,
type="messages",
elem_classes=["chatbot-container"],
show_label=True,
avatar_images=(None, None),
bubble_full_width=False
)
submit_btn.click(
fn=chat,
inputs=[system_prompt, user_text, audio_file, chatbot, max_tokens, temperature, top_p, show_thinking],
outputs=[chatbot]
)
clear_btn.click(
fn=lambda: ([], "", None),
outputs=[chatbot, user_text, audio_file]
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=6008)
parser.add_argument("--model", default=MODEL_NAME)
args = parser.parse_args()
import os
# 取消代理设置
os.environ.update({
'http_proxy': '', 'https_proxy': '', 'all_proxy': '',
'HTTP_PROXY': '', 'HTTPS_PROXY': '', 'ALL_PROXY': ''
})
# 更新全局模型名称
if args.model:
MODEL_NAME = args.model
print(f"启动Gradio: http://{args.host}:{args.port}")
print(f"API地址: {API_BASE_URL}")
print(f"模型: {MODEL_NAME}")
demo.launch(server_name=args.host, server_port=args.port, share=False)