File size: 11,916 Bytes
af60cba
 
 
 
 
a42280a
 
af60cba
 
 
 
 
 
 
 
 
 
a4da6d3
 
 
af60cba
 
8bacbbf
 
 
 
f30e96e
8bacbbf
 
 
 
f30e96e
8bacbbf
 
 
82e8be7
 
 
 
 
 
 
3154fce
82e8be7
 
a42280a
af60cba
a4da6d3
af60cba
 
a4da6d3
 
 
 
 
 
 
 
 
8bacbbf
 
af60cba
 
 
 
a4da6d3
 
2bef76a
af60cba
8bacbbf
af60cba
 
a42280a
a4da6d3
2bef76a
af60cba
8bacbbf
af60cba
8bacbbf
af60cba
 
 
 
 
 
6c0d50f
af60cba
 
 
6c0d50f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af60cba
 
 
6c0d50f
af60cba
 
 
 
 
 
 
 
 
 
 
 
6c0d50f
 
af60cba
6c0d50f
 
af60cba
 
6c0d50f
 
af60cba
 
6c0d50f
 
af60cba
6c0d50f
 
af60cba
7324283
6c0d50f
af60cba
96512ae
6c0d50f
 
 
96512ae
af60cba
6c0d50f
af60cba
ce08446
af60cba
6c0d50f
 
af60cba
 
 
 
6c0d50f
af60cba
 
 
ce08446
 
 
 
 
 
 
af60cba
96512ae
8d69a10
96512ae
 
 
 
 
8d69a10
96512ae
 
 
 
 
 
 
 
 
 
af60cba
6c0d50f
96512ae
af60cba
ce08446
 
8d69a10
7324283
6c0d50f
8d69a10
6c0d50f
 
 
 
 
af60cba
 
6c0d50f
3589840
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c0d50f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af60cba
 
6c0d50f
af60cba
 
6c0d50f
af60cba
 
 
 
6c0d50f
af60cba
 
6c0d50f
 
 
 
af60cba
 
 
 
 
6c0d50f
af60cba
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
# modules/ai_model.py
import torch
import base64
import requests
from io import BytesIO
import os
from huggingface_hub import login
from PIL import Image
from transformers import AutoProcessor, Gemma3nForConditionalGeneration
from utils.logger import log
from typing import Union, Tuple

class AIModel:
    def __init__(self, model_name: str = "google/gemma-3n-e2b-it"):
        self.model_name = model_name
        self.model = None
        self.processor = None
        
        # 设置缓存目录
        self._setup_cache_dirs()
        self._initialize_model()

    def _setup_cache_dirs(self):
        """设置缓存目录"""
        cache_dir = "/app/.cache/huggingface"
        os.makedirs(cache_dir, exist_ok=True)
        
        # 设置环境变量
        os.environ["HF_HOME"] = cache_dir
        os.environ["TRANSFORMERS_CACHE"] = cache_dir
        os.environ["HF_DATASETS_CACHE"] = cache_dir
        
        log.info(f"设置缓存目录: {cache_dir}")

    def _authenticate_hf(self):

        assitant_token = os.getenv("Assitant_tocken")
        token_to_use = assitant_token

        cache_dir = "/app/.cache/huggingface"
        login(token=token_to_use, add_to_git_credential=False)
        log.info("✅ HuggingFace 认证成功")
        return token_to_use



    def _initialize_model(self):
        """初始化Gemma模型"""
        try:
            log.info(f"正在加载模型: {self.model_name}")
            
            token = self._authenticate_hf()
            
            if not token:
                log.error("❌ 无法获取有效token,模型加载失败")
                self.model = None
                self.processor = None
                return
            
            cache_dir = "/app/.cache/huggingface"
            
            self.model = Gemma3nForConditionalGeneration.from_pretrained(
                self.model_name,
                device_map="auto",
                torch_dtype=torch.bfloat16,
                trust_remote_code=True,
                token=token,
                cache_dir=cache_dir
            ).eval()
            
            self.processor = AutoProcessor.from_pretrained(
                self.model_name,
                trust_remote_code=True,
                token=token,
                cache_dir=cache_dir
            )
            
            log.info("✅ Gemma AI 模型初始化成功")
            
        except Exception as e:
            log.error(f"❌ Gemma AI 模型初始化失败: {e}", exc_info=True)
            self.model = None
            self.processor = None

    def is_available(self) -> bool:
        
        return self.model is not None and self.processor is not None

    def detect_input_type(self, input_data: str) -> str:

        if not isinstance(input_data, str):
            return "text"
    
        image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"]
        if (input_data.startswith(("http://", "https://")) and 
            any(input_data.lower().endswith(ext) for ext in image_extensions)):
            return "image"
        elif any(input_data.endswith(ext) for ext in image_extensions):
            return "image"
        elif input_data.startswith("data:image/"):
            return "image"

        audio_extensions = [".wav", ".mp3", ".m4a", ".ogg", ".flac"]
        if (input_data.startswith(("http://", "https://")) and 
            any(input_data.lower().endswith(ext) for ext in audio_extensions)):
            return "audio"
        elif any(input_data.endswith(ext) for ext in audio_extensions):
            return "audio"
    
        return "text"

    def format_input(self, input_type: str, raw_input: str) -> Tuple[str, Union[str, Image.Image, None]]:
    
        if input_type == "image":
            try:
                if raw_input.startswith("data:image/"):
                    header, encoded = raw_input.split(",", 1)
                    image_data = base64.b64decode(encoded)
                    image = Image.open(BytesIO(image_data)).convert("RGB")
                elif raw_input.startswith(("http://", "https://")):
                    response = requests.get(raw_input, timeout=10)
                    response.raise_for_status()
                    image = Image.open(BytesIO(response.content)).convert("RGB")
                else:
                
                    image = Image.open(raw_input).convert("RGB")
            
                log.info("✅ 图片加载成功")
                return input_type, image, "请描述这张图片,并基于图片内容提供旅游建议。"
            
            except Exception as e:
                log.error(f"❌ 图片加载失败: {e}")
                return "text", None, f"图片加载失败,请检查路径或URL。"
            
        elif input_type == "audio":
            
            log.warning("⚠️ 音频处理功能暂未实现")
            return "text", None, "抱歉,音频输入功能正在开发中。请使用文字描述您的需求。"
        
        else:  # text
            return input_type, None, raw_input

    def run_inference(self, input_type: str, formatted_input: Union[str, Image.Image], prompt: str,temperature: float = 0.5) -> str:
   
        try:
            # 截断过长的 prompt
            if len(prompt) > 500:
                prompt = prompt[:500] + "..."
            
            # 准备输入 (处理图片或文本)
            if input_type == "image" and isinstance(formatted_input, Image.Image):
                image_token = getattr(self.processor.tokenizer, 'image_token', '<image>')
                if image_token not in prompt:
                    prompt = f"{image_token}\n{prompt}"
                inputs = self.processor(
                    text=prompt,
                    images=formatted_input,
                    return_tensors="pt"
                ).to(self.model.device, dtype=torch.bfloat16)
            else:
                inputs = self.processor(
                    text=prompt,
                    return_tensors="pt"
                ).to(self.model.device, dtype=torch.bfloat16)

            if hasattr(inputs, 'input_ids') and inputs.input_ids.shape[-1] > 512:
                log.warning(f"⚠️ 截断过长输入: {inputs.input_ids.shape[-1]} -> 512")
                inputs.input_ids = inputs.input_ids[:, :512]
                if hasattr(inputs, 'attention_mask'):
                    inputs.attention_mask = inputs.attention_mask[:, :512]

            
            with torch.inference_mode():
                generation_args = {
                    "max_new_tokens": 512,
                    "pad_token_id": self.processor.tokenizer.eos_token_id,
                    "use_cache": True
                }

                # 如果 temperature 接近0,使用贪心解码 (用于分类等确定性任务)
                if temperature < 1e-6: 
                    log.info("▶️ 使用贪心解码 (do_sample=False) 以获得确定性输出。")
                    generation_args["do_sample"] = False
                # 否则,使用采样解码 (用于创造性生成任务)
                else:
                    log.info(f"▶️ 使用采样解码 (do_sample=True),temperature={temperature}。")
                    generation_args["do_sample"] = True
                    generation_args["temperature"] = temperature
                    generation_args["top_p"] = 0.9 # top_p 只在采样时有意义

                # 使用构建好的参数字典来调用 generate
                outputs = self.model.generate(
                    **inputs,
                    **generation_args
                )
            input_length = inputs.input_ids.shape[-1]
            generated_tokens = outputs[0][input_length:]
            decoded = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

            return decoded if decoded else "我理解了您的问题,请告诉我更多具体信息。"

        except RuntimeError as e:
            if "shape" in str(e):
                log.error(f"❌ Tensor形状错误: {e}")
                return "输入处理遇到问题,请尝试简化您的问题。"
            raise e
        except Exception as e:
            log.error(f"❌ 模型推理失败: {e}", exc_info=True)
            return "抱歉,处理您的请求时遇到技术问题。"
        
    def chat_completion(self, model: str, messages: list, **kwargs) -> str:
            
        if not self.is_available():
            log.error("模型未就绪,无法执行 chat_completion")
            if kwargs.get("response_format", {}).get("type") == "json_object":
                return '{"error": "Model not available"}'
            return "抱歉,AI 模型当前不可用。"

        full_prompt = "\n".join([msg.get("content", "") for msg in messages])

        temperature = kwargs.get("temperature", 0.7)

        if kwargs.get("response_format", {}).get("type") == "json_object":
            # 在 prompt 末尾添加指令,强制模型输出 JSON
            full_prompt += "\n\n请注意:你的回答必须是一个严格的、不含任何额外解释和代码块标记的 JSON 对象。"
            # 对于JSON生成任务,使用较低的 temperature 以获得更稳定、确定性的结构
            temperature = 0.1 

        log.debug(f"▶️ 执行 chat_completion (适配器), temperature={temperature}, prompt='{full_prompt[:100]}...'")

        return self.run_inference(
            input_type="text",
            formatted_input=None,
            prompt=full_prompt,
            temperature=temperature # 将处理后的 temperature 传递下去
        )


    def _build_limited_prompt(self, processed_text: str, context: str = "") -> str:
        """构建长度受限的prompt - 新增辅助方法"""
        # 限制输入长度
        if len(processed_text) > 200:
            processed_text = processed_text[:200] + "..."
    
        if context and len(context) > 300:
            context = context[:300] + "..."
    
        # 保持你原有的prompt结构
        if context:
            return (
                f"你是一个专业的旅游助手。请基于以下背景信息,用中文友好地回答用户的问题。\n\n"
                f"--- 背景信息 ---\n{context}\n\n"
                f"--- 用户问题 ---\n{processed_text}\n\n"
                f"请提供专业、实用的旅游建议:"
            )
        else:
            return (
                f"你是一个专业的旅游助手。请用中文友好地回答用户的问题。\n\n"
                f"用户问题:{processed_text}\n\n"
                f"请提供专业、实用的旅游建议:"
            )

    def generate(self, user_input: str, context: str = "") -> str:
        """主要的生成方法 - 保持原有逻辑"""
        if not self.is_available():
            return "抱歉,AI 模型当前不可用,请稍后再试。"
    
        try:
            # 1. 检测输入类型
            input_type = self.detect_input_type(user_input)
            log.info(f"检测到输入类型: {input_type}")
        
            # 2. 格式化输入
            input_type, formatted_data, processed_text = self.format_input(input_type, user_input)
        
            # 3. 构建prompt - 使用你的原有结构
            prompt = self._build_limited_prompt(processed_text, context)
        
            # 4. 执行推理
            if input_type == "image" and formatted_data is not None:
                return self.run_inference("image", formatted_data, prompt)
            else:
                return self.run_inference("text", processed_text, prompt)
            
        except Exception as e:
            log.error(f"❌ 生成回复时发生错误: {e}", exc_info=True)
            return "抱歉,我在思考时遇到了点麻烦,请稍后再试。"