Step-Audio-R1 / app.py
moevis's picture
调整界面布局,将聊天历史和输入区域分为左右两侧,并优化输入框的默认提示信息
467194a
raw
history blame
11.9 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Gradio 多模态聊天界面:直接在 app.py 内部调用 vLLM.LLM 进行推理
"""
import base64
import os
import sys
import threading
import time
from typing import Optional, Tuple
import gradio as gr
# 检查命令行参数,在导入 vllm 之前确定是否启用
# 这样可以在没有安装 vllm 的情况下运行界面预览
if "--no-vllm" in sys.argv:
os.environ["ENABLE_VLLM"] = "false"
# 检查是否启用 vLLM 模式
ENABLE_VLLM = os.getenv("ENABLE_VLLM", "true").lower() in ("true", "1", "yes")
if ENABLE_VLLM:
try:
from vllm import LLM, SamplingParams
except ImportError:
print("[WARNING] 无法导入 vllm,自动切换到界面预览模式")
print("[INFO] 如需使用 vLLM,请先安装: pip install vllm")
ENABLE_VLLM = False
LLM = None
SamplingParams = None
else:
LLM = None
SamplingParams = None
print("[INFO] 运行在界面预览模式,不加载 vLLM")
# 默认配置,可通过环境变量或 CLI 覆盖
DEFAULT_MODEL_ID = os.getenv("MODEL_NAME", "stepfun-ai/Step-Audio-2-mini-Think")
DEFAULT_MODEL_PATH = os.getenv("MODEL_PATH", DEFAULT_MODEL_ID)
DEFAULT_TP = int(os.getenv("TENSOR_PARALLEL_SIZE", "4"))
DEFAULT_MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192"))
DEFAULT_GPU_UTIL = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.9"))
DEFAULT_TOKENIZER_MODE = os.getenv("TOKENIZER_MODE", "step_audio_2")
DEFAULT_SERVED_NAME = os.getenv("SERVED_MODEL_NAME", "step-audio-2-mini-think")
_llm: Optional[LLM] = None
_llm_lock = threading.Lock()
LLM_ARGS = {
"model": DEFAULT_MODEL_PATH,
"trust_remote_code": True,
"tensor_parallel_size": DEFAULT_TP,
"tokenizer_mode": DEFAULT_TOKENIZER_MODE,
"max_model_len": DEFAULT_MAX_MODEL_LEN,
"served_model_name": DEFAULT_SERVED_NAME,
"gpu_memory_utilization": DEFAULT_GPU_UTIL,
}
def encode_audio_to_base64(audio_path: Optional[str]) -> Optional[dict]:
"""将音频文件编码为 base64"""
if audio_path is None:
return None
try:
with open(audio_path, "rb") as audio_file:
audio_data = audio_file.read()
audio_base64 = base64.b64encode(audio_data).decode('utf-8')
# 尝试从文件扩展名推断格式
ext = os.path.splitext(audio_path)[1].lower().lstrip('.')
if not ext:
ext = "wav" # 默认格式
return {
"data": audio_base64,
"format": ext
}
except Exception as e:
print(f"Error encoding audio: {e}")
return None
def format_messages(
system_prompt: str,
chat_history: list,
user_text: str,
audio_file: Optional[str]
) -> list:
"""格式化消息为 OpenAI API 格式"""
messages = []
# 添加 system prompt
if system_prompt and system_prompt.strip():
messages.append({
"role": "system",
"content": system_prompt.strip()
})
# 添加历史对话
for human, assistant in chat_history:
if human:
messages.append({"role": "user", "content": human})
if assistant:
messages.append({"role": "assistant", "content": assistant})
# 添加当前用户输入
content_parts = []
# 添加文本输入
if user_text and user_text.strip():
content_parts.append({
"type": "text",
"text": user_text.strip()
})
# 添加音频输入
if audio_file:
audio_data = encode_audio_to_base64(audio_file)
if audio_data:
content_parts.append({
"type": "input_audio",
"input_audio": audio_data
})
if content_parts:
# 如果只有一个文本部分,直接使用字符串
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
messages.append({
"role": "user",
"content": content_parts[0]["text"]
})
else:
messages.append({
"role": "user",
"content": content_parts
})
return messages
def chat_predict(
system_prompt: str,
user_text: str,
audio_file: Optional[str],
chat_history: list,
max_tokens: int,
temperature: float,
top_p: float
) -> Tuple[list, str]:
"""调用本地 vLLM LLM 完成推理"""
if not user_text and not audio_file:
return chat_history, "⚠ 请提供文本或音频输入"
# 如果是预览模式,返回模拟响应
if not ENABLE_VLLM:
user_display = user_text if user_text else "[音频输入]"
mock_response = f"这是一个模拟回复。您说: {user_text[:50] if user_text else '音频'}"
chat_history.append((user_display, mock_response))
return chat_history, ""
messages = format_messages(system_prompt, chat_history, user_text, audio_file)
if not messages:
return chat_history, "⚠ 无有效输入"
try:
llm = _get_llm()
sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
start_time = time.time()
outputs = llm.chat(messages, sampling_params=sampling_params, use_tqdm=False)
latency = time.time() - start_time
if not outputs or not outputs[0].outputs:
return chat_history, "⚠ 模型未返回结果"
assistant_message = outputs[0].outputs[0].text
user_display = user_text if user_text else "[音频输入]"
chat_history.append((user_display, assistant_message))
return chat_history, ""
except Exception as e:
import traceback
traceback.print_exc()
return chat_history, ""
def _get_llm() -> LLM:
"""单例方式初始化 LLM"""
if not ENABLE_VLLM:
raise RuntimeError("vLLM 未启用,无法加载模型")
global _llm
if _llm is not None:
return _llm
with _llm_lock:
if _llm is not None:
return _llm
print(f"[LLM] 初始化中,参数: {LLM_ARGS}")
_llm = LLM(**LLM_ARGS)
return _llm
def _set_llm_args(**kwargs) -> None:
"""更新 LLM 初始化参数"""
global LLM_ARGS, _llm
LLM_ARGS = kwargs
_llm = None # 确保使用新配置重新加载
# 构建 Gradio 界面
with gr.Blocks(title="Step Audio 2 Chat", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Step Audio R1 Demo
"""
)
with gr.Row():
# 左侧:参数配置
with gr.Column(scale=1):
gr.Markdown("### 配置")
system_prompt = gr.Textbox(
label="System Prompt",
placeholder="输入系统提示词...",
lines=4,
value="You are an expert in audio analysis, please analyze the audio content and answer the questions accurately"
)
with gr.Row():
max_tokens = gr.Slider(
label="Max Tokens",
minimum=1,
maximum=16384,
value=8192,
step=1
)
with gr.Row():
temperature = gr.Slider(
label="Temperature",
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.1
)
top_p = gr.Slider(
label="Top P",
minimum=0.0,
maximum=1.0,
value=0.9,
step=0.05
)
# 右侧:对话和输入
with gr.Column(scale=1):
gr.Markdown("### 对话")
chatbot = gr.Chatbot(
label="聊天历史",
height=400,
show_copy_button=True
)
user_text = gr.Textbox(
label="文本输入",
placeholder="输入您的消息...",
lines=2
)
audio_file = gr.Audio(
label="音频输入",
type="filepath",
sources=["microphone", "upload"]
)
with gr.Row():
submit_btn = gr.Button("提交", variant="primary", size="lg")
clear_btn = gr.Button("清空", variant="secondary")
status_text = gr.Textbox(label="状态", interactive=False, visible=False)
# 事件绑定
submit_btn.click(
fn=chat_predict,
inputs=[
system_prompt,
user_text,
audio_file,
chatbot,
max_tokens,
temperature,
top_p
],
outputs=[chatbot, status_text]
)
clear_btn.click(
fn=lambda: ([], "", None),
outputs=[chatbot, user_text, audio_file]
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Step Audio 2 Gradio Chat Interface")
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="服务器主机地址"
)
parser.add_argument(
"--port",
type=int,
default=7860,
help="服务器端口"
)
parser.add_argument(
"--model",
type=str,
default=DEFAULT_MODEL_PATH,
help="模型名称或本地路径"
)
parser.add_argument(
"--tensor-parallel-size",
type=int,
default=DEFAULT_TP,
help="张量并行数量"
)
parser.add_argument(
"--max-model-len",
type=int,
default=DEFAULT_MAX_MODEL_LEN,
help="最大上下文长度"
)
parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=DEFAULT_GPU_UTIL,
help="GPU 显存利用率"
)
parser.add_argument(
"--tokenizer-mode",
type=str,
default=DEFAULT_TOKENIZER_MODE,
help="tokenizer 模式"
)
parser.add_argument(
"--served-model-name",
type=str,
default=DEFAULT_SERVED_NAME,
help="对外暴露的模型名称"
)
parser.add_argument(
"--no-vllm",
action="store_true",
help="禁用 vLLM,仅启动界面预览模式"
)
args = parser.parse_args()
# --no-vllm 参数已在文件开头处理,这里只是提示
if args.no_vllm and not ENABLE_VLLM:
print("[INFO] 已禁用 vLLM,运行在界面预览模式")
_set_llm_args(
model=args.model,
trust_remote_code=True,
tensor_parallel_size=args.tensor_parallel_size,
tokenizer_mode=args.tokenizer_mode,
max_model_len=args.max_model_len,
served_model_name=args.served_model_name,
gpu_memory_utilization=args.gpu_memory_utilization,
)
print("==========================================")
print("Step Audio 2 Gradio Chat")
if ENABLE_VLLM:
print(f"模式: vLLM 推理模式")
print(f"模型: {args.model}")
print(f"Tensor Parallel Size: {args.tensor_parallel_size}")
print(f"Max Model Len: {args.max_model_len}")
print(f"Tokenizer Mode: {args.tokenizer_mode}")
print(f"Served Model Name: {args.served_model_name}")
else:
print(f"模式: 界面预览模式(无 vLLM)")
print(f"Gradio 地址: http://{args.host}:{args.port}")
print("==========================================")
demo.queue().launch(
server_name=args.host,
server_port=args.port,
share=False
)