|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class AdaptiveConfidenceAggregation(nn.Module): |
|
|
""" |
|
|
Adaptive Confidence Aggregation (ACA) module for enhancing bounding box prediction confidence |
|
|
based on geometric properties of point clouds. |
|
|
|
|
|
Simplified version to avoid matrix multiplication errors. |
|
|
""" |
|
|
def __init__(self, model_cfg): |
|
|
super().__init__() |
|
|
self.model_cfg = model_cfg |
|
|
self.use_density = model_cfg.get('USE_DENSITY', True) |
|
|
self.use_curvature = model_cfg.get('USE_CURVATURE', False) |
|
|
self.use_normals = model_cfg.get('USE_NORMALS', False) |
|
|
|
|
|
|
|
|
self.density_weight = 1.0 |
|
|
self.curvature_weight = 0.5 |
|
|
self.normals_weight = 0.3 |
|
|
|
|
|
def forward(self, geometric_features, base_scores=None): |
|
|
""" |
|
|
Args: |
|
|
geometric_features: (N, 5) tensor with [density, curvature, normal_x, normal_y, normal_z] |
|
|
base_scores: Optional (N,) tensor with base confidence scores to refine |
|
|
|
|
|
Returns: |
|
|
confidence_scores: (N,) tensor with refined confidence scores |
|
|
""" |
|
|
try: |
|
|
|
|
|
if geometric_features is None: |
|
|
raise ValueError("geometric_features is None") |
|
|
|
|
|
|
|
|
if not isinstance(geometric_features, torch.Tensor): |
|
|
geometric_features = torch.tensor(geometric_features) |
|
|
|
|
|
|
|
|
device = next(self.parameters()).device |
|
|
geometric_features = geometric_features.to(device) |
|
|
|
|
|
|
|
|
if len(geometric_features.shape) == 1: |
|
|
|
|
|
geometric_features = geometric_features.unsqueeze(0) |
|
|
|
|
|
|
|
|
if geometric_features.shape[1] < 5: |
|
|
|
|
|
padding_size = 5 - geometric_features.shape[1] |
|
|
padding = torch.zeros(geometric_features.shape[0], padding_size, device=device) |
|
|
geometric_features = torch.cat([geometric_features, padding], dim=1) |
|
|
elif geometric_features.shape[1] > 5: |
|
|
|
|
|
geometric_features = geometric_features[:, :5] |
|
|
|
|
|
|
|
|
geometric_features = torch.nan_to_num(geometric_features, nan=0.0, posinf=1.0, neginf=0.0) |
|
|
|
|
|
|
|
|
confidence_scores = torch.ones(geometric_features.shape[0], device=device) |
|
|
|
|
|
|
|
|
if self.use_density: |
|
|
density = geometric_features[:, 0] |
|
|
confidence_scores = confidence_scores * (0.5 + 0.5 * density) |
|
|
|
|
|
|
|
|
if self.use_curvature: |
|
|
curvature = geometric_features[:, 1] |
|
|
confidence_scores = confidence_scores * (0.8 + 0.2 * (1.0 - curvature)) |
|
|
|
|
|
|
|
|
if self.use_normals: |
|
|
|
|
|
normal_z = geometric_features[:, 4] |
|
|
confidence_scores = confidence_scores * (0.9 + 0.1 * torch.abs(normal_z)) |
|
|
|
|
|
|
|
|
if base_scores is not None: |
|
|
try: |
|
|
|
|
|
if not isinstance(base_scores, torch.Tensor): |
|
|
base_scores = torch.tensor(base_scores, device=device) |
|
|
else: |
|
|
base_scores = base_scores.to(device) |
|
|
|
|
|
|
|
|
if base_scores.dim() == 0: |
|
|
base_scores = base_scores.unsqueeze(0).expand(confidence_scores.shape[0]) |
|
|
elif base_scores.dim() > 1: |
|
|
base_scores = base_scores.squeeze() |
|
|
|
|
|
|
|
|
if base_scores.shape[0] != confidence_scores.shape[0]: |
|
|
if base_scores.shape[0] > confidence_scores.shape[0]: |
|
|
base_scores = base_scores[:confidence_scores.shape[0]] |
|
|
else: |
|
|
|
|
|
padding = torch.ones(confidence_scores.shape[0] - base_scores.shape[0], device=device) |
|
|
base_scores = torch.cat([base_scores, padding]) |
|
|
|
|
|
|
|
|
base_scores = torch.nan_to_num(base_scores, nan=1.0, posinf=1.0, neginf=0.0) |
|
|
|
|
|
|
|
|
confidence_scores = 0.3 * confidence_scores + 0.7 * base_scores |
|
|
except Exception as e: |
|
|
print(f"Warning: Error processing base_scores: {e}. Using computed confidence scores only.") |
|
|
|
|
|
|
|
|
confidence_scores = torch.nan_to_num(confidence_scores, nan=1.0, posinf=1.0, neginf=0.0) |
|
|
|
|
|
|
|
|
confidence_scores = torch.clamp(confidence_scores, 0.0, 1.0) |
|
|
|
|
|
return confidence_scores |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Warning: Error in AdaptiveConfidenceAggregation: {e}. Using fallback.") |
|
|
|
|
|
if base_scores is not None: |
|
|
if isinstance(base_scores, torch.Tensor): |
|
|
return base_scores |
|
|
else: |
|
|
device = next(self.parameters()).device |
|
|
return torch.ones(1, device=device) |
|
|
else: |
|
|
device = next(self.parameters()).device |
|
|
return torch.ones(1, device=device) |
|
|
|
|
|
@staticmethod |
|
|
def apply_confidence_to_boxes(boxes, confidence_scores, score_thresh=0.1): |
|
|
""" |
|
|
Apply confidence scores to boxes and filter by threshold |
|
|
|
|
|
Args: |
|
|
boxes: (N, 7+C) [x, y, z, dx, dy, dz, heading, ...] |
|
|
confidence_scores: (N,) confidence scores |
|
|
score_thresh: Threshold for filtering boxes |
|
|
|
|
|
Returns: |
|
|
filtered_boxes: Boxes with scores above threshold |
|
|
""" |
|
|
|
|
|
if boxes.shape[0] == 0: |
|
|
return boxes |
|
|
|
|
|
boxes_with_conf = boxes.clone() |
|
|
boxes_with_conf[:, 7] = boxes_with_conf[:, 7] * confidence_scores |
|
|
|
|
|
|
|
|
mask = boxes_with_conf[:, 7] >= score_thresh |
|
|
filtered_boxes = boxes_with_conf[mask] |
|
|
|
|
|
return filtered_boxes |