#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Gradio 多模态聊天界面:直接在 app.py 内部调用 vLLM.LLM 进行推理 """ import base64 import os import sys import threading import time from typing import Optional, Tuple import gradio as gr # 检查命令行参数,在导入 vllm 之前确定是否启用 # 这样可以在没有安装 vllm 的情况下运行界面预览 if "--no-vllm" in sys.argv: os.environ["ENABLE_VLLM"] = "false" # 检查是否启用 vLLM 模式 ENABLE_VLLM = os.getenv("ENABLE_VLLM", "true").lower() in ("true", "1", "yes") if ENABLE_VLLM: try: from vllm import LLM, SamplingParams except ImportError: print("[WARNING] 无法导入 vllm,自动切换到界面预览模式") print("[INFO] 如需使用 vLLM,请先安装: pip install vllm") ENABLE_VLLM = False LLM = None SamplingParams = None else: LLM = None SamplingParams = None print("[INFO] 运行在界面预览模式,不加载 vLLM") # 默认配置,可通过环境变量或 CLI 覆盖 DEFAULT_MODEL_ID = os.getenv("MODEL_NAME", "stepfun-ai/Step-Audio-2-mini-Think") DEFAULT_MODEL_PATH = os.getenv("MODEL_PATH", DEFAULT_MODEL_ID) DEFAULT_TP = int(os.getenv("TENSOR_PARALLEL_SIZE", "4")) DEFAULT_MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192")) DEFAULT_GPU_UTIL = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.9")) DEFAULT_TOKENIZER_MODE = os.getenv("TOKENIZER_MODE", "step_audio_2") DEFAULT_SERVED_NAME = os.getenv("SERVED_MODEL_NAME", "step-audio-2-mini-think") _llm: Optional[LLM] = None _llm_lock = threading.Lock() LLM_ARGS = { "model": DEFAULT_MODEL_PATH, "trust_remote_code": True, "tensor_parallel_size": DEFAULT_TP, "tokenizer_mode": DEFAULT_TOKENIZER_MODE, "max_model_len": DEFAULT_MAX_MODEL_LEN, "served_model_name": DEFAULT_SERVED_NAME, "gpu_memory_utilization": DEFAULT_GPU_UTIL, } def encode_audio_to_base64(audio_path: Optional[str]) -> Optional[dict]: """将音频文件编码为 base64""" if audio_path is None: return None try: with open(audio_path, "rb") as audio_file: audio_data = audio_file.read() audio_base64 = base64.b64encode(audio_data).decode('utf-8') # 尝试从文件扩展名推断格式 ext = os.path.splitext(audio_path)[1].lower().lstrip('.') if not ext: ext = "wav" # 默认格式 return { "data": audio_base64, "format": ext } except Exception as e: print(f"Error encoding audio: {e}") return None def format_messages( system_prompt: str, chat_history: list, user_text: str, audio_file: Optional[str] ) -> list: """格式化消息为 OpenAI API 格式""" messages = [] # 添加 system prompt if system_prompt and system_prompt.strip(): messages.append({ "role": "system", "content": system_prompt.strip() }) # 添加历史对话 for human, assistant in chat_history: if human: messages.append({"role": "user", "content": human}) if assistant: messages.append({"role": "assistant", "content": assistant}) # 添加当前用户输入 content_parts = [] # 添加文本输入 if user_text and user_text.strip(): content_parts.append({ "type": "text", "text": user_text.strip() }) # 添加音频输入 if audio_file: audio_data = encode_audio_to_base64(audio_file) if audio_data: content_parts.append({ "type": "input_audio", "input_audio": audio_data }) if content_parts: # 如果只有一个文本部分,直接使用字符串 if len(content_parts) == 1 and content_parts[0]["type"] == "text": messages.append({ "role": "user", "content": content_parts[0]["text"] }) else: messages.append({ "role": "user", "content": content_parts }) return messages def chat_predict( system_prompt: str, user_text: str, audio_file: Optional[str], chat_history: list, max_tokens: int, temperature: float, top_p: float ) -> Tuple[list, str]: """调用本地 vLLM LLM 完成推理""" if not user_text and not audio_file: return chat_history, "⚠ 请提供文本或音频输入" # 如果是预览模式,返回模拟响应 if not ENABLE_VLLM: user_display = user_text if user_text else "[音频输入]" mock_response = f"这是一个模拟回复。您说: {user_text[:50] if user_text else '音频'}" chat_history.append((user_display, mock_response)) return chat_history, "" messages = format_messages(system_prompt, chat_history, user_text, audio_file) if not messages: return chat_history, "⚠ 无有效输入" try: llm = _get_llm() sampling_params = SamplingParams( max_tokens=max_tokens, temperature=temperature, top_p=top_p, ) start_time = time.time() outputs = llm.chat(messages, sampling_params=sampling_params, use_tqdm=False) latency = time.time() - start_time if not outputs or not outputs[0].outputs: return chat_history, "⚠ 模型未返回结果" assistant_message = outputs[0].outputs[0].text user_display = user_text if user_text else "[音频输入]" chat_history.append((user_display, assistant_message)) return chat_history, "" except Exception as e: import traceback traceback.print_exc() return chat_history, "" def _get_llm() -> LLM: """单例方式初始化 LLM""" if not ENABLE_VLLM: raise RuntimeError("vLLM 未启用,无法加载模型") global _llm if _llm is not None: return _llm with _llm_lock: if _llm is not None: return _llm print(f"[LLM] 初始化中,参数: {LLM_ARGS}") _llm = LLM(**LLM_ARGS) return _llm def _set_llm_args(**kwargs) -> None: """更新 LLM 初始化参数""" global LLM_ARGS, _llm LLM_ARGS = kwargs _llm = None # 确保使用新配置重新加载 # 构建 Gradio 界面 with gr.Blocks(title="Step Audio 2 Chat", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # Step Audio R1 Demo """ ) with gr.Row(): # 左侧:参数配置 with gr.Column(scale=1): gr.Markdown("### 配置") system_prompt = gr.Textbox( label="System Prompt", placeholder="输入系统提示词...", lines=4, value="You are an expert in audio analysis, please analyze the audio content and answer the questions accurately" ) with gr.Row(): max_tokens = gr.Slider( label="Max Tokens", minimum=1, maximum=16384, value=8192, step=1 ) with gr.Row(): temperature = gr.Slider( label="Temperature", minimum=0.0, maximum=2.0, value=0.7, step=0.1 ) top_p = gr.Slider( label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.05 ) # 右侧:对话和输入 with gr.Column(scale=1): gr.Markdown("### 对话") chatbot = gr.Chatbot( label="聊天历史", height=400, show_copy_button=True ) user_text = gr.Textbox( label="文本输入", placeholder="输入您的消息...", lines=2 ) audio_file = gr.Audio( label="音频输入", type="filepath", sources=["microphone", "upload"] ) with gr.Row(): submit_btn = gr.Button("提交", variant="primary", size="lg") clear_btn = gr.Button("清空", variant="secondary") status_text = gr.Textbox(label="状态", interactive=False, visible=False) # 事件绑定 submit_btn.click( fn=chat_predict, inputs=[ system_prompt, user_text, audio_file, chatbot, max_tokens, temperature, top_p ], outputs=[chatbot, status_text] ) clear_btn.click( fn=lambda: ([], "", None), outputs=[chatbot, user_text, audio_file] ) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Step Audio 2 Gradio Chat Interface") parser.add_argument( "--host", type=str, default="0.0.0.0", help="服务器主机地址" ) parser.add_argument( "--port", type=int, default=7860, help="服务器端口" ) parser.add_argument( "--model", type=str, default=DEFAULT_MODEL_PATH, help="模型名称或本地路径" ) parser.add_argument( "--tensor-parallel-size", type=int, default=DEFAULT_TP, help="张量并行数量" ) parser.add_argument( "--max-model-len", type=int, default=DEFAULT_MAX_MODEL_LEN, help="最大上下文长度" ) parser.add_argument( "--gpu-memory-utilization", type=float, default=DEFAULT_GPU_UTIL, help="GPU 显存利用率" ) parser.add_argument( "--tokenizer-mode", type=str, default=DEFAULT_TOKENIZER_MODE, help="tokenizer 模式" ) parser.add_argument( "--served-model-name", type=str, default=DEFAULT_SERVED_NAME, help="对外暴露的模型名称" ) parser.add_argument( "--no-vllm", action="store_true", help="禁用 vLLM,仅启动界面预览模式" ) args = parser.parse_args() # --no-vllm 参数已在文件开头处理,这里只是提示 if args.no_vllm and not ENABLE_VLLM: print("[INFO] 已禁用 vLLM,运行在界面预览模式") _set_llm_args( model=args.model, trust_remote_code=True, tensor_parallel_size=args.tensor_parallel_size, tokenizer_mode=args.tokenizer_mode, max_model_len=args.max_model_len, served_model_name=args.served_model_name, gpu_memory_utilization=args.gpu_memory_utilization, ) print("==========================================") print("Step Audio 2 Gradio Chat") if ENABLE_VLLM: print(f"模式: vLLM 推理模式") print(f"模型: {args.model}") print(f"Tensor Parallel Size: {args.tensor_parallel_size}") print(f"Max Model Len: {args.max_model_len}") print(f"Tokenizer Mode: {args.tokenizer_mode}") print(f"Served Model Name: {args.served_model_name}") else: print(f"模式: 界面预览模式(无 vLLM)") print(f"Gradio 地址: http://{args.host}:{args.port}") print("==========================================") demo.queue().launch( server_name=args.host, server_port=args.port, share=False )