yxc97's picture
Upload folder using huggingface_hub
62a2f1c verified
from .detector3d_template import Detector3DTemplate
from ..model_utils.aca_utils import AdaptiveConfidenceAggregation
import torch
class SARA3D(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_list = self.build_networks()
# Initialize Adaptive Confidence Aggregation module if enabled
self.use_aca = self.model_cfg.get('USE_ACA', True)
if self.use_aca:
self.aca_module = AdaptiveConfidenceAggregation(
model_cfg=self.model_cfg.get('ACA_CONFIG', {})
)
def forward(self, batch_dict):
# Process through network modules
for cur_module in self.module_list:
batch_dict = cur_module(batch_dict)
if self.training:
loss, tb_dict, disp_dict = self.get_training_loss()
ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts
def get_training_loss(self):
disp_dict = {}
loss_rpn, tb_dict = self.dense_head.get_loss()
tb_dict = {
'loss_rpn': loss_rpn.item(),
**tb_dict
}
loss = loss_rpn
return loss, tb_dict, disp_dict
def post_processing(self, batch_dict):
post_process_cfg = self.model_cfg.POST_PROCESSING
batch_size = batch_dict['batch_size']
final_pred_dict = batch_dict['final_box_dicts']
recall_dict = {}
# Apply Adaptive Confidence Aggregation if enabled
if self.use_aca:
# Check if geometric features are available
if 'geometric_features' in batch_dict and batch_dict['geometric_features'] is not None:
try:
geometric_features = batch_dict['geometric_features']
# Convert to torch tensor if it's numpy array
if not isinstance(geometric_features, torch.Tensor):
device = next(self.parameters()).device
geometric_features = torch.from_numpy(geometric_features).to(device)
except Exception as e:
print(f"Warning: Error processing geometric_features: {e}")
# Set to None to use fallback
geometric_features = None
for index in range(batch_size):
if index in final_pred_dict:
pred_boxes = final_pred_dict[index]['pred_boxes']
pred_scores = final_pred_dict[index]['pred_scores']
# Get geometric features for boxes
# This is a simplified approach - in practice, you would need to map
# from predicted boxes to the corresponding voxels/points
if pred_boxes.shape[0] > 0 and geometric_features.shape[0] > 0:
# For simplicity, we'll use a subset of geometric features
# In practice, you would need proper mapping from boxes to features
num_boxes = pred_boxes.shape[0]
num_features = min(num_boxes, geometric_features.shape[0])
# Get confidence scores from ACA module
box_geometric_features = geometric_features[:num_features]
confidence_scores = self.aca_module(box_geometric_features, pred_scores[:num_features])
# Apply confidence scores to boxes
if num_features < num_boxes:
# If we have fewer features than boxes, pad with ones
padded_scores = torch.ones_like(pred_scores)
padded_scores[:num_features] = confidence_scores
confidence_scores = padded_scores
# Update scores
final_pred_dict[index]['pred_scores'] = confidence_scores
final_pred_dict[index]['pred_boxes'][:, 7] = confidence_scores
else:
# If geometric features are not available, we can still try to compute them
# from the predicted boxes and point cloud data
for index in range(batch_size):
if index in final_pred_dict:
pred_boxes = final_pred_dict[index]['pred_boxes']
pred_scores = final_pred_dict[index]['pred_scores']
if pred_boxes.shape[0] > 0:
# Create simple geometric features based on box properties
# This is a fallback when proper geometric features are not available
box_sizes = pred_boxes[:, 3:6] # width, length, height
box_volumes = box_sizes[:, 0] * box_sizes[:, 1] * box_sizes[:, 2]
# Normalize volumes
normalized_volumes = box_volumes / (box_volumes.max() + 1e-6)
# Create simple geometric features: [density, curvature (set to 0), normal (set to [1,0,0])]
try:
device = pred_boxes.device
except:
device = next(self.parameters()).device
simple_geometric_features = torch.zeros((pred_boxes.shape[0], 5), device=device)
simple_geometric_features[:, 0] = normalized_volumes # Use volume as density
simple_geometric_features[:, 2] = 1.0 # Set x-normal to 1
# Apply ACA module with these simple features
confidence_scores = self.aca_module(simple_geometric_features, pred_scores)
# Update scores
final_pred_dict[index]['pred_scores'] = confidence_scores
final_pred_dict[index]['pred_boxes'][:, 7] = confidence_scores
# Generate recall statistics
for index in range(batch_size):
if index in final_pred_dict:
pred_boxes = final_pred_dict[index]['pred_boxes']
recall_dict = self.generate_recall_record(
box_preds=pred_boxes,
recall_dict=recall_dict, batch_index=index, data_dict=batch_dict,
thresh_list=post_process_cfg.RECALL_THRESH_LIST
)
return final_pred_dict, recall_dict