File size: 3,516 Bytes
6a3bd1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import torch
import numpy as np
from PIL import Image
import cv2
from typing import List, Dict
import torchvision.transforms as transforms
class SaliencyDetectionManager:
"""Visual saliency detection using U2-Net"""
def __init__(self):
print("Loading U2-Net model...")
try:
from torchvision.models.segmentation import deeplabv3_resnet50
self.model = deeplabv3_resnet50(pretrained=True)
self.model.eval()
if torch.cuda.is_available():
self.model = self.model.cuda()
except Exception as e:
print(f"Warning: Cannot load deep learning model, using fallback: {e}")
self.model = None
self.threshold = 0.5
self.min_area = 1600
self.min_saliency = 0.6
self.transform = transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
print("✓ SaliencyDetectionManager initialized")
def detect_salient_regions(self, image: Image.Image) -> List[Dict]:
"""Detect salient regions"""
img_array = np.array(image)
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
regions = []
height, width = img_array.shape[:2]
for contour in contours:
area = cv2.contourArea(contour)
if area < self.min_area:
continue
x, y, w, h = cv2.boundingRect(contour)
bbox = [float(x), float(y), float(x + w), float(y + h)]
region_img = image.crop(bbox)
regions.append({
'bbox': bbox,
'area': area,
'saliency_score': min(area / (width * height), 1.0),
'image': region_img
})
regions = sorted(regions, key=lambda x: x['saliency_score'], reverse=True)
return regions[:10]
def extract_unknown_regions(self, salient_regions: List[Dict], yolo_detections: List[Dict]) -> List[Dict]:
"""Extract salient regions not detected by YOLO"""
unknown_regions = []
for region in salient_regions:
max_iou = 0.0
for det in yolo_detections:
iou = self._calculate_iou(region['bbox'], det['bbox'])
max_iou = max(max_iou, iou)
if max_iou < 0.3:
unknown_regions.append(region)
return unknown_regions
def _calculate_iou(self, box1: List[float], box2: List[float]) -> float:
"""Calculate IoU (Intersection over Union)"""
x1_min, y1_min, x1_max, y1_max = box1
x2_min, y2_min, x2_max, y2_max = box2
inter_xmin = max(x1_min, x2_min)
inter_ymin = max(y1_min, y2_min)
inter_xmax = min(x1_max, x2_max)
inter_ymax = min(y1_max, y2_max)
if inter_xmax < inter_xmin or inter_ymax < inter_ymin:
return 0.0
inter_area = (inter_xmax - inter_xmin) * (inter_ymax - inter_ymin)
box1_area = (x1_max - x1_min) * (y1_max - y1_min)
box2_area = (x2_max - x2_min) * (y2_max - y2_min)
union_area = box1_area + box2_area - inter_area
return inter_area / union_area if union_area > 0 else 0.0
print("✓ SaliencyDetectionManager defined")
|