SafeQwen2.5-VL-7B / modeling_safeqwen2_5_vl.py
ywlee88's picture
Upload folder using huggingface_hub
c4cf665 verified
"""
SafeQwen2.5-VL Model Implementation
SafeQwen2.5-VL extends Qwen2.5-VL with multimodal safety classification capabilities.
It adds a safety classification head that operates on pooled image features to identify
potentially unsafe content across 20 safety categories.
Key features:
- Non-invasive architecture: Uses standard Qwen2.5-VL forward pass
- Post-processing safety classification on image features
- Simple pooling strategy for feature aggregation
- Full gradient flow compatibility for training
Author: SafeQwen Team
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Union
from dataclasses import dataclass
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
Qwen2_5_VLConfig,
)
from .configuration_safeqwen2_5_vl import SafeQwen2_5_VLConfig
@dataclass
class SafeQwen2_5_VLOutput(CausalLMOutputWithPast):
"""
Output class for SafeQwen2.5-VL with safety classification results.
Extends the standard CausalLMOutputWithPast to include safety-related outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
Language modeling loss (and safety loss if labels provided).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*):
Cached key/value attention states.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Hidden states of the model at each layer.
attentions (`tuple(torch.FloatTensor)`, *optional*):
Attention weights at each layer.
rope_deltas (`torch.LongTensor`, *optional*):
RoPE position deltas for Qwen2.5-VL.
img_safety_logits (`torch.FloatTensor` of shape `(batch_size, num_safety_categories)`, *optional*):
Safety classification logits for each image in the batch.
img_safety_probs (`torch.FloatTensor` of shape `(batch_size, num_safety_categories)`, *optional*):
Safety classification probabilities (softmax of logits).
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
rope_deltas: Optional[torch.LongTensor] = None
img_safety_logits: Optional[torch.FloatTensor] = None
img_safety_probs: Optional[torch.FloatTensor] = None
class SafetyMLP(nn.Module):
"""
Multi-layer perceptron for safety classification.
A simple feedforward network that maps image features to safety category logits.
Args:
input_size (`int`):
Size of input features (typically model hidden size).
hidden_size (`int`):
Size of hidden layer(s).
output_size (`int`):
Number of output safety categories.
num_hidden_layers (`int`, *optional*, defaults to 1):
Number of hidden layers in the MLP.
"""
def __init__(
self,
input_size: int,
hidden_size: int,
output_size: int,
num_hidden_layers: int = 1
):
super().__init__()
layers = []
# First layer
layers.append(nn.Linear(input_size, hidden_size))
layers.append(nn.GELU())
layers.append(nn.Dropout(0.1))
# Additional hidden layers
for _ in range(num_hidden_layers - 1):
layers.append(nn.Linear(hidden_size, hidden_size))
layers.append(nn.GELU())
layers.append(nn.Dropout(0.1))
# Output layer
layers.append(nn.Linear(hidden_size, output_size))
self.mlp = nn.Sequential(*layers)
def forward(self, x):
return self.mlp(x)
class SafeQwen2_5_VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
"""
SafeQwen2.5-VL model for conditional generation with safety classification.
This model extends Qwen2_5_VLForConditionalGeneration with an additional safety
classification head that analyzes image content for potential safety concerns.
The model architecture:
1. Uses standard Qwen2.5-VL for vision-language modeling
2. Extracts image features from hidden states using pooling
3. Passes pooled features through a safety classification MLP
4. Returns both generation outputs and safety predictions
Key design principles:
- Non-invasive: Does not modify base Qwen2.5-VL forward pass
- Post-processing: Safety classification happens after standard forward pass
- Gradient-friendly: Maintains full gradient flow for end-to-end training
Example:
```python
from transformers import AutoModel, AutoProcessor
import torch
# Load model and processor
model = AutoModel.from_pretrained("your-username/SafeQwen2.5-VL-7B", trust_remote_code=True)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
# Prepare inputs
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": "path/to/image.jpg"},
{"type": "text", "text": "Describe this image."},
],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Generate with safety classification
outputs = model(**inputs, do_safety=True)
# Access safety predictions
safety_probs = outputs.img_safety_probs
print(f"Safety probabilities: {safety_probs}")
# Generate text
generated_ids = model.generate(**inputs, max_new_tokens=128)
```
"""
config_class = SafeQwen2_5_VLConfig
def __init__(self, config: SafeQwen2_5_VLConfig):
super().__init__(config)
# Add safety head if safety configuration is present
num_safety_categories = getattr(config, 'num_safety_categories', None)
if num_safety_categories and num_safety_categories > 0:
hidden_size = config.hidden_size
safety_head_hidden_scale = getattr(config, 'safety_head_hidden_scale', 4.0)
safety_hidden_size = int(hidden_size * safety_head_hidden_scale)
safety_num_hidden_layers = getattr(config, 'safety_num_hidden_layers', 1)
self.img_safety_head = SafetyMLP(
input_size=hidden_size,
hidden_size=safety_hidden_size,
output_size=num_safety_categories,
num_hidden_layers=safety_num_hidden_layers
)
else:
self.img_safety_head = None
def _extract_image_features_simple(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
input_ids: Optional[torch.Tensor] = None
) -> Optional[torch.Tensor]:
"""
Extract image features using pooling over image token positions.
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`):
Hidden states from the model.
attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`, *optional*):
Attention mask (currently unused, reserved for future use).
input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`, *optional*):
Input token IDs used to identify image token positions.
Returns:
`torch.Tensor` of shape `(batch_size, hidden_size)` or `None`:
Pooled image features for each sample in the batch.
"""
if input_ids is None:
return None
# Find image token positions (Qwen2.5-VL uses specific image token IDs)
image_token_id = getattr(self.config, 'image_token_id', 151655)
# Create mask for image tokens and move to same device as hidden_states
image_mask = (input_ids == image_token_id).to(hidden_states.device) # [batch_size, seq_len]
if not image_mask.any():
return None
# Pool image token features for each sample in batch
batch_size = hidden_states.shape[0]
hidden_size = hidden_states.shape[-1]
# Use list comprehension to avoid in-place operations
image_features_list = []
for i in range(batch_size):
sample_image_mask = image_mask[i] # [seq_len]
if sample_image_mask.any():
# Extract hidden states for image tokens
sample_image_features = hidden_states[i][sample_image_mask] # [num_image_tokens, hidden_size]
# Simple mean pooling - maintains gradients
pooled_features = sample_image_features.mean(dim=0) # [hidden_size]
image_features_list.append(pooled_features)
else:
# For samples without images, use gradient-preserving zero
zero_features = hidden_states[i, 0, :] * 0.0
image_features_list.append(zero_features)
# Stack the features - this maintains gradient flow
image_features = torch.stack(image_features_list, dim=0) # [batch_size, hidden_size]
return image_features
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
do_safety: bool = True,
safety_labels: Optional[torch.LongTensor] = None,
**kwargs
) -> Union[Tuple, SafeQwen2_5_VLOutput]:
"""
Forward pass with optional safety classification.
Args:
do_safety (`bool`, *optional*, defaults to `True`):
Whether to perform safety classification. Set to False during generation.
safety_labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Ground truth safety category labels for training (currently unused).
Returns:
`SafeQwen2_5_VLOutput` or `tuple`:
Model outputs including optional safety predictions.
"""
# Force output_hidden_states if we need safety classification
if do_safety and self.img_safety_head is not None:
output_hidden_states = True
return_dict = True
# Standard Qwen2.5-VL forward pass - NO MODIFICATIONS
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
return_dict=True,
**kwargs
)
# Initialize safety outputs
img_safety_logits = None
img_safety_probs = None
# Post-process for safety classification
# Only do safety classification during initial forward pass, not during generation
is_generation = past_key_values is not None and len(past_key_values) > 0
# Check if we have image tokens in the input
has_image_tokens = False
if input_ids is not None:
image_token_id = getattr(self.config, 'image_token_id', 151655)
has_image_tokens = (input_ids == image_token_id).any().item()
# Only perform safety classification if:
# 1. Safety is requested
# 2. We have a safety head
# 3. We have hidden states
# 4. We have image tokens
# 5. This is NOT during text generation
should_do_safety = (
do_safety and
self.img_safety_head is not None and
outputs.hidden_states is not None and
has_image_tokens and
not is_generation
)
if should_do_safety:
# Extract image features from hidden states
last_hidden_state = outputs.hidden_states[-1] # [batch_size, seq_len, hidden_size]
image_features = self._extract_image_features_simple(
last_hidden_state, attention_mask, input_ids
)
if image_features is not None:
# Run through safety head
img_safety_logits = self.img_safety_head(image_features)
img_safety_probs = torch.softmax(img_safety_logits, dim=-1)
# Return results
if return_dict is False:
output = (outputs.loss, outputs.logits, outputs.past_key_values,
outputs.hidden_states, outputs.attentions)
if img_safety_logits is not None:
output += (img_safety_logits, img_safety_probs)
return output
else:
return SafeQwen2_5_VLOutput(
loss=outputs.loss,
logits=outputs.logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=getattr(outputs, 'rope_deltas', None),
img_safety_logits=img_safety_logits,
img_safety_probs=img_safety_probs
)