import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch # ------------------------------- # 模型加载 # ------------------------------- MODEL_ID = "caobin/llm-caobin" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", # CPU 上自动映射到 CPU trust_remote_code=True ) # ------------------------------- # 工具函数:清理历史 # ------------------------------- def clean_history(history): """ 将历史消息的 content 转为字符串,避免 list 导致空回答 """ cleaned = [] for msg in history: content = msg['content'] if isinstance(content, list): # list -> str content = " ".join([str(c) for c in content]) cleaned.append({"role": msg['role'], "content": content}) return cleaned # ------------------------------- # 聊天函数 # ------------------------------- def chat_fn(message, history): history = clean_history(history) recent_history = history[-6:] # 保留最近 3 轮对话 full_prompt = "" for msg in recent_history: if msg["role"] == "user": full_prompt += f"<|user|>{msg['content']}<|assistant|>" elif msg["role"] == "assistant": full_prompt += msg['content'] # 当前用户问题 full_prompt += f"<|user|>{message}<|assistant|>" # tokenizer -> tensor inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) # 生成回答 output_ids = model.generate( **inputs, max_new_tokens=128, temperature=0.3, top_p=0.3, do_sample=True, ) output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) if "<|assistant|>" in output_text: output_text = output_text.split("<|assistant|>")[-1] return output_text.strip() # ------------------------------- # Gradio UI # ------------------------------- with gr.Blocks(title="caobin LLM Chatbot") as demo: gr.Markdown("# 🤖 caobin's AI assistant") chatbot = gr.Chatbot(height=450) msg = gr.Textbox(label="输入你的问题") def respond(message, chat_history): response = chat_fn(message, chat_history) # 用字典格式添加消息 chat_history.append({"role": "user", "content": message}) chat_history.append({"role": "assistant", "content": response}) return "", chat_history msg.submit(respond, [msg, chatbot], [msg, chatbot]) # ------------------------------- # 启动 # ------------------------------- demo.launch()