File size: 7,475 Bytes
6a3bd1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import torch
from PIL import Image
from typing import Dict, List, Tuple
import numpy as np
class BrandDetectionOptimizer:
"""
智能品牌檢測優化器 - 性能與準確度平衡
通過快速預篩選減少不必要的深度檢測
"""
def __init__(self, clip_manager, ocr_manager, prompt_library):
self.clip_manager = clip_manager
self.ocr_manager = ocr_manager
self.prompt_library = prompt_library
def quick_brand_prescreening(self, image: Image.Image) -> List[str]:
"""
快速品牌預篩選 - 只檢測最可能的品牌類別
大幅減少需要深度檢測的品牌數量
Returns:
List of brand names that are likely present
"""
# Step 1: OCR 快速掃描(最快且最準確的方法)
likely_brands = set()
ocr_results = self.ocr_manager.extract_text(image, use_brand_preprocessing=True)
for ocr_item in ocr_results:
text = ocr_item['text'].upper()
# 過濾太短的文字(避免誤匹配)
if len(text) < 2:
continue
# 檢查所有品牌的別名
for brand_name, brand_info in self.prompt_library.get_all_brands().items():
aliases = [alias.upper() for alias in brand_info.get('aliases', [])]
# 完全匹配或部分匹配(但要求較高相似度)
for alias in aliases:
# 完全匹配
if alias == text:
likely_brands.add(brand_name)
break
# 部分匹配:要求別名長度 >= 3 且匹配度高
elif len(alias) >= 3:
if alias in text and len(alias) / len(text) > 0.6:
likely_brands.add(brand_name)
break
elif text in alias and len(text) / len(alias) > 0.6:
likely_brands.add(brand_name)
break
# Step 2: 視覺特徵快速分類(使用通用品牌類別)
category_prompts = {
'luxury': 'luxury brand product with monogram pattern and leather details',
'sportswear': 'sportswear brand product with athletic logo and swoosh design',
'tech': 'technology brand product with minimalist design and metal finish',
'automotive': 'luxury car brand with distinctive grille and emblem',
'watches': 'luxury watch with distinctive dial and brand logo',
'fashion': 'fashion brand product with signature pattern or logo'
}
category_scores = self.clip_manager.classify_zero_shot(
image, list(category_prompts.values())
)
# 獲取最可能的類別(top 2)
sorted_categories = sorted(
category_scores.items(), key=lambda x: x[1], reverse=True
)[:2]
# 將類別映射回品牌
category_mapping = {v: k for k, v in category_prompts.items()}
for prompt_text, score in sorted_categories:
if score > 0.30: # 提高閾值,減少誤判(0.15 → 0.30)
category = category_mapping[prompt_text]
# 添加該類別的所有品牌
category_brands = self.prompt_library.get_brands_by_category(category)
likely_brands.update(category_brands.keys())
# Step 3: 如果完全沒有線索,只添加視覺特徵最明顯的 3 個品牌(保底)
# 注意:這不是硬編碼,而是在無任何線索時的合理默認值
if not likely_brands:
# 只添加視覺特徵極其明顯且常見的品牌
default_brands = ['Louis Vuitton', 'Gucci', 'Nike']
likely_brands.update(default_brands)
# 返回品牌列表(不限制數量,而是依賴質量過濾)
return list(likely_brands)
def smart_region_selection(self, image: Image.Image,
saliency_regions: List[Dict]) -> List[Tuple[int, int, int, int]]:
"""
智能區域選擇 - 只掃描有品牌可能性的區域
替代低效的網格掃描
Args:
image: PIL Image
saliency_regions: Saliency detection results
Returns:
List of bboxes (x1, y1, x2, y2) to scan
"""
regions_to_scan = []
img_width, img_height = image.size
# Strategy 1: 使用顯著性區域(最有可能包含品牌)
if saliency_regions:
for region in saliency_regions[:3]: # Top 3 salient regions
bbox = region.get('bbox')
if bbox:
# 擴展區域以包含周邊context
x1, y1, x2, y2 = bbox
padding = 20
x1 = max(0, x1 - padding)
y1 = max(0, y1 - padding)
x2 = min(img_width, x2 + padding)
y2 = min(img_height, y2 + padding)
# 確保區域夠大
if (x2 - x1) > 100 and (y2 - y1) > 100:
regions_to_scan.append((x1, y1, x2, y2))
# Strategy 2: 中心區域(品牌通常在中心)
center_x = img_width // 2
center_y = img_height // 2
center_size = min(img_width, img_height) // 2
center_bbox = (
max(0, center_x - center_size // 2),
max(0, center_y - center_size // 2),
min(img_width, center_x + center_size // 2),
min(img_height, center_y + center_size // 2)
)
regions_to_scan.append(center_bbox)
# Strategy 3: 如果沒有顯著區域,使用全圖
if not regions_to_scan:
regions_to_scan.append((0, 0, img_width, img_height))
return regions_to_scan
def compute_brand_confidence_boost(self, brand_name: str,
ocr_results: List[Dict],
base_confidence: float) -> float:
"""
基於 OCR 結果提升品牌信心度
如果 OCR 檢測到品牌名稱,大幅提升信心度
Args:
brand_name: Brand name
ocr_results: OCR detection results
base_confidence: Base confidence from visual matching
Returns:
Boosted confidence score
"""
brand_info = self.prompt_library.get_brand_prompts(brand_name)
if not brand_info:
return base_confidence
aliases = [alias.upper() for alias in brand_info.get('aliases', [])]
max_boost = 0.0
for ocr_item in ocr_results:
text = ocr_item['text'].upper()
ocr_conf = ocr_item['confidence']
for alias in aliases:
# 完全匹配
if alias == text:
max_boost = max(max_boost, 0.40 * ocr_conf) # 最高提升 0.40
# 部分匹配
elif alias in text or text in alias:
if len(alias) > 2: # 避免短字符串誤匹配
max_boost = max(max_boost, 0.25 * ocr_conf)
# 應用提升,但不超過 0.95
boosted_confidence = min(base_confidence + max_boost, 0.95)
return boosted_confidence
print("✓ BrandDetectionOptimizer (performance and accuracy optimizer) defined")
|