File size: 7,188 Bytes
62a2f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from .detector3d_template import Detector3DTemplate
from ..model_utils.aca_utils import AdaptiveConfidenceAggregation
import torch


class SARA3D(Detector3DTemplate):
    def __init__(self, model_cfg, num_class, dataset):
        super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
        self.module_list = self.build_networks()
        
        # Initialize Adaptive Confidence Aggregation module if enabled
        self.use_aca = self.model_cfg.get('USE_ACA', True)
        if self.use_aca:
            self.aca_module = AdaptiveConfidenceAggregation(
                model_cfg=self.model_cfg.get('ACA_CONFIG', {})
            )

    def forward(self, batch_dict):
        # Process through network modules
        for cur_module in self.module_list:
            batch_dict = cur_module(batch_dict)

        if self.training:
            loss, tb_dict, disp_dict = self.get_training_loss()

            ret_dict = {
                'loss': loss
            }
            return ret_dict, tb_dict, disp_dict
        else:
            pred_dicts, recall_dicts = self.post_processing(batch_dict)
            return pred_dicts, recall_dicts

    def get_training_loss(self):
        disp_dict = {}

        loss_rpn, tb_dict = self.dense_head.get_loss()
        tb_dict = {
            'loss_rpn': loss_rpn.item(),
            **tb_dict
        }

        loss = loss_rpn
        return loss, tb_dict, disp_dict

    def post_processing(self, batch_dict):
        post_process_cfg = self.model_cfg.POST_PROCESSING
        batch_size = batch_dict['batch_size']
        final_pred_dict = batch_dict['final_box_dicts']
        recall_dict = {}
        
        # Apply Adaptive Confidence Aggregation if enabled
        if self.use_aca:
            # Check if geometric features are available
            if 'geometric_features' in batch_dict and batch_dict['geometric_features'] is not None:
                try:
                    geometric_features = batch_dict['geometric_features']
                    
                    # Convert to torch tensor if it's numpy array
                    if not isinstance(geometric_features, torch.Tensor):
                        device = next(self.parameters()).device
                        geometric_features = torch.from_numpy(geometric_features).to(device)
                except Exception as e:
                    print(f"Warning: Error processing geometric_features: {e}")
                    # Set to None to use fallback
                    geometric_features = None
                
                for index in range(batch_size):
                    if index in final_pred_dict:
                        pred_boxes = final_pred_dict[index]['pred_boxes']
                        pred_scores = final_pred_dict[index]['pred_scores']
                        
                        # Get geometric features for boxes
                        # This is a simplified approach - in practice, you would need to map
                        # from predicted boxes to the corresponding voxels/points
                        if pred_boxes.shape[0] > 0 and geometric_features.shape[0] > 0:
                            # For simplicity, we'll use a subset of geometric features
                            # In practice, you would need proper mapping from boxes to features
                            num_boxes = pred_boxes.shape[0]
                            num_features = min(num_boxes, geometric_features.shape[0])
                            
                            # Get confidence scores from ACA module
                            box_geometric_features = geometric_features[:num_features]
                            confidence_scores = self.aca_module(box_geometric_features, pred_scores[:num_features])
                            
                            # Apply confidence scores to boxes
                            if num_features < num_boxes:
                                # If we have fewer features than boxes, pad with ones
                                padded_scores = torch.ones_like(pred_scores)
                                padded_scores[:num_features] = confidence_scores
                                confidence_scores = padded_scores
                            
                            # Update scores
                            final_pred_dict[index]['pred_scores'] = confidence_scores
                            final_pred_dict[index]['pred_boxes'][:, 7] = confidence_scores
            else:
                # If geometric features are not available, we can still try to compute them
                # from the predicted boxes and point cloud data
                for index in range(batch_size):
                    if index in final_pred_dict:
                        pred_boxes = final_pred_dict[index]['pred_boxes']
                        pred_scores = final_pred_dict[index]['pred_scores']
                        
                        if pred_boxes.shape[0] > 0:
                            # Create simple geometric features based on box properties
                            # This is a fallback when proper geometric features are not available
                            box_sizes = pred_boxes[:, 3:6]  # width, length, height
                            box_volumes = box_sizes[:, 0] * box_sizes[:, 1] * box_sizes[:, 2]
                            
                            # Normalize volumes
                            normalized_volumes = box_volumes / (box_volumes.max() + 1e-6)
                            
                            # Create simple geometric features: [density, curvature (set to 0), normal (set to [1,0,0])]
                            try:
                                device = pred_boxes.device
                            except:
                                device = next(self.parameters()).device
                                
                            simple_geometric_features = torch.zeros((pred_boxes.shape[0], 5), device=device)
                            simple_geometric_features[:, 0] = normalized_volumes  # Use volume as density
                            simple_geometric_features[:, 2] = 1.0  # Set x-normal to 1
                            
                            # Apply ACA module with these simple features
                            confidence_scores = self.aca_module(simple_geometric_features, pred_scores)
                            
                            # Update scores
                            final_pred_dict[index]['pred_scores'] = confidence_scores
                            final_pred_dict[index]['pred_boxes'][:, 7] = confidence_scores
        
        # Generate recall statistics
        for index in range(batch_size):
            if index in final_pred_dict:
                pred_boxes = final_pred_dict[index]['pred_boxes']
                
                recall_dict = self.generate_recall_record(
                    box_preds=pred_boxes,
                    recall_dict=recall_dict, batch_index=index, data_dict=batch_dict,
                    thresh_list=post_process_cfg.RECALL_THRESH_LIST
                )

        return final_pred_dict, recall_dict