File size: 12,799 Bytes
6a3bd1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 |
import torch
import json
import re
from PIL import Image
from typing import List, Dict, Tuple
from datetime import datetime
from caption_generation_manager import CaptionGenerationManager
class BrandVerificationManager:
"""VLM-based brand verification and three-way voting system"""
def __init__(self, caption_generator: CaptionGenerationManager = None):
"""
Args:
caption_generator: CaptionGenerationManager instance for VLM access
"""
if caption_generator is None:
caption_generator = CaptionGenerationManager()
self.caption_generator = caption_generator
# Confidence mapping for VLM responses
self.confidence_map = {
'high': 0.9,
'medium': 0.7,
'low': 0.5,
'very high': 0.95,
'very low': 0.3
}
print("✓ Brand Verification Manager initialized with VLM")
def verify_brands(self, image: Image.Image, detected_brands: List[Tuple[str, float, list]]) -> Dict:
"""
Use VLM to verify detected brands
Args:
image: PIL Image
detected_brands: List of (brand_name, confidence, bbox) tuples
Returns:
Dictionary with verification results
"""
if not detected_brands:
return {
'verified_brands': [],
'false_positives': [],
'additional_brands': [],
'confidence': 0.0
}
# Construct verification prompt
brand_list = ', '.join([f"{brand[0]} (confidence: {brand[1]:.2f})"
for brand in detected_brands[:3]]) # Top 3 brands
verification_prompt = f"""Analyze this image carefully. Our computer vision system detected the following brands: {brand_list}.
Please verify each brand identification:
1. Are these brand identifications correct based on visible logos, patterns, text, or distinctive features?
2. If incorrect, what brands do you actually see (if any)?
3. Describe the visual evidence (logo shape, text, pattern, color scheme, hardware) that supports your conclusion.
Respond in JSON format:
{{
"verified_brands": [
{{"name": "Brand Name", "confidence": "high/medium/low", "evidence": "description of visual evidence"}}
],
"false_positives": ["brand names that were incorrectly detected"],
"additional_brands": ["brands we missed but you can see"]
}}
IMPORTANT: Only include brands you can clearly identify with visual evidence. If unsure, use "low" confidence."""
# Generate VLM response
try:
response = self._generate_vlm_response(image, verification_prompt)
parsed_result = self._parse_verification_response(response)
return parsed_result
except Exception as e:
print(f"VLM verification error: {e}")
# Fallback to original detections
return {
'verified_brands': [
{'name': brand[0], 'confidence': 'medium', 'evidence': 'VLM verification failed'}
for brand in detected_brands
],
'false_positives': [],
'additional_brands': []
}
def three_way_voting(self, openclip_brands: List[Tuple], ocr_brands: Dict,
vlm_result: Dict) -> List[Tuple[str, float, list]]:
"""
Three-way voting: OpenCLIP vs OCR vs VLM
Args:
openclip_brands: List of (brand_name, confidence, bbox) from OpenCLIP
ocr_brands: Dict of {brand_name: (text_score, ocr_conf)} from OCR
vlm_result: Verification result from VLM
Returns:
List of (brand_name, final_confidence, bbox) tuples
"""
votes = {} # brand_name -> {votes: int, sources: list, bbox: list}
confidence_scores = {} # brand_name -> list of (source, confidence)
# Vote 1: OpenCLIP
for brand_name, confidence, bbox in openclip_brands:
if brand_name not in votes:
votes[brand_name] = {'votes': 0, 'sources': [], 'bbox': bbox}
confidence_scores[brand_name] = []
votes[brand_name]['votes'] += 1
votes[brand_name]['sources'].append('openclip')
confidence_scores[brand_name].append(('openclip', confidence * 0.8))
# Vote 2: OCR
for brand_name, (text_score, ocr_conf) in ocr_brands.items():
if brand_name not in votes:
# OCR found a brand not detected by OpenCLIP
votes[brand_name] = {'votes': 0, 'sources': [], 'bbox': None}
confidence_scores[brand_name] = []
votes[brand_name]['votes'] += 1
votes[brand_name]['sources'].append('ocr')
combined_ocr_score = (text_score + ocr_conf) / 2
confidence_scores[brand_name].append(('ocr', combined_ocr_score * 0.7))
# Vote 3: VLM (double weight - most reliable)
for brand_info in vlm_result.get('verified_brands', []):
brand_name = brand_info['name']
vlm_confidence_level = brand_info.get('confidence', 'medium')
vlm_confidence = self.confidence_map.get(vlm_confidence_level.lower(), 0.7)
if brand_name not in votes:
# VLM found a brand missed by both OpenCLIP and OCR
votes[brand_name] = {'votes': 0, 'sources': [], 'bbox': None}
confidence_scores[brand_name] = []
votes[brand_name]['votes'] += 2 # VLM gets double vote
votes[brand_name]['sources'].append('vlm')
confidence_scores[brand_name].append(('vlm', vlm_confidence))
# Remove false positives flagged by VLM
for false_positive in vlm_result.get('false_positives', []):
if false_positive in votes:
# Reduce votes significantly
votes[false_positive]['votes'] = max(0, votes[false_positive]['votes'] - 2)
# Calculate final scores
final_brands = []
for brand_name, vote_info in votes.items():
if vote_info['votes'] <= 0:
continue # Skip brands with no votes
# Calculate weighted average confidence
scores = confidence_scores.get(brand_name, [])
if not scores:
continue
# VLM has highest weight, OpenCLIP medium, OCR lowest
weighted_sum = 0.0
weight_total = 0.0
for source, score in scores:
if source == 'vlm':
weight = 1.0
elif source == 'openclip':
weight = 0.6
else: # ocr
weight = 0.4
weighted_sum += score * weight
weight_total += weight
avg_confidence = weighted_sum / weight_total if weight_total > 0 else 0.0
# Boost confidence if multiple sources agree
if vote_info['votes'] >= 2:
avg_confidence *= 1.15 # 15% boost for agreement
# Cap at 0.95
avg_confidence = min(avg_confidence, 0.95)
# Only include if confidence is reasonable
if avg_confidence > 0.30:
final_brands.append((brand_name, avg_confidence, vote_info['bbox']))
# Sort by confidence
final_brands.sort(key=lambda x: x[1], reverse=True)
return final_brands
def extract_visual_evidence(self, image: Image.Image, brand_name: str) -> Dict:
"""
Extract detailed visual evidence for identified brand
Args:
image: PIL Image
brand_name: Identified brand name
Returns:
Dictionary with evidence description
"""
evidence_prompt = f"""You identified {brand_name} in this image. Please describe the specific visual evidence:
1. Logo appearance: Describe the logo's shape, style, color, and exact location in the image
2. Text elements: What text did you see? (exact wording, font style, placement)
3. Distinctive patterns: Any signature patterns, textures, or design elements
4. Color scheme: Brand-specific colors used
5. Product features: Distinctive product design characteristics
Be specific and detailed. Focus on objective visual features."""
try:
evidence_description = self._generate_vlm_response(image, evidence_prompt)
return {
'brand': brand_name,
'evidence_description': evidence_description,
'timestamp': datetime.now().isoformat()
}
except Exception as e:
return {
'brand': brand_name,
'evidence_description': f"Evidence extraction failed: {str(e)}",
'timestamp': datetime.now().isoformat()
}
def _generate_vlm_response(self, image: Image.Image, prompt: str) -> str:
"""
Generate VLM response for given image and prompt
Args:
image: PIL Image
prompt: Text prompt
Returns:
VLM response string
"""
from qwen_vl_utils import process_vision_info
messages = [{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}]
text = self.caption_generator.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.caption_generator.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
).to(self.caption_generator.model.device)
# Generate with low temperature for factual responses
generation_config = {
'temperature': 0.3, # Low temperature for factual verification
'top_p': 0.9,
'max_new_tokens': 300,
'repetition_penalty': 1.1
}
generated_ids = self.caption_generator.model.generate(
**inputs,
**generation_config
)
# Trim input tokens
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.caption_generator.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
return output_text
def _parse_verification_response(self, response: str) -> Dict:
"""
Parse VLM verification response
Args:
response: VLM response string
Returns:
Parsed dictionary
"""
try:
# Try to extract JSON from response
json_match = re.search(r'\{.*\}', response, re.DOTALL)
if json_match:
result = json.loads(json_match.group())
return result
except json.JSONDecodeError:
pass
# Fallback: rule-based parsing
return self._rule_based_parse(response)
def _rule_based_parse(self, response: str) -> Dict:
"""
Fallback rule-based parsing if JSON fails
Args:
response: VLM response string
Returns:
Parsed dictionary
"""
result = {
'verified_brands': [],
'false_positives': [],
'additional_brands': []
}
# Simple pattern matching
lines = response.lower().split('\n')
for line in lines:
# Look for brand names mentioned with positive sentiment
if any(word in line for word in ['correct', 'yes', 'visible', 'see', 'identified']):
# Extract potential brand names (capitalize words)
words = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', response)
for word in words:
if len(word) > 2: # Avoid short words
result['verified_brands'].append({
'name': word,
'confidence': 'medium',
'evidence': 'Extracted from VLM response'
})
return result
print("✓ BrandVerificationManager (VLM verification and voting) defined")
|