Spaces:
Running
Running
| # modules/ai_model.py | |
| import torch | |
| import base64 | |
| import requests | |
| from io import BytesIO | |
| import os | |
| from huggingface_hub import login | |
| from PIL import Image | |
| from transformers import AutoProcessor, Gemma3nForConditionalGeneration | |
| from utils.logger import log | |
| from typing import Union, Tuple | |
| class AIModel: | |
| def __init__(self, model_name: str = "google/gemma-3n-e2b-it"): | |
| self.model_name = model_name | |
| self.model = None | |
| self.processor = None | |
| # 设置缓存目录 | |
| self._setup_cache_dirs() | |
| self._initialize_model() | |
| def _setup_cache_dirs(self): | |
| """设置缓存目录""" | |
| cache_dir = "/app/.cache/huggingface" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # 设置环境变量 | |
| os.environ["HF_HOME"] = cache_dir | |
| os.environ["TRANSFORMERS_CACHE"] = cache_dir | |
| os.environ["HF_DATASETS_CACHE"] = cache_dir | |
| log.info(f"设置缓存目录: {cache_dir}") | |
| def _authenticate_hf(self): | |
| assitant_token = os.getenv("Assitant_tocken") | |
| token_to_use = assitant_token | |
| cache_dir = "/app/.cache/huggingface" | |
| login(token=token_to_use, add_to_git_credential=False) | |
| log.info("✅ HuggingFace 认证成功") | |
| return token_to_use | |
| def _initialize_model(self): | |
| """初始化Gemma模型""" | |
| try: | |
| log.info(f"正在加载模型: {self.model_name}") | |
| token = self._authenticate_hf() | |
| if not token: | |
| log.error("❌ 无法获取有效token,模型加载失败") | |
| self.model = None | |
| self.processor = None | |
| return | |
| cache_dir = "/app/.cache/huggingface" | |
| self.model = Gemma3nForConditionalGeneration.from_pretrained( | |
| self.model_name, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| token=token, | |
| cache_dir=cache_dir | |
| ).eval() | |
| self.processor = AutoProcessor.from_pretrained( | |
| self.model_name, | |
| trust_remote_code=True, | |
| token=token, | |
| cache_dir=cache_dir | |
| ) | |
| log.info("✅ Gemma AI 模型初始化成功") | |
| except Exception as e: | |
| log.error(f"❌ Gemma AI 模型初始化失败: {e}", exc_info=True) | |
| self.model = None | |
| self.processor = None | |
| def is_available(self) -> bool: | |
| return self.model is not None and self.processor is not None | |
| def detect_input_type(self, input_data: str) -> str: | |
| if not isinstance(input_data, str): | |
| return "text" | |
| image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"] | |
| if (input_data.startswith(("http://", "https://")) and | |
| any(input_data.lower().endswith(ext) for ext in image_extensions)): | |
| return "image" | |
| elif any(input_data.endswith(ext) for ext in image_extensions): | |
| return "image" | |
| elif input_data.startswith("data:image/"): | |
| return "image" | |
| audio_extensions = [".wav", ".mp3", ".m4a", ".ogg", ".flac"] | |
| if (input_data.startswith(("http://", "https://")) and | |
| any(input_data.lower().endswith(ext) for ext in audio_extensions)): | |
| return "audio" | |
| elif any(input_data.endswith(ext) for ext in audio_extensions): | |
| return "audio" | |
| return "text" | |
| def format_input(self, input_type: str, raw_input: str) -> Tuple[str, Union[str, Image.Image, None]]: | |
| if input_type == "image": | |
| try: | |
| if raw_input.startswith("data:image/"): | |
| header, encoded = raw_input.split(",", 1) | |
| image_data = base64.b64decode(encoded) | |
| image = Image.open(BytesIO(image_data)).convert("RGB") | |
| elif raw_input.startswith(("http://", "https://")): | |
| response = requests.get(raw_input, timeout=10) | |
| response.raise_for_status() | |
| image = Image.open(BytesIO(response.content)).convert("RGB") | |
| else: | |
| image = Image.open(raw_input).convert("RGB") | |
| log.info("✅ 图片加载成功") | |
| return input_type, image, "请描述这张图片,并基于图片内容提供旅游建议。" | |
| except Exception as e: | |
| log.error(f"❌ 图片加载失败: {e}") | |
| return "text", None, f"图片加载失败,请检查路径或URL。" | |
| elif input_type == "audio": | |
| log.warning("⚠️ 音频处理功能暂未实现") | |
| return "text", None, "抱歉,音频输入功能正在开发中。请使用文字描述您的需求。" | |
| else: # text | |
| return input_type, None, raw_input | |
| def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str,temperature: float = 0.5) -> str: | |
| try: | |
| # 截断过长的 prompt | |
| if len(prompt) > 500: | |
| prompt = prompt[:500] + "..." | |
| # 准备输入 (处理图片或文本) | |
| if input_type == "image" and isinstance(formatted_input, Image.Image): | |
| image_token = getattr(self.processor.tokenizer, 'image_token', '<image>') | |
| if image_token not in prompt: | |
| prompt = f"{image_token}\n{prompt}" | |
| inputs = self.processor( | |
| text=prompt, | |
| images=formatted_input, | |
| return_tensors="pt" | |
| ).to(self.model.device, dtype=torch.bfloat16) | |
| else: | |
| inputs = self.processor( | |
| text=prompt, | |
| return_tensors="pt" | |
| ).to(self.model.device, dtype=torch.bfloat16) | |
| if hasattr(inputs, 'input_ids') and inputs.input_ids.shape[-1] > 512: | |
| log.warning(f"⚠️ 截断过长输入: {inputs.input_ids.shape[-1]} -> 512") | |
| inputs.input_ids = inputs.input_ids[:, :512] | |
| if hasattr(inputs, 'attention_mask'): | |
| inputs.attention_mask = inputs.attention_mask[:, :512] | |
| with torch.inference_mode(): | |
| generation_args = { | |
| "max_new_tokens": 512, | |
| "pad_token_id": self.processor.tokenizer.eos_token_id, | |
| "use_cache": True | |
| } | |
| # 如果 temperature 接近0,使用贪心解码 (用于分类等确定性任务) | |
| if temperature < 1e-6: | |
| log.info("▶️ 使用贪心解码 (do_sample=False) 以获得确定性输出。") | |
| generation_args["do_sample"] = False | |
| # 否则,使用采样解码 (用于创造性生成任务) | |
| else: | |
| log.info(f"▶️ 使用采样解码 (do_sample=True),temperature={temperature}。") | |
| generation_args["do_sample"] = True | |
| generation_args["temperature"] = temperature | |
| generation_args["top_p"] = 0.9 # top_p 只在采样时有意义 | |
| # 使用构建好的参数字典来调用 generate | |
| outputs = self.model.generate( | |
| **inputs, | |
| **generation_args | |
| ) | |
| input_length = inputs.input_ids.shape[-1] | |
| generated_tokens = outputs[0][input_length:] | |
| decoded = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() | |
| return decoded if decoded else "我理解了您的问题,请告诉我更多具体信息。" | |
| except RuntimeError as e: | |
| if "shape" in str(e): | |
| log.error(f"❌ Tensor形状错误: {e}") | |
| return "输入处理遇到问题,请尝试简化您的问题。" | |
| raise e | |
| except Exception as e: | |
| log.error(f"❌ 模型推理失败: {e}", exc_info=True) | |
| return "抱歉,处理您的请求时遇到技术问题。" | |
| def chat_completion(self, model: str, messages: list, **kwargs) -> str: | |
| if not self.is_available(): | |
| log.error("模型未就绪,无法执行 chat_completion") | |
| if kwargs.get("response_format", {}).get("type") == "json_object": | |
| return '{"error": "Model not available"}' | |
| return "抱歉,AI 模型当前不可用。" | |
| full_prompt = "\n".join([msg.get("content", "") for msg in messages]) | |
| temperature = kwargs.get("temperature", 0.7) | |
| if kwargs.get("response_format", {}).get("type") == "json_object": | |
| # 在 prompt 末尾添加指令,强制模型输出 JSON | |
| full_prompt += "\n\n请注意:你的回答必须是一个严格的、不含任何额外解释和代码块标记的 JSON 对象。" | |
| # 对于JSON生成任务,使用较低的 temperature 以获得更稳定、确定性的结构 | |
| temperature = 0.1 | |
| log.debug(f"▶️ 执行 chat_completion (适配器), temperature={temperature}, prompt='{full_prompt[:100]}...'") | |
| return self.run_inference( | |
| input_type="text", | |
| formatted_input=None, | |
| prompt=full_prompt, | |
| temperature=temperature # 将处理后的 temperature 传递下去 | |
| ) | |
| def _build_limited_prompt(self, processed_text: str, context: str = "") -> str: | |
| """构建长度受限的prompt - 新增辅助方法""" | |
| # 限制输入长度 | |
| if len(processed_text) > 200: | |
| processed_text = processed_text[:200] + "..." | |
| if context and len(context) > 300: | |
| context = context[:300] + "..." | |
| # 保持你原有的prompt结构 | |
| if context: | |
| return ( | |
| f"你是一个专业的旅游助手。请基于以下背景信息,用中文友好地回答用户的问题。\n\n" | |
| f"--- 背景信息 ---\n{context}\n\n" | |
| f"--- 用户问题 ---\n{processed_text}\n\n" | |
| f"请提供专业、实用的旅游建议:" | |
| ) | |
| else: | |
| return ( | |
| f"你是一个专业的旅游助手。请用中文友好地回答用户的问题。\n\n" | |
| f"用户问题:{processed_text}\n\n" | |
| f"请提供专业、实用的旅游建议:" | |
| ) | |
| def generate(self, user_input: str, context: str = "") -> str: | |
| """主要的生成方法 - 保持原有逻辑""" | |
| if not self.is_available(): | |
| return "抱歉,AI 模型当前不可用,请稍后再试。" | |
| try: | |
| # 1. 检测输入类型 | |
| input_type = self.detect_input_type(user_input) | |
| log.info(f"检测到输入类型: {input_type}") | |
| # 2. 格式化输入 | |
| input_type, formatted_data, processed_text = self.format_input(input_type, user_input) | |
| # 3. 构建prompt - 使用你的原有结构 | |
| prompt = self._build_limited_prompt(processed_text, context) | |
| # 4. 执行推理 | |
| if input_type == "image" and formatted_data is not None: | |
| return self.run_inference("image", formatted_data, prompt) | |
| else: | |
| return self.run_inference("text", processed_text, prompt) | |
| except Exception as e: | |
| log.error(f"❌ 生成回复时发生错误: {e}", exc_info=True) | |
| return "抱歉,我在思考时遇到了点麻烦,请稍后再试。" |