File size: 7,188 Bytes
62a2f1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from .detector3d_template import Detector3DTemplate
from ..model_utils.aca_utils import AdaptiveConfidenceAggregation
import torch
class SARA3D(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_list = self.build_networks()
# Initialize Adaptive Confidence Aggregation module if enabled
self.use_aca = self.model_cfg.get('USE_ACA', True)
if self.use_aca:
self.aca_module = AdaptiveConfidenceAggregation(
model_cfg=self.model_cfg.get('ACA_CONFIG', {})
)
def forward(self, batch_dict):
# Process through network modules
for cur_module in self.module_list:
batch_dict = cur_module(batch_dict)
if self.training:
loss, tb_dict, disp_dict = self.get_training_loss()
ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts
def get_training_loss(self):
disp_dict = {}
loss_rpn, tb_dict = self.dense_head.get_loss()
tb_dict = {
'loss_rpn': loss_rpn.item(),
**tb_dict
}
loss = loss_rpn
return loss, tb_dict, disp_dict
def post_processing(self, batch_dict):
post_process_cfg = self.model_cfg.POST_PROCESSING
batch_size = batch_dict['batch_size']
final_pred_dict = batch_dict['final_box_dicts']
recall_dict = {}
# Apply Adaptive Confidence Aggregation if enabled
if self.use_aca:
# Check if geometric features are available
if 'geometric_features' in batch_dict and batch_dict['geometric_features'] is not None:
try:
geometric_features = batch_dict['geometric_features']
# Convert to torch tensor if it's numpy array
if not isinstance(geometric_features, torch.Tensor):
device = next(self.parameters()).device
geometric_features = torch.from_numpy(geometric_features).to(device)
except Exception as e:
print(f"Warning: Error processing geometric_features: {e}")
# Set to None to use fallback
geometric_features = None
for index in range(batch_size):
if index in final_pred_dict:
pred_boxes = final_pred_dict[index]['pred_boxes']
pred_scores = final_pred_dict[index]['pred_scores']
# Get geometric features for boxes
# This is a simplified approach - in practice, you would need to map
# from predicted boxes to the corresponding voxels/points
if pred_boxes.shape[0] > 0 and geometric_features.shape[0] > 0:
# For simplicity, we'll use a subset of geometric features
# In practice, you would need proper mapping from boxes to features
num_boxes = pred_boxes.shape[0]
num_features = min(num_boxes, geometric_features.shape[0])
# Get confidence scores from ACA module
box_geometric_features = geometric_features[:num_features]
confidence_scores = self.aca_module(box_geometric_features, pred_scores[:num_features])
# Apply confidence scores to boxes
if num_features < num_boxes:
# If we have fewer features than boxes, pad with ones
padded_scores = torch.ones_like(pred_scores)
padded_scores[:num_features] = confidence_scores
confidence_scores = padded_scores
# Update scores
final_pred_dict[index]['pred_scores'] = confidence_scores
final_pred_dict[index]['pred_boxes'][:, 7] = confidence_scores
else:
# If geometric features are not available, we can still try to compute them
# from the predicted boxes and point cloud data
for index in range(batch_size):
if index in final_pred_dict:
pred_boxes = final_pred_dict[index]['pred_boxes']
pred_scores = final_pred_dict[index]['pred_scores']
if pred_boxes.shape[0] > 0:
# Create simple geometric features based on box properties
# This is a fallback when proper geometric features are not available
box_sizes = pred_boxes[:, 3:6] # width, length, height
box_volumes = box_sizes[:, 0] * box_sizes[:, 1] * box_sizes[:, 2]
# Normalize volumes
normalized_volumes = box_volumes / (box_volumes.max() + 1e-6)
# Create simple geometric features: [density, curvature (set to 0), normal (set to [1,0,0])]
try:
device = pred_boxes.device
except:
device = next(self.parameters()).device
simple_geometric_features = torch.zeros((pred_boxes.shape[0], 5), device=device)
simple_geometric_features[:, 0] = normalized_volumes # Use volume as density
simple_geometric_features[:, 2] = 1.0 # Set x-normal to 1
# Apply ACA module with these simple features
confidence_scores = self.aca_module(simple_geometric_features, pred_scores)
# Update scores
final_pred_dict[index]['pred_scores'] = confidence_scores
final_pred_dict[index]['pred_boxes'][:, 7] = confidence_scores
# Generate recall statistics
for index in range(batch_size):
if index in final_pred_dict:
pred_boxes = final_pred_dict[index]['pred_boxes']
recall_dict = self.generate_recall_record(
box_preds=pred_boxes,
recall_dict=recall_dict, batch_index=index, data_dict=batch_dict,
thresh_list=post_process_cfg.RECALL_THRESH_LIST
)
return final_pred_dict, recall_dict |