Travel_Assistant / modules /ai_model.py
Eliot0110's picture
fix: decoder
ce08446
raw
history blame
11.9 kB
# modules/ai_model.py
import torch
import base64
import requests
from io import BytesIO
import os
from huggingface_hub import login
from PIL import Image
from transformers import AutoProcessor, Gemma3nForConditionalGeneration
from utils.logger import log
from typing import Union, Tuple
class AIModel:
def __init__(self, model_name: str = "google/gemma-3n-e2b-it"):
self.model_name = model_name
self.model = None
self.processor = None
# 设置缓存目录
self._setup_cache_dirs()
self._initialize_model()
def _setup_cache_dirs(self):
"""设置缓存目录"""
cache_dir = "/app/.cache/huggingface"
os.makedirs(cache_dir, exist_ok=True)
# 设置环境变量
os.environ["HF_HOME"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
os.environ["HF_DATASETS_CACHE"] = cache_dir
log.info(f"设置缓存目录: {cache_dir}")
def _authenticate_hf(self):
assitant_token = os.getenv("Assitant_tocken")
token_to_use = assitant_token
cache_dir = "/app/.cache/huggingface"
login(token=token_to_use, add_to_git_credential=False)
log.info("✅ HuggingFace 认证成功")
return token_to_use
def _initialize_model(self):
"""初始化Gemma模型"""
try:
log.info(f"正在加载模型: {self.model_name}")
token = self._authenticate_hf()
if not token:
log.error("❌ 无法获取有效token,模型加载失败")
self.model = None
self.processor = None
return
cache_dir = "/app/.cache/huggingface"
self.model = Gemma3nForConditionalGeneration.from_pretrained(
self.model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
token=token,
cache_dir=cache_dir
).eval()
self.processor = AutoProcessor.from_pretrained(
self.model_name,
trust_remote_code=True,
token=token,
cache_dir=cache_dir
)
log.info("✅ Gemma AI 模型初始化成功")
except Exception as e:
log.error(f"❌ Gemma AI 模型初始化失败: {e}", exc_info=True)
self.model = None
self.processor = None
def is_available(self) -> bool:
return self.model is not None and self.processor is not None
def detect_input_type(self, input_data: str) -> str:
if not isinstance(input_data, str):
return "text"
image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"]
if (input_data.startswith(("http://", "https://")) and
any(input_data.lower().endswith(ext) for ext in image_extensions)):
return "image"
elif any(input_data.endswith(ext) for ext in image_extensions):
return "image"
elif input_data.startswith("data:image/"):
return "image"
audio_extensions = [".wav", ".mp3", ".m4a", ".ogg", ".flac"]
if (input_data.startswith(("http://", "https://")) and
any(input_data.lower().endswith(ext) for ext in audio_extensions)):
return "audio"
elif any(input_data.endswith(ext) for ext in audio_extensions):
return "audio"
return "text"
def format_input(self, input_type: str, raw_input: str) -> Tuple[str, Union[str, Image.Image, None]]:
if input_type == "image":
try:
if raw_input.startswith("data:image/"):
header, encoded = raw_input.split(",", 1)
image_data = base64.b64decode(encoded)
image = Image.open(BytesIO(image_data)).convert("RGB")
elif raw_input.startswith(("http://", "https://")):
response = requests.get(raw_input, timeout=10)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(raw_input).convert("RGB")
log.info("✅ 图片加载成功")
return input_type, image, "请描述这张图片,并基于图片内容提供旅游建议。"
except Exception as e:
log.error(f"❌ 图片加载失败: {e}")
return "text", None, f"图片加载失败,请检查路径或URL。"
elif input_type == "audio":
log.warning("⚠️ 音频处理功能暂未实现")
return "text", None, "抱歉,音频输入功能正在开发中。请使用文字描述您的需求。"
else: # text
return input_type, None, raw_input
def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str,temperature: float = 0.5) -> str:
try:
# 截断过长的 prompt
if len(prompt) > 500:
prompt = prompt[:500] + "..."
# 准备输入 (处理图片或文本)
if input_type == "image" and isinstance(formatted_input, Image.Image):
image_token = getattr(self.processor.tokenizer, 'image_token', '<image>')
if image_token not in prompt:
prompt = f"{image_token}\n{prompt}"
inputs = self.processor(
text=prompt,
images=formatted_input,
return_tensors="pt"
).to(self.model.device, dtype=torch.bfloat16)
else:
inputs = self.processor(
text=prompt,
return_tensors="pt"
).to(self.model.device, dtype=torch.bfloat16)
if hasattr(inputs, 'input_ids') and inputs.input_ids.shape[-1] > 512:
log.warning(f"⚠️ 截断过长输入: {inputs.input_ids.shape[-1]} -> 512")
inputs.input_ids = inputs.input_ids[:, :512]
if hasattr(inputs, 'attention_mask'):
inputs.attention_mask = inputs.attention_mask[:, :512]
with torch.inference_mode():
generation_args = {
"max_new_tokens": 512,
"pad_token_id": self.processor.tokenizer.eos_token_id,
"use_cache": True
}
# 如果 temperature 接近0,使用贪心解码 (用于分类等确定性任务)
if temperature < 1e-6:
log.info("▶️ 使用贪心解码 (do_sample=False) 以获得确定性输出。")
generation_args["do_sample"] = False
# 否则,使用采样解码 (用于创造性生成任务)
else:
log.info(f"▶️ 使用采样解码 (do_sample=True),temperature={temperature}。")
generation_args["do_sample"] = True
generation_args["temperature"] = temperature
generation_args["top_p"] = 0.9 # top_p 只在采样时有意义
# 使用构建好的参数字典来调用 generate
outputs = self.model.generate(
**inputs,
**generation_args
)
input_length = inputs.input_ids.shape[-1]
generated_tokens = outputs[0][input_length:]
decoded = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
return decoded if decoded else "我理解了您的问题,请告诉我更多具体信息。"
except RuntimeError as e:
if "shape" in str(e):
log.error(f"❌ Tensor形状错误: {e}")
return "输入处理遇到问题,请尝试简化您的问题。"
raise e
except Exception as e:
log.error(f"❌ 模型推理失败: {e}", exc_info=True)
return "抱歉,处理您的请求时遇到技术问题。"
def chat_completion(self, model: str, messages: list, **kwargs) -> str:
if not self.is_available():
log.error("模型未就绪,无法执行 chat_completion")
if kwargs.get("response_format", {}).get("type") == "json_object":
return '{"error": "Model not available"}'
return "抱歉,AI 模型当前不可用。"
full_prompt = "\n".join([msg.get("content", "") for msg in messages])
temperature = kwargs.get("temperature", 0.7)
if kwargs.get("response_format", {}).get("type") == "json_object":
# 在 prompt 末尾添加指令,强制模型输出 JSON
full_prompt += "\n\n请注意:你的回答必须是一个严格的、不含任何额外解释和代码块标记的 JSON 对象。"
# 对于JSON生成任务,使用较低的 temperature 以获得更稳定、确定性的结构
temperature = 0.1
log.debug(f"▶️ 执行 chat_completion (适配器), temperature={temperature}, prompt='{full_prompt[:100]}...'")
return self.run_inference(
input_type="text",
formatted_input=None,
prompt=full_prompt,
temperature=temperature # 将处理后的 temperature 传递下去
)
def _build_limited_prompt(self, processed_text: str, context: str = "") -> str:
"""构建长度受限的prompt - 新增辅助方法"""
# 限制输入长度
if len(processed_text) > 200:
processed_text = processed_text[:200] + "..."
if context and len(context) > 300:
context = context[:300] + "..."
# 保持你原有的prompt结构
if context:
return (
f"你是一个专业的旅游助手。请基于以下背景信息,用中文友好地回答用户的问题。\n\n"
f"--- 背景信息 ---\n{context}\n\n"
f"--- 用户问题 ---\n{processed_text}\n\n"
f"请提供专业、实用的旅游建议:"
)
else:
return (
f"你是一个专业的旅游助手。请用中文友好地回答用户的问题。\n\n"
f"用户问题:{processed_text}\n\n"
f"请提供专业、实用的旅游建议:"
)
def generate(self, user_input: str, context: str = "") -> str:
"""主要的生成方法 - 保持原有逻辑"""
if not self.is_available():
return "抱歉,AI 模型当前不可用,请稍后再试。"
try:
# 1. 检测输入类型
input_type = self.detect_input_type(user_input)
log.info(f"检测到输入类型: {input_type}")
# 2. 格式化输入
input_type, formatted_data, processed_text = self.format_input(input_type, user_input)
# 3. 构建prompt - 使用你的原有结构
prompt = self._build_limited_prompt(processed_text, context)
# 4. 执行推理
if input_type == "image" and formatted_data is not None:
return self.run_inference("image", formatted_data, prompt)
else:
return self.run_inference("text", processed_text, prompt)
except Exception as e:
log.error(f"❌ 生成回复时发生错误: {e}", exc_info=True)
return "抱歉,我在思考时遇到了点麻烦,请稍后再试。"