| from .detector3d_template import Detector3DTemplate | |
| class CenterPoint(Detector3DTemplate): | |
| def __init__(self, model_cfg, num_class, dataset): | |
| super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset) | |
| self.module_list = self.build_networks() | |
| def forward(self, batch_dict): | |
| for cur_module in self.module_list: | |
| batch_dict = cur_module(batch_dict) | |
| if self.training: | |
| loss, tb_dict, disp_dict = self.get_training_loss() | |
| ret_dict = { | |
| 'loss': loss | |
| } | |
| return ret_dict, tb_dict, disp_dict | |
| else: | |
| pred_dicts, recall_dicts = self.post_processing(batch_dict) | |
| return pred_dicts, recall_dicts | |
| def get_training_loss(self): | |
| disp_dict = {} | |
| loss_rpn, tb_dict = self.dense_head.get_loss() | |
| tb_dict = { | |
| 'loss_rpn': loss_rpn.item(), | |
| **tb_dict | |
| } | |
| loss = loss_rpn | |
| return loss, tb_dict, disp_dict | |
| def post_processing(self, batch_dict): | |
| post_process_cfg = self.model_cfg.POST_PROCESSING | |
| batch_size = batch_dict['batch_size'] | |
| final_pred_dict = batch_dict['final_box_dicts'] | |
| recall_dict = {} | |
| for index in range(batch_size): | |
| pred_boxes = final_pred_dict[index]['pred_boxes'] | |
| recall_dict = self.generate_recall_record( | |
| box_preds=pred_boxes, | |
| recall_dict=recall_dict, batch_index=index, data_dict=batch_dict, | |
| thresh_list=post_process_cfg.RECALL_THRESH_LIST | |
| ) | |
| return final_pred_dict, recall_dict | |