File size: 1,553 Bytes
62a2f1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from collections import namedtuple
import numpy as np
import torch
from .detectors import build_detector
try:
import kornia
except:
pass
# print('Warning: kornia is not installed. This package is only required by CaDDN')
def build_network(model_cfg, num_class, dataset):
model = build_detector(
model_cfg=model_cfg, num_class=num_class, dataset=dataset
)
return model
def load_data_to_gpu(batch_dict):
for key, val in batch_dict.items():
if key == 'camera_imgs':
batch_dict[key] = val.cuda()
elif not isinstance(val, np.ndarray):
continue
elif key in ['frame_id', 'metadata', 'calib', 'image_paths','ori_shape','img_process_infos']:
continue
elif key in ['images']:
batch_dict[key] = kornia.image_to_tensor(val).float().cuda().contiguous()
elif key in ['image_shape']:
batch_dict[key] = torch.from_numpy(val).int().cuda()
else:
batch_dict[key] = torch.from_numpy(val).float().cuda()
def model_fn_decorator():
ModelReturn = namedtuple('ModelReturn', ['loss', 'tb_dict', 'disp_dict'])
def model_func(model, batch_dict):
load_data_to_gpu(batch_dict)
ret_dict, tb_dict, disp_dict = model(batch_dict)
loss = ret_dict['loss'].mean()
if hasattr(model, 'update_global_step'):
model.update_global_step()
else:
model.module.update_global_step()
return ModelReturn(loss, tb_dict, disp_dict)
return model_func
|