|
|
|
|
|
import torch |
|
|
from PIL import Image |
|
|
from typing import Dict, List, Tuple |
|
|
import numpy as np |
|
|
|
|
|
class BrandDetectionOptimizer: |
|
|
""" |
|
|
智能品牌檢測優化器 - 性能與準確度平衡 |
|
|
通過快速預篩選減少不必要的深度檢測 |
|
|
""" |
|
|
|
|
|
def __init__(self, clip_manager, ocr_manager, prompt_library): |
|
|
self.clip_manager = clip_manager |
|
|
self.ocr_manager = ocr_manager |
|
|
self.prompt_library = prompt_library |
|
|
|
|
|
def quick_brand_prescreening(self, image: Image.Image) -> List[str]: |
|
|
""" |
|
|
快速品牌預篩選 - 只檢測最可能的品牌類別 |
|
|
大幅減少需要深度檢測的品牌數量 |
|
|
|
|
|
Returns: |
|
|
List of brand names that are likely present |
|
|
""" |
|
|
|
|
|
likely_brands = set() |
|
|
|
|
|
ocr_results = self.ocr_manager.extract_text(image, use_brand_preprocessing=True) |
|
|
|
|
|
for ocr_item in ocr_results: |
|
|
text = ocr_item['text'].upper() |
|
|
|
|
|
|
|
|
if len(text) < 2: |
|
|
continue |
|
|
|
|
|
|
|
|
for brand_name, brand_info in self.prompt_library.get_all_brands().items(): |
|
|
aliases = [alias.upper() for alias in brand_info.get('aliases', [])] |
|
|
|
|
|
|
|
|
for alias in aliases: |
|
|
|
|
|
if alias == text: |
|
|
likely_brands.add(brand_name) |
|
|
break |
|
|
|
|
|
elif len(alias) >= 3: |
|
|
if alias in text and len(alias) / len(text) > 0.6: |
|
|
likely_brands.add(brand_name) |
|
|
break |
|
|
elif text in alias and len(text) / len(alias) > 0.6: |
|
|
likely_brands.add(brand_name) |
|
|
break |
|
|
|
|
|
|
|
|
category_prompts = { |
|
|
'luxury': 'luxury brand product with monogram pattern and leather details', |
|
|
'sportswear': 'sportswear brand product with athletic logo and swoosh design', |
|
|
'tech': 'technology brand product with minimalist design and metal finish', |
|
|
'automotive': 'luxury car brand with distinctive grille and emblem', |
|
|
'watches': 'luxury watch with distinctive dial and brand logo', |
|
|
'fashion': 'fashion brand product with signature pattern or logo' |
|
|
} |
|
|
|
|
|
category_scores = self.clip_manager.classify_zero_shot( |
|
|
image, list(category_prompts.values()) |
|
|
) |
|
|
|
|
|
|
|
|
sorted_categories = sorted( |
|
|
category_scores.items(), key=lambda x: x[1], reverse=True |
|
|
)[:2] |
|
|
|
|
|
|
|
|
category_mapping = {v: k for k, v in category_prompts.items()} |
|
|
|
|
|
for prompt_text, score in sorted_categories: |
|
|
if score > 0.30: |
|
|
category = category_mapping[prompt_text] |
|
|
|
|
|
category_brands = self.prompt_library.get_brands_by_category(category) |
|
|
likely_brands.update(category_brands.keys()) |
|
|
|
|
|
|
|
|
|
|
|
if not likely_brands: |
|
|
|
|
|
default_brands = ['Louis Vuitton', 'Gucci', 'Nike'] |
|
|
likely_brands.update(default_brands) |
|
|
|
|
|
|
|
|
return list(likely_brands) |
|
|
|
|
|
def smart_region_selection(self, image: Image.Image, |
|
|
saliency_regions: List[Dict]) -> List[Tuple[int, int, int, int]]: |
|
|
""" |
|
|
智能區域選擇 - 只掃描有品牌可能性的區域 |
|
|
替代低效的網格掃描 |
|
|
|
|
|
Args: |
|
|
image: PIL Image |
|
|
saliency_regions: Saliency detection results |
|
|
|
|
|
Returns: |
|
|
List of bboxes (x1, y1, x2, y2) to scan |
|
|
""" |
|
|
regions_to_scan = [] |
|
|
img_width, img_height = image.size |
|
|
|
|
|
|
|
|
if saliency_regions: |
|
|
for region in saliency_regions[:3]: |
|
|
bbox = region.get('bbox') |
|
|
if bbox: |
|
|
|
|
|
x1, y1, x2, y2 = bbox |
|
|
padding = 20 |
|
|
x1 = max(0, x1 - padding) |
|
|
y1 = max(0, y1 - padding) |
|
|
x2 = min(img_width, x2 + padding) |
|
|
y2 = min(img_height, y2 + padding) |
|
|
|
|
|
|
|
|
if (x2 - x1) > 100 and (y2 - y1) > 100: |
|
|
regions_to_scan.append((x1, y1, x2, y2)) |
|
|
|
|
|
|
|
|
center_x = img_width // 2 |
|
|
center_y = img_height // 2 |
|
|
center_size = min(img_width, img_height) // 2 |
|
|
|
|
|
center_bbox = ( |
|
|
max(0, center_x - center_size // 2), |
|
|
max(0, center_y - center_size // 2), |
|
|
min(img_width, center_x + center_size // 2), |
|
|
min(img_height, center_y + center_size // 2) |
|
|
) |
|
|
regions_to_scan.append(center_bbox) |
|
|
|
|
|
|
|
|
if not regions_to_scan: |
|
|
regions_to_scan.append((0, 0, img_width, img_height)) |
|
|
|
|
|
return regions_to_scan |
|
|
|
|
|
def compute_brand_confidence_boost(self, brand_name: str, |
|
|
ocr_results: List[Dict], |
|
|
base_confidence: float) -> float: |
|
|
""" |
|
|
基於 OCR 結果提升品牌信心度 |
|
|
如果 OCR 檢測到品牌名稱,大幅提升信心度 |
|
|
|
|
|
Args: |
|
|
brand_name: Brand name |
|
|
ocr_results: OCR detection results |
|
|
base_confidence: Base confidence from visual matching |
|
|
|
|
|
Returns: |
|
|
Boosted confidence score |
|
|
""" |
|
|
brand_info = self.prompt_library.get_brand_prompts(brand_name) |
|
|
if not brand_info: |
|
|
return base_confidence |
|
|
|
|
|
aliases = [alias.upper() for alias in brand_info.get('aliases', [])] |
|
|
|
|
|
max_boost = 0.0 |
|
|
for ocr_item in ocr_results: |
|
|
text = ocr_item['text'].upper() |
|
|
ocr_conf = ocr_item['confidence'] |
|
|
|
|
|
for alias in aliases: |
|
|
|
|
|
if alias == text: |
|
|
max_boost = max(max_boost, 0.40 * ocr_conf) |
|
|
|
|
|
elif alias in text or text in alias: |
|
|
if len(alias) > 2: |
|
|
max_boost = max(max_boost, 0.25 * ocr_conf) |
|
|
|
|
|
|
|
|
boosted_confidence = min(base_confidence + max_boost, 0.95) |
|
|
return boosted_confidence |
|
|
|
|
|
print("✓ BrandDetectionOptimizer (performance and accuracy optimizer) defined") |
|
|
|