|
|
""" |
|
|
SafeQwen2.5-VL Model Implementation |
|
|
|
|
|
SafeQwen2.5-VL extends Qwen2.5-VL with multimodal safety classification capabilities. |
|
|
It adds a safety classification head that operates on pooled image features to identify |
|
|
potentially unsafe content across 20 safety categories. |
|
|
|
|
|
Key features: |
|
|
- Non-invasive architecture: Uses standard Qwen2.5-VL forward pass |
|
|
- Post-processing safety classification on image features |
|
|
- Simple pooling strategy for feature aggregation |
|
|
- Full gradient flow compatibility for training |
|
|
|
|
|
Author: SafeQwen Team |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Optional, Tuple, List, Union |
|
|
from dataclasses import dataclass |
|
|
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from transformers.models.qwen2_5_vl import ( |
|
|
Qwen2_5_VLForConditionalGeneration, |
|
|
Qwen2_5_VLConfig, |
|
|
) |
|
|
|
|
|
from .configuration_safeqwen2_5_vl import SafeQwen2_5_VLConfig |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SafeQwen2_5_VLOutput(CausalLMOutputWithPast): |
|
|
""" |
|
|
Output class for SafeQwen2.5-VL with safety classification results. |
|
|
|
|
|
Extends the standard CausalLMOutputWithPast to include safety-related outputs. |
|
|
|
|
|
Args: |
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*): |
|
|
Language modeling loss (and safety loss if labels provided). |
|
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
|
|
Prediction scores of the language modeling head. |
|
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*): |
|
|
Cached key/value attention states. |
|
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*): |
|
|
Hidden states of the model at each layer. |
|
|
attentions (`tuple(torch.FloatTensor)`, *optional*): |
|
|
Attention weights at each layer. |
|
|
rope_deltas (`torch.LongTensor`, *optional*): |
|
|
RoPE position deltas for Qwen2.5-VL. |
|
|
img_safety_logits (`torch.FloatTensor` of shape `(batch_size, num_safety_categories)`, *optional*): |
|
|
Safety classification logits for each image in the batch. |
|
|
img_safety_probs (`torch.FloatTensor` of shape `(batch_size, num_safety_categories)`, *optional*): |
|
|
Safety classification probabilities (softmax of logits). |
|
|
""" |
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
logits: Optional[torch.FloatTensor] = None |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None |
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
rope_deltas: Optional[torch.LongTensor] = None |
|
|
img_safety_logits: Optional[torch.FloatTensor] = None |
|
|
img_safety_probs: Optional[torch.FloatTensor] = None |
|
|
|
|
|
|
|
|
class SafetyMLP(nn.Module): |
|
|
""" |
|
|
Multi-layer perceptron for safety classification. |
|
|
|
|
|
A simple feedforward network that maps image features to safety category logits. |
|
|
|
|
|
Args: |
|
|
input_size (`int`): |
|
|
Size of input features (typically model hidden size). |
|
|
hidden_size (`int`): |
|
|
Size of hidden layer(s). |
|
|
output_size (`int`): |
|
|
Number of output safety categories. |
|
|
num_hidden_layers (`int`, *optional*, defaults to 1): |
|
|
Number of hidden layers in the MLP. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_size: int, |
|
|
hidden_size: int, |
|
|
output_size: int, |
|
|
num_hidden_layers: int = 1 |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
layers = [] |
|
|
|
|
|
|
|
|
layers.append(nn.Linear(input_size, hidden_size)) |
|
|
layers.append(nn.GELU()) |
|
|
layers.append(nn.Dropout(0.1)) |
|
|
|
|
|
|
|
|
for _ in range(num_hidden_layers - 1): |
|
|
layers.append(nn.Linear(hidden_size, hidden_size)) |
|
|
layers.append(nn.GELU()) |
|
|
layers.append(nn.Dropout(0.1)) |
|
|
|
|
|
|
|
|
layers.append(nn.Linear(hidden_size, output_size)) |
|
|
|
|
|
self.mlp = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.mlp(x) |
|
|
|
|
|
|
|
|
class SafeQwen2_5_VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): |
|
|
""" |
|
|
SafeQwen2.5-VL model for conditional generation with safety classification. |
|
|
|
|
|
This model extends Qwen2_5_VLForConditionalGeneration with an additional safety |
|
|
classification head that analyzes image content for potential safety concerns. |
|
|
|
|
|
The model architecture: |
|
|
1. Uses standard Qwen2.5-VL for vision-language modeling |
|
|
2. Extracts image features from hidden states using pooling |
|
|
3. Passes pooled features through a safety classification MLP |
|
|
4. Returns both generation outputs and safety predictions |
|
|
|
|
|
Key design principles: |
|
|
- Non-invasive: Does not modify base Qwen2.5-VL forward pass |
|
|
- Post-processing: Safety classification happens after standard forward pass |
|
|
- Gradient-friendly: Maintains full gradient flow for end-to-end training |
|
|
|
|
|
Example: |
|
|
```python |
|
|
from transformers import AutoModel, AutoProcessor |
|
|
import torch |
|
|
|
|
|
# Load model and processor |
|
|
model = AutoModel.from_pretrained("your-username/SafeQwen2.5-VL-7B", trust_remote_code=True) |
|
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") |
|
|
|
|
|
# Prepare inputs |
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": "path/to/image.jpg"}, |
|
|
{"type": "text", "text": "Describe this image."}, |
|
|
], |
|
|
} |
|
|
] |
|
|
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
inputs = processor( |
|
|
text=[text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
# Generate with safety classification |
|
|
outputs = model(**inputs, do_safety=True) |
|
|
|
|
|
# Access safety predictions |
|
|
safety_probs = outputs.img_safety_probs |
|
|
print(f"Safety probabilities: {safety_probs}") |
|
|
|
|
|
# Generate text |
|
|
generated_ids = model.generate(**inputs, max_new_tokens=128) |
|
|
``` |
|
|
""" |
|
|
|
|
|
config_class = SafeQwen2_5_VLConfig |
|
|
|
|
|
def __init__(self, config: SafeQwen2_5_VLConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
num_safety_categories = getattr(config, 'num_safety_categories', None) |
|
|
if num_safety_categories and num_safety_categories > 0: |
|
|
hidden_size = config.hidden_size |
|
|
safety_head_hidden_scale = getattr(config, 'safety_head_hidden_scale', 4.0) |
|
|
safety_hidden_size = int(hidden_size * safety_head_hidden_scale) |
|
|
safety_num_hidden_layers = getattr(config, 'safety_num_hidden_layers', 1) |
|
|
|
|
|
self.img_safety_head = SafetyMLP( |
|
|
input_size=hidden_size, |
|
|
hidden_size=safety_hidden_size, |
|
|
output_size=num_safety_categories, |
|
|
num_hidden_layers=safety_num_hidden_layers |
|
|
) |
|
|
else: |
|
|
self.img_safety_head = None |
|
|
|
|
|
def _extract_image_features_simple( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
input_ids: Optional[torch.Tensor] = None |
|
|
) -> Optional[torch.Tensor]: |
|
|
""" |
|
|
Extract image features using pooling over image token positions. |
|
|
|
|
|
Args: |
|
|
hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`): |
|
|
Hidden states from the model. |
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`, *optional*): |
|
|
Attention mask (currently unused, reserved for future use). |
|
|
input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`, *optional*): |
|
|
Input token IDs used to identify image token positions. |
|
|
|
|
|
Returns: |
|
|
`torch.Tensor` of shape `(batch_size, hidden_size)` or `None`: |
|
|
Pooled image features for each sample in the batch. |
|
|
""" |
|
|
if input_ids is None: |
|
|
return None |
|
|
|
|
|
|
|
|
image_token_id = getattr(self.config, 'image_token_id', 151655) |
|
|
|
|
|
|
|
|
image_mask = (input_ids == image_token_id).to(hidden_states.device) |
|
|
|
|
|
if not image_mask.any(): |
|
|
return None |
|
|
|
|
|
|
|
|
batch_size = hidden_states.shape[0] |
|
|
hidden_size = hidden_states.shape[-1] |
|
|
|
|
|
|
|
|
image_features_list = [] |
|
|
|
|
|
for i in range(batch_size): |
|
|
sample_image_mask = image_mask[i] |
|
|
|
|
|
if sample_image_mask.any(): |
|
|
|
|
|
sample_image_features = hidden_states[i][sample_image_mask] |
|
|
|
|
|
|
|
|
pooled_features = sample_image_features.mean(dim=0) |
|
|
image_features_list.append(pooled_features) |
|
|
else: |
|
|
|
|
|
zero_features = hidden_states[i, 0, :] * 0.0 |
|
|
image_features_list.append(zero_features) |
|
|
|
|
|
|
|
|
image_features = torch.stack(image_features_list, dim=0) |
|
|
|
|
|
return image_features |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
do_safety: bool = True, |
|
|
safety_labels: Optional[torch.LongTensor] = None, |
|
|
**kwargs |
|
|
) -> Union[Tuple, SafeQwen2_5_VLOutput]: |
|
|
""" |
|
|
Forward pass with optional safety classification. |
|
|
|
|
|
Args: |
|
|
do_safety (`bool`, *optional*, defaults to `True`): |
|
|
Whether to perform safety classification. Set to False during generation. |
|
|
safety_labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
|
Ground truth safety category labels for training (currently unused). |
|
|
|
|
|
Returns: |
|
|
`SafeQwen2_5_VLOutput` or `tuple`: |
|
|
Model outputs including optional safety predictions. |
|
|
""" |
|
|
|
|
|
|
|
|
if do_safety and self.img_safety_head is not None: |
|
|
output_hidden_states = True |
|
|
return_dict = True |
|
|
|
|
|
|
|
|
outputs = super().forward( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
labels=labels, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
pixel_values=pixel_values, |
|
|
pixel_values_videos=pixel_values_videos, |
|
|
image_grid_thw=image_grid_thw, |
|
|
video_grid_thw=video_grid_thw, |
|
|
return_dict=True, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
img_safety_logits = None |
|
|
img_safety_probs = None |
|
|
|
|
|
|
|
|
|
|
|
is_generation = past_key_values is not None and len(past_key_values) > 0 |
|
|
|
|
|
|
|
|
has_image_tokens = False |
|
|
if input_ids is not None: |
|
|
image_token_id = getattr(self.config, 'image_token_id', 151655) |
|
|
has_image_tokens = (input_ids == image_token_id).any().item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
should_do_safety = ( |
|
|
do_safety and |
|
|
self.img_safety_head is not None and |
|
|
outputs.hidden_states is not None and |
|
|
has_image_tokens and |
|
|
not is_generation |
|
|
) |
|
|
|
|
|
if should_do_safety: |
|
|
|
|
|
last_hidden_state = outputs.hidden_states[-1] |
|
|
|
|
|
image_features = self._extract_image_features_simple( |
|
|
last_hidden_state, attention_mask, input_ids |
|
|
) |
|
|
|
|
|
if image_features is not None: |
|
|
|
|
|
img_safety_logits = self.img_safety_head(image_features) |
|
|
img_safety_probs = torch.softmax(img_safety_logits, dim=-1) |
|
|
|
|
|
|
|
|
if return_dict is False: |
|
|
output = (outputs.loss, outputs.logits, outputs.past_key_values, |
|
|
outputs.hidden_states, outputs.attentions) |
|
|
if img_safety_logits is not None: |
|
|
output += (img_safety_logits, img_safety_probs) |
|
|
return output |
|
|
else: |
|
|
return SafeQwen2_5_VLOutput( |
|
|
loss=outputs.loss, |
|
|
logits=outputs.logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
rope_deltas=getattr(outputs, 'rope_deltas', None), |
|
|
img_safety_logits=img_safety_logits, |
|
|
img_safety_probs=img_safety_probs |
|
|
) |
|
|
|