# modules/ai_model.py import torch import base64 import requests from io import BytesIO from PIL import Image from transformers import AutoProcessor, Gemma3nForConditionalGeneration from utils.logger import log from typing import Union, Tuple class AIModel: def __init__(self, model_name: str = "google/gemma-3n-e2b-it"): self.model_name = model_name self.model = None self.processor = None self._initialize_model() def _initialize_model(self): """初始化Gemma模型 - 基于官方调用方式""" try: log.info(f"正在加载模型: {self.model_name}") self.model = Gemma3nForConditionalGeneration.from_pretrained( self.model_name, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True ).eval() self.processor = AutoProcessor.from_pretrained( self.model_name, trust_remote_code=True ) log.info("✅ Gemma AI 模型初始化成功") except Exception as e: log.error(f"❌ Gemma AI 模型初始化失败: {e}", exc_info=True) self.model = None self.processor = None def is_available(self) -> bool: """检查模型是否可用""" return self.model is not None and self.processor is not None def detect_input_type(self, input_data: str) -> str: """检测输入类型:图片/音频/文字""" if isinstance(input_data, str): # 检查是否为图片URL或路径 if (input_data.startswith(("http://", "https://")) and any(input_data.lower().endswith(ext) for ext in [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"])): return "image" elif input_data.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp")): return "image" # 检查是否为音频URL或路径 elif (input_data.startswith(("http://", "https://")) and any(input_data.lower().endswith(ext) for ext in [".wav", ".mp3", ".m4a", ".ogg"])): return "audio" elif input_data.endswith((".wav", ".mp3", ".m4a", ".ogg")): return "audio" # 检查是否为base64编码的图片 elif input_data.startswith("data:image/"): return "image" return "text" def format_input(self, input_type: str, raw_input: str) -> Tuple[str, Union[str, Image.Image, None]]: """格式化输入数据""" formatted_data = None processed_text = raw_input if input_type == "image": try: if raw_input.startswith("data:image/"): # 处理base64编码的图片 header, encoded = raw_input.split(",", 1) image_data = base64.b64decode(encoded) image = Image.open(BytesIO(image_data)).convert("RGB") elif raw_input.startswith(("http://", "https://")): # 处理图片URL response = requests.get(raw_input, timeout=10) response.raise_for_status() image = Image.open(BytesIO(response.content)).convert("RGB") else: # 处理本地图片路径 image = Image.open(raw_input).convert("RGB") formatted_data = image processed_text = "请描述这张图片,并基于图片内容提供旅游建议。" log.info("✅ 图片加载成功") except Exception as e: log.error(f"❌ 图片加载失败: {e}") return "text", f"图片加载失败,请检查图片路径或URL。原始输入: {raw_input}" elif input_type == "audio": # 音频处理逻辑(如果需要的话,目前先返回提示) log.warning("⚠️ 音频处理功能暂未实现") processed_text = "抱歉,音频输入功能正在开发中。请使用文字描述您的需求。" elif input_type == "text": # 文字输入直接使用 formatted_data = None processed_text = raw_input return input_type, formatted_data, processed_text def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str) -> str: """执行模型推理""" try: if input_type == "image" and isinstance(formatted_input, Image.Image): # 图片输入处理 image_token = self.processor.tokenizer.image_token if image_token not in prompt: prompt = f"{image_token}\n{prompt}" inputs = self.processor( text=prompt, images=formatted_input, return_tensors="pt" ).to(self.model.device, dtype=torch.bfloat16) else: # 纯文本输入处理 inputs = self.processor( text=prompt, return_tensors="pt" ).to(self.model.device, dtype=torch.bfloat16) # 生成响应 with torch.inference_mode(): outputs = self.model.generate( **inputs, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=self.processor.tokenizer.eos_token_id ) # 解码输出 decoded = self.processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() # 清理输出,移除输入的prompt部分 if prompt in decoded: decoded = decoded.replace(prompt, "").strip() return decoded except Exception as e: log.error(f"❌ 模型推理失败: {e}", exc_info=True) return "抱歉,我在处理您的请求时遇到了技术问题,请稍后再试。" def generate(self, user_input: str, context: str = "") -> str: """主要的生成方法 - 支持多模态输入""" if not self.is_available(): return "抱歉,AI 模型当前不可用,请稍后再试。" try: # 1. 检测输入类型 input_type = self.detect_input_type(user_input) log.info(f"检测到输入类型: {input_type}") # 2. 格式化输入 input_type, formatted_data, processed_text = self.format_input(input_type, user_input) # 3. 构建prompt if context: prompt = ( f"你是一个专业的旅游助手。请基于以下背景信息,用中文友好地回答用户的问题。\n\n" f"--- 背景信息 ---\n{context}\n\n" f"--- 用户问题 ---\n{processed_text}\n\n" f"请提供专业、实用的旅游建议:" ) else: prompt = ( f"你是一个专业的旅游助手。请用中文友好地回答用户的问题。\n\n" f"用户问题:{processed_text}\n\n" f"请提供专业、实用的旅游建议:" ) # 4. 执行推理 if input_type == "image" and formatted_data is not None: return self.run_inference("image", formatted_data, prompt) else: return self.run_inference("text", processed_text, prompt) except Exception as e: log.error(f"❌ 生成回复时发生错误: {e}", exc_info=True) return "抱歉,我在思考时遇到了点麻烦,请稍后再试。"