| | import os
|
| | import torch
|
| | import threading
|
| | import re
|
| | from typing import List, Dict, Any, Optional
|
| | from transformers import AutoModelForCausalLM, AutoTokenizer
|
| |
|
| |
|
| | def resolve_model_path(model_id: str) -> str:
|
| | """
|
| | 解析模型路径:如果是本地路径则直接返回,否则从 ModelScope 下载。
|
| |
|
| | 参数:
|
| | model_id: 模型标识符(ModelScope model_id)或本地目录路径
|
| | 返回:
|
| | 模型的本地目录路径
|
| | """
|
| | if os.path.isdir(model_id):
|
| | print(f"使用本地模型: {model_id}")
|
| | return model_id
|
| |
|
| | print(f"从 ModelScope 下载模型: {model_id} ...")
|
| | from modelscope import snapshot_download
|
| | local_path = snapshot_download(model_id)
|
| | print(f"模型已下载到: {local_path}")
|
| | return local_path
|
| |
|
| |
|
| | class VisionLanguageModel:
|
| | """
|
| | Qwen3-VL 视觉语言模型封装,用于图片内容描述。
|
| | 支持两种运行模式:
|
| | - 在线 API 模式: 通过 DashScope OpenAI 兼容接口调用(速度快,无需 GPU)
|
| | - 本地模型模式: 加载模型到本地 GPU/CPU 推理
|
| | """
|
| |
|
| |
|
| | DEFAULT_PROMPT = (
|
| | "请按以下结构如实描述这张图片,仅提取事实内容,不要做任何风险分析或价值判断:\n\n"
|
| | "【图片文字】逐字提取图片中出现的所有文字(包括标题、正文、水印、"
|
| | "对话气泡、标语、商标等),保持原文不做任何修改。如果没有文字请注明。\n\n"
|
| | "【视觉内容】描述场景、人物、动作、表情、物体、符号等所有可见元素。"
|
| | "如果包含敏感、暴力、色情等内容,请如实描述,不要回避。\n\n"
|
| | "【内容类型】判断图片类型(如:表情包、聊天截图、广告、新闻、普通照片等)。"
|
| | )
|
| |
|
| | def __init__(
|
| | self,
|
| | model_path: str = None,
|
| | device: str = "auto",
|
| | use_api: bool = False,
|
| | api_base: str = None,
|
| | api_key: str = None,
|
| | api_model: str = None,
|
| | load_local: bool = True,
|
| | api_max_calls: int = 200,
|
| | ):
|
| | self.model_path = model_path
|
| | self.device = device
|
| | self.model = None
|
| | self.processor = None
|
| | self._lock = threading.Lock()
|
| |
|
| |
|
| | self._api_call_count = 0
|
| | self._api_max_calls = api_max_calls
|
| | self._api_count_lock = threading.Lock()
|
| |
|
| |
|
| | self.api_client = None
|
| | self.api_model = api_model
|
| | if api_base and api_key:
|
| | self._init_api_client(api_base, api_key, api_model)
|
| |
|
| |
|
| | self.local_loaded = False
|
| | if load_local and model_path:
|
| | self._load_local_model()
|
| |
|
| |
|
| |
|
| |
|
| | def _init_api_client(self, api_base: str, api_key: str, api_model: str):
|
| | """初始化 DashScope OpenAI 兼容 API 客户端"""
|
| | from openai import OpenAI
|
| | self.api_client = OpenAI(
|
| | api_key=api_key,
|
| | base_url=api_base,
|
| | )
|
| | self.api_model = api_model
|
| | print(f"视觉语言模型 API 已就绪: {api_base} / {api_model}")
|
| | print(f"API 调用次数上限: {self._api_max_calls}")
|
| |
|
| |
|
| |
|
| |
|
| | @property
|
| | def api_call_count(self) -> int:
|
| | """当前已使用的 API 调用次数"""
|
| | with self._api_count_lock:
|
| | return self._api_call_count
|
| |
|
| | @property
|
| | def api_remaining(self) -> int:
|
| | """剩余可用的 API 调用次数"""
|
| | with self._api_count_lock:
|
| | return max(0, self._api_max_calls - self._api_call_count)
|
| |
|
| | @property
|
| | def api_limit_reached(self) -> bool:
|
| | """API 调用次数是否已达上限"""
|
| | with self._api_count_lock:
|
| | return self._api_call_count >= self._api_max_calls
|
| |
|
| | def _increment_api_count(self):
|
| | """递增 API 调用计数(线程安全)"""
|
| | with self._api_count_lock:
|
| | self._api_call_count += 1
|
| | remaining = self._api_max_calls - self._api_call_count
|
| | if remaining <= 10 and remaining >= 0:
|
| | print(f"[警告] 在线 API 剩余调用次数: {remaining}/{self._api_max_calls}")
|
| | elif self._api_call_count == self._api_max_calls:
|
| | print(f"[警告] 在线 API 调用次数已达上限 ({self._api_max_calls}),后续将自动降级为本地模型")
|
| |
|
| | @staticmethod
|
| | def _image_to_data_url(image_path: str) -> str:
|
| | """将本地图片文件转换为 base64 data URL"""
|
| | import base64
|
| | with open(image_path, "rb") as f:
|
| | data = base64.b64encode(f.read()).decode()
|
| | ext = os.path.splitext(image_path)[1].lower()
|
| | mime_map = {
|
| | ".jpg": "image/jpeg", ".jpeg": "image/jpeg",
|
| | ".png": "image/png", ".gif": "image/gif",
|
| | ".webp": "image/webp", ".bmp": "image/bmp",
|
| | }
|
| | mime = mime_map.get(ext, "image/png")
|
| | return f"data:{mime};base64,{data}"
|
| |
|
| | def _describe_image_api(self, image_path: str, prompt: str) -> str:
|
| | """通过在线 API 生成图片描述"""
|
| | if self.api_client is None:
|
| | raise RuntimeError("在线 API 未配置,请检查 vl_api_base / vl_api_key 设置")
|
| |
|
| | data_url = self._image_to_data_url(image_path)
|
| |
|
| | response = self.api_client.chat.completions.create(
|
| | model=self.api_model,
|
| | messages=[
|
| | {
|
| | "role": "user",
|
| | "content": [
|
| | {"type": "image_url", "image_url": {"url": data_url}},
|
| | {"type": "text", "text": prompt},
|
| | ],
|
| | }
|
| | ],
|
| | max_tokens=512,
|
| | )
|
| | return response.choices[0].message.content
|
| |
|
| |
|
| |
|
| |
|
| | def _load_local_model(self):
|
| | """加载本地 Qwen3-VL 模型"""
|
| | from transformers import Qwen3VLForConditionalGeneration
|
| |
|
| | local_path = resolve_model_path(self.model_path)
|
| | print(f"正在加载本地视觉语言模型: {local_path}...")
|
| |
|
| | self.processor = self._load_processor(local_path)
|
| | self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| | local_path,
|
| | torch_dtype="auto",
|
| | device_map=self.device,
|
| | trust_remote_code=True,
|
| | ).eval()
|
| | self.local_loaded = True
|
| | print("本地视觉语言模型加载完成。")
|
| |
|
| | def _load_processor(self, local_path: str):
|
| | """
|
| | 加载处理器,包含多级回退机制。
|
| | 某些 transformers 版本中 VIDEO_PROCESSOR_MAPPING_NAMES 未正确初始化,
|
| | 导致 AutoProcessor.from_pretrained 抛出 TypeError,此处做兼容处理。
|
| | """
|
| |
|
| | try:
|
| | from transformers import AutoProcessor
|
| | return AutoProcessor.from_pretrained(
|
| | local_path,
|
| | trust_remote_code=True,
|
| | )
|
| | except TypeError as e:
|
| | if "NoneType" in str(e):
|
| | print(f"AutoProcessor 遇到视频处理器兼容性问题: {e}")
|
| | else:
|
| | raise
|
| |
|
| |
|
| | try:
|
| | from transformers.models.auto import video_processing_auto
|
| | if video_processing_auto.VIDEO_PROCESSOR_MAPPING_NAMES is None:
|
| | video_processing_auto.VIDEO_PROCESSOR_MAPPING_NAMES = {}
|
| | print("已修复 VIDEO_PROCESSOR_MAPPING_NAMES 初始化问题,重新加载...")
|
| | from transformers import AutoProcessor
|
| | return AutoProcessor.from_pretrained(
|
| | local_path,
|
| | trust_remote_code=True,
|
| | )
|
| | except Exception as e:
|
| | print(f"修复后重试仍失败: {e}")
|
| |
|
| |
|
| | print("回退方案: 手动组装处理器...")
|
| | from transformers import AutoTokenizer, AutoImageProcessor
|
| | tokenizer = AutoTokenizer.from_pretrained(
|
| | local_path, trust_remote_code=True
|
| | )
|
| | image_processor = AutoImageProcessor.from_pretrained(
|
| | local_path, trust_remote_code=True
|
| | )
|
| | try:
|
| | from transformers import Qwen3VLProcessor
|
| | processor = Qwen3VLProcessor(
|
| | image_processor=image_processor,
|
| | tokenizer=tokenizer,
|
| | )
|
| | print("手动组装处理器成功。")
|
| | return processor
|
| | except (ImportError, Exception) as e:
|
| | raise RuntimeError(
|
| | f"处理器加载失败: {e}\n"
|
| | "请尝试: pip install -U transformers torchvision qwen-vl-utils"
|
| | )
|
| |
|
| | def _describe_image_local(self, image_path: str, prompt: str) -> str:
|
| | """使用本地模型生成图片描述"""
|
| | if not self.local_loaded:
|
| | raise RuntimeError(
|
| | "本地视觉模型未加载。请设置 XGUARD_VL_USE_API=false 重启,或切换为在线 API 模式。"
|
| | )
|
| |
|
| | with self._lock:
|
| | messages = [
|
| | {
|
| | "role": "user",
|
| | "content": [
|
| | {"type": "image", "image": image_path},
|
| | {"type": "text", "text": prompt},
|
| | ],
|
| | }
|
| | ]
|
| |
|
| | inputs = self.processor.apply_chat_template(
|
| | messages,
|
| | tokenize=True,
|
| | add_generation_prompt=True,
|
| | return_dict=True,
|
| | return_tensors="pt",
|
| | )
|
| | inputs = inputs.to(self.model.device)
|
| |
|
| | with torch.no_grad():
|
| | generated_ids = self.model.generate(
|
| | **inputs,
|
| | max_new_tokens=512,
|
| | do_sample=False,
|
| | )
|
| |
|
| | generated_ids_trimmed = [
|
| | out_ids[len(in_ids):]
|
| | for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| | ]
|
| | output_text = self.processor.batch_decode(
|
| | generated_ids_trimmed,
|
| | skip_special_tokens=True,
|
| | clean_up_tokenization_spaces=False,
|
| | )
|
| | return output_text[0]
|
| |
|
| |
|
| |
|
| |
|
| | def _ensure_local_model(self):
|
| | """确保本地模型已加载(用于 API 限额耗尽时的延迟加载)"""
|
| | if self.local_loaded:
|
| | return
|
| | if not self.model_path:
|
| | raise RuntimeError(
|
| | "在线 API 调用次数已达上限,且未配置本地模型路径 (XGUARD_VL_MODEL_PATH),"
|
| | "无法降级到本地模型。请配置本地模型或重启服务以重置 API 计数。"
|
| | )
|
| | print("[自动降级] API 次数耗尽,正在加载本地视觉语言模型...")
|
| | self._load_local_model()
|
| | print("[自动降级] 本地视觉语言模型加载完成。")
|
| |
|
| | def describe_image(self, image_path: str, prompt: str = None, use_api: bool = None) -> str:
|
| | """
|
| | 生成图片描述(统一接口)。
|
| |
|
| | 参数:
|
| | image_path: 图片文件路径
|
| | prompt: 自定义描述提示,为空则使用默认提示
|
| | use_api: 是否使用在线 API,为 None 时由 api_client 是否可用决定
|
| | 返回:
|
| | 图片的文本描述
|
| |
|
| | 注意:
|
| | 当 use_api=True 但 API 调用次数已达上限时,会自动降级到本地模型。
|
| | 降级信息通过返回值中的 metadata 属性传递(如有需要请检查 self.api_limit_reached)。
|
| | """
|
| | if not prompt:
|
| | prompt = self.DEFAULT_PROMPT
|
| |
|
| |
|
| | if use_api is None:
|
| | use_api = self.api_client is not None
|
| |
|
| |
|
| | if use_api and self.api_limit_reached:
|
| | remaining = self.api_remaining
|
| | print(
|
| | f"[API 限流] 在线 API 调用已达上限 "
|
| | f"({self._api_call_count}/{self._api_max_calls}),自动降级到本地模型"
|
| | )
|
| | self._ensure_local_model()
|
| | use_api = False
|
| |
|
| | if use_api:
|
| | self._increment_api_count()
|
| | return self._describe_image_api(image_path, prompt)
|
| | else:
|
| | return self._describe_image_local(image_path, prompt)
|
| |
|
| |
|
| | class XGuardModel:
|
| | """
|
| | YuFeng-XGuard 安全检测模型封装。
|
| |
|
| | 推理逻辑完全对齐官方实现:
|
| | - apply_chat_template 支持 policy / reason_first 参数
|
| | - 通过 decoded text 直接匹配 id2risk(而非 token_id 中转)
|
| | - reason_first 模式下正确定位风险 token 的 score 位置
|
| | """
|
| |
|
| | def __init__(self, model_path: str, device: str = "auto"):
|
| | self.model_path = model_path
|
| | self.device = device
|
| | self.model = None
|
| | self.tokenizer = None
|
| | self.id2risk = None
|
| | self._lock = threading.Lock()
|
| | self._load_model()
|
| |
|
| | def _load_model(self):
|
| | """加载模型和 tokenizer,提取 id2risk 映射表"""
|
| | local_path = resolve_model_path(self.model_path)
|
| |
|
| | print(f"正在加载安全检测模型: {local_path}...")
|
| | self.tokenizer = AutoTokenizer.from_pretrained(
|
| | local_path,
|
| | trust_remote_code=True
|
| | )
|
| | self.model = AutoModelForCausalLM.from_pretrained(
|
| | local_path,
|
| | torch_dtype="auto",
|
| | device_map=self.device,
|
| | trust_remote_code=True
|
| | ).eval()
|
| |
|
| |
|
| |
|
| |
|
| | self.id2risk = self.tokenizer.init_kwargs.get('id2risk', {})
|
| | print(f"id2risk 映射条目数: {len(self.id2risk)}")
|
| | print(f"##################self.id2risk: {self.id2risk} #####################")
|
| | if self.id2risk:
|
| | print(f"示例映射: {list(self.id2risk.items())[:5]}")
|
| |
|
| | def infer(self, messages: List[Dict[str, str]], policy=None,
|
| | max_new_tokens: int = 1, reason_first: bool = False) -> Dict[str, Any]:
|
| | """
|
| | 官方推理接口,完全对齐 XGuard 官方推理逻辑。
|
| |
|
| | 参数:
|
| | messages: 对话消息列表
|
| | policy: 动态策略(可选),用于运行时自定义安全检测规则
|
| | max_new_tokens: 最大生成 token 数
|
| | reason_first: 是否先生成归因分析再输出风险 token
|
| | 返回:
|
| | {
|
| | 'response': str, # 完整解码文本
|
| | 'token_score': {text: prob, ...}, # 风险 token 位置的 topk token 分数
|
| | 'risk_score': {risk_name: prob, ...} # 匹配到 id2risk 的风险类别分数
|
| | }
|
| | """
|
| | with self._lock:
|
| |
|
| | rendered_query = self.tokenizer.apply_chat_template(
|
| | messages,
|
| | policy=policy,
|
| | reason_first=reason_first,
|
| | tokenize=False
|
| | )
|
| |
|
| | model_inputs = self.tokenizer(
|
| | [rendered_query], return_tensors="pt"
|
| | ).to(self.model.device)
|
| |
|
| | with torch.no_grad():
|
| | outputs = self.model.generate(
|
| | **model_inputs,
|
| | max_new_tokens=max_new_tokens,
|
| | do_sample=False,
|
| | output_scores=True,
|
| | return_dict_in_generate=True
|
| | )
|
| |
|
| | batch_idx = 0
|
| | input_length = model_inputs['input_ids'].shape[1]
|
| |
|
| |
|
| | output_ids = outputs["sequences"].tolist()[batch_idx][input_length:]
|
| | response = self.tokenizer.decode(output_ids, skip_special_tokens=True)
|
| |
|
| |
|
| | generated_tokens = outputs.sequences[:, input_length:]
|
| | scores = torch.stack(outputs.scores, dim=1)
|
| | scores = scores.softmax(dim=-1)
|
| | scores_topk_value, scores_topk_index = scores.topk(k=10, dim=-1)
|
| |
|
| | generated_tokens_with_probs = []
|
| | for generated_token, score_topk_value, score_topk_index in zip(
|
| | generated_tokens, scores_topk_value, scores_topk_index
|
| | ):
|
| | generated_tokens_with_prob = []
|
| | for token, topk_value, topk_index in zip(
|
| | generated_token, score_topk_value, score_topk_index
|
| | ):
|
| | token = int(token.cpu())
|
| | if token == self.tokenizer.pad_token_id:
|
| | continue
|
| |
|
| | res_topk_score = {}
|
| | for ii, (value, index) in enumerate(zip(topk_value, topk_index)):
|
| | if ii == 0 or value.cpu().numpy() > 1e-4:
|
| | text = self.tokenizer.decode(index.cpu().numpy())
|
| | res_topk_score[text] = {
|
| | "id": str(int(index.cpu().numpy())),
|
| | "prob": round(float(value.cpu().numpy()), 4),
|
| | }
|
| |
|
| | generated_tokens_with_prob.append(res_topk_score)
|
| | generated_tokens_with_probs.append(generated_tokens_with_prob)
|
| |
|
| |
|
| |
|
| |
|
| | score_idx = (
|
| | max(len(generated_tokens_with_probs[batch_idx]) - 2, 0)
|
| | if reason_first else 0
|
| | )
|
| |
|
| |
|
| | token_score = {
|
| | k: v['prob']
|
| | for k, v in generated_tokens_with_probs[batch_idx][score_idx].items()
|
| | }
|
| | risk_score = {
|
| | self.id2risk[k]: v['prob']
|
| | for k, v in generated_tokens_with_probs[batch_idx][score_idx].items()
|
| | if k in self.id2risk
|
| | }
|
| |
|
| | return {
|
| | 'response': response,
|
| | 'token_score': token_score,
|
| | 'risk_score': risk_score,
|
| | }
|
| |
|
| | def parse_explanation(self, response: str) -> Optional[str]:
|
| | """
|
| | 从响应中解析归因分析部分。
|
| |
|
| | XGuard 在 reason_first=False 模式下,输出格式为:
|
| | [风险分类 token][归因分析文本]
|
| | 风险 token 是 id2risk 中的短字符串 key(如 'sec', 'pc' 等),
|
| | 后续文本为自然语言的归因分析说明。
|
| | """
|
| | if not response or not response.strip():
|
| | return None
|
| |
|
| |
|
| | match = re.search(r'<explanation>(.*?)</explanation>', response, re.DOTALL)
|
| | if match:
|
| | return match.group(1).strip()
|
| |
|
| | text = response.strip()
|
| |
|
| |
|
| |
|
| | if self.id2risk:
|
| | for key in sorted(self.id2risk.keys(), key=len, reverse=True):
|
| | if text.startswith(key):
|
| | remainder = text[len(key):].strip()
|
| | if remainder:
|
| | return remainder
|
| | break
|
| |
|
| |
|
| | if len(text) > 8:
|
| | return text
|
| |
|
| | return None
|
| |
|
| | def analyze(self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]],
|
| | enable_reasoning: bool = False, policy=None) -> Dict[str, Any]:
|
| | """
|
| | 高层分析接口,封装推理结果为结构化格式。
|
| |
|
| | 参数:
|
| | messages: 对话消息列表
|
| | tools: 工具信息(已拼接到 messages 中,暂未使用)
|
| | enable_reasoning: 是否启用归因分析(生成更多 token)
|
| | policy: 动态策略(可选)
|
| | """
|
| |
|
| | max_new_tokens = 512 if enable_reasoning else 1
|
| |
|
| | infer_result = self.infer(
|
| | messages,
|
| | policy=policy,
|
| | max_new_tokens=max_new_tokens,
|
| | reason_first=False
|
| | )
|
| | risk_scores = infer_result.get("risk_score", {})
|
| | response = infer_result.get("response", "")
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | SAFE_CATEGORY = "Safe-Safe"
|
| | safe_prob = risk_scores.get(SAFE_CATEGORY, 0.0)
|
| |
|
| |
|
| | risk_items = {k: v for k, v in risk_scores.items() if k != SAFE_CATEGORY}
|
| | sorted_risks = sorted(risk_items.items(), key=lambda x: x[1], reverse=True)
|
| |
|
| | top_risk_name = sorted_risks[0][0] if sorted_risks else ""
|
| | top_risk_prob = sorted_risks[0][1] if sorted_risks else 0.0
|
| |
|
| |
|
| | if safe_prob >= top_risk_prob and safe_prob >= 0.5:
|
| |
|
| | is_safe = 1
|
| | risk_level = "safe"
|
| | elif safe_prob >= top_risk_prob:
|
| |
|
| |
|
| | is_safe = 0
|
| | risk_level = "low"
|
| | else:
|
| |
|
| |
|
| | is_safe = 0
|
| | if top_risk_prob >= 0.5:
|
| | risk_level = "high"
|
| | elif top_risk_prob >= 0.3:
|
| | risk_level = "medium"
|
| | else:
|
| | risk_level = "low"
|
| |
|
| |
|
| | confidence = safe_prob if is_safe == 1 else top_risk_prob
|
| |
|
| |
|
| |
|
| | if is_safe == 0:
|
| | top_risks = sorted_risks[:3]
|
| | else:
|
| |
|
| | top_risks = sorted_risks[:1] if sorted_risks else []
|
| |
|
| | risk_types = [r[0] for r in top_risks]
|
| | reason = "; ".join([f"{r}: {s}" for r, s in top_risks])
|
| |
|
| | result = {
|
| | "is_safe": is_safe,
|
| | "risk_level": risk_level,
|
| | "confidence": round(confidence, 4),
|
| | "risk_type": risk_types,
|
| | "reason": reason,
|
| | "detail_scores": risk_scores,
|
| | "response": response
|
| | }
|
| |
|
| |
|
| | if enable_reasoning:
|
| | explanation = self.parse_explanation(response)
|
| | if explanation:
|
| | result["explanation"] = explanation
|
| |
|
| | return result
|
| |
|