File size: 7,512 Bytes
62a2f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import torch
import torch.nn as nn
import torch.nn.functional as F


class AdaptiveConfidenceAggregation(nn.Module):
    """
    Adaptive Confidence Aggregation (ACA) module for enhancing bounding box prediction confidence
    based on geometric properties of point clouds.
    
    Simplified version to avoid matrix multiplication errors.
    """
    def __init__(self, model_cfg):
        super().__init__()
        self.model_cfg = model_cfg
        self.use_density = model_cfg.get('USE_DENSITY', True)
        self.use_curvature = model_cfg.get('USE_CURVATURE', False)  # Disabled by default
        self.use_normals = model_cfg.get('USE_NORMALS', False)  # Disabled by default
        
        # Fixed weights for geometric properties (no learning to avoid matrix multiplication errors)
        self.density_weight = 1.0
        self.curvature_weight = 0.5
        self.normals_weight = 0.3
        
    def forward(self, geometric_features, base_scores=None):
        """
        Args:
            geometric_features: (N, 5) tensor with [density, curvature, normal_x, normal_y, normal_z]
            base_scores: Optional (N,) tensor with base confidence scores to refine
            
        Returns:
            confidence_scores: (N,) tensor with refined confidence scores
        """
        try:
            # Validate input
            if geometric_features is None:
                raise ValueError("geometric_features is None")
                
            # Convert to tensor if it's not already
            if not isinstance(geometric_features, torch.Tensor):
                geometric_features = torch.tensor(geometric_features)
                
            # Ensure it's on the right device
            device = next(self.parameters()).device
            geometric_features = geometric_features.to(device)
            
            # Check if geometric_features has the right shape
            if len(geometric_features.shape) == 1:
                # If it's a 1D tensor, reshape to 2D
                geometric_features = geometric_features.unsqueeze(0)
                
            # Ensure we have at least 5 feature dimensions
            if geometric_features.shape[1] < 5:
                # Pad with zeros if needed
                padding_size = 5 - geometric_features.shape[1]
                padding = torch.zeros(geometric_features.shape[0], padding_size, device=device)
                geometric_features = torch.cat([geometric_features, padding], dim=1)
            elif geometric_features.shape[1] > 5:
                # Slice to first 5 dimensions
                geometric_features = geometric_features[:, :5]
            
            # Handle NaN or Inf values
            geometric_features = torch.nan_to_num(geometric_features, nan=0.0, posinf=1.0, neginf=0.0)
            
            # Simplified confidence computation using fixed weights
            confidence_scores = torch.ones(geometric_features.shape[0], device=device)
            
            # Apply density weight if enabled
            if self.use_density:
                density = geometric_features[:, 0]
                confidence_scores = confidence_scores * (0.5 + 0.5 * density)
            
            # Apply curvature weight if enabled
            if self.use_curvature:
                curvature = geometric_features[:, 1]
                confidence_scores = confidence_scores * (0.8 + 0.2 * (1.0 - curvature))
            
            # Apply normals weight if enabled
            if self.use_normals:
                # Use only the z-component of the normal for simplicity
                normal_z = geometric_features[:, 4]
                confidence_scores = confidence_scores * (0.9 + 0.1 * torch.abs(normal_z))
            
            # If base scores are provided, combine them with our confidence scores
            if base_scores is not None:
                try:
                    # Convert to tensor if it's not already
                    if not isinstance(base_scores, torch.Tensor):
                        base_scores = torch.tensor(base_scores, device=device)
                    else:
                        base_scores = base_scores.to(device)
                        
                    # Ensure base_scores has the right shape
                    if base_scores.dim() == 0:
                        base_scores = base_scores.unsqueeze(0).expand(confidence_scores.shape[0])
                    elif base_scores.dim() > 1:
                        base_scores = base_scores.squeeze()
                    
                    # Ensure base_scores has the same length as confidence_scores
                    if base_scores.shape[0] != confidence_scores.shape[0]:
                        if base_scores.shape[0] > confidence_scores.shape[0]:
                            base_scores = base_scores[:confidence_scores.shape[0]]
                        else:
                            # Pad with ones
                            padding = torch.ones(confidence_scores.shape[0] - base_scores.shape[0], device=device)
                            base_scores = torch.cat([base_scores, padding])
                    
                    # Handle NaN or Inf values
                    base_scores = torch.nan_to_num(base_scores, nan=1.0, posinf=1.0, neginf=0.0)
                    
                    # Combine scores - use a weighted average instead of multiplication
                    confidence_scores = 0.3 * confidence_scores + 0.7 * base_scores
                except Exception as e:
                    print(f"Warning: Error processing base_scores: {e}. Using computed confidence scores only.")
                
            # Final check for NaN or Inf values
            confidence_scores = torch.nan_to_num(confidence_scores, nan=1.0, posinf=1.0, neginf=0.0)
            
            # Ensure confidence scores are in [0, 1]
            confidence_scores = torch.clamp(confidence_scores, 0.0, 1.0)
            
            return confidence_scores
            
        except Exception as e:
            print(f"Warning: Error in AdaptiveConfidenceAggregation: {e}. Using fallback.")
            # Fallback: return base scores or ones
            if base_scores is not None:
                if isinstance(base_scores, torch.Tensor):
                    return base_scores
                else:
                    device = next(self.parameters()).device
                    return torch.ones(1, device=device)
            else:
                device = next(self.parameters()).device
                return torch.ones(1, device=device)
        
    @staticmethod
    def apply_confidence_to_boxes(boxes, confidence_scores, score_thresh=0.1):
        """
        Apply confidence scores to boxes and filter by threshold
        
        Args:
            boxes: (N, 7+C) [x, y, z, dx, dy, dz, heading, ...]
            confidence_scores: (N,) confidence scores
            score_thresh: Threshold for filtering boxes
            
        Returns:
            filtered_boxes: Boxes with scores above threshold
        """
        # Apply confidence scores to box scores (assuming score is at index 7)
        if boxes.shape[0] == 0:
            return boxes
            
        boxes_with_conf = boxes.clone()
        boxes_with_conf[:, 7] = boxes_with_conf[:, 7] * confidence_scores
        
        # Filter boxes by score threshold
        mask = boxes_with_conf[:, 7] >= score_thresh
        filtered_boxes = boxes_with_conf[mask]
        
        return filtered_boxes