|
|
import mxnet as mx |
|
|
import mxnet.ndarray as nd |
|
|
import nowcasting.config as cfg |
|
|
from nowcasting.ops import reset_regs |
|
|
from nowcasting.operators.common import grid_generator |
|
|
from nowcasting.operators import * |
|
|
from nowcasting.ops import * |
|
|
from nowcasting.prediction_base_factory import PredictionBaseFactory |
|
|
from nowcasting.operators.transformations import DFN |
|
|
from nowcasting.my_module import MyModule |
|
|
|
|
|
|
|
|
def get_encoder_forecaster_rnn_blocks(batch_size): |
|
|
encoder_rnn_blocks = [] |
|
|
forecaster_rnn_blocks = [] |
|
|
gan_rnn_blocks = [] |
|
|
CONFIG = cfg.MODEL.ENCODER_FORECASTER.RNN_BLOCKS |
|
|
for vec, block_prefix in [(encoder_rnn_blocks, "ebrnn"), |
|
|
(forecaster_rnn_blocks, "fbrnn"), |
|
|
(gan_rnn_blocks, "dbrnn")]: |
|
|
for i in range(len(CONFIG.NUM_FILTER)): |
|
|
name = "%s%d" % (block_prefix, i + 1) |
|
|
if CONFIG.LAYER_TYPE[i] == "ConvGRU": |
|
|
rnn_block = BaseStackRNN(base_rnn_class=ConvGRU, |
|
|
stack_num=CONFIG.STACK_NUM[i], |
|
|
name=name, |
|
|
residual_connection=CONFIG.RES_CONNECTION, |
|
|
num_filter=CONFIG.NUM_FILTER[i], |
|
|
b_h_w=(batch_size, |
|
|
cfg.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE[i], |
|
|
cfg.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE[i]), |
|
|
h2h_kernel=CONFIG.H2H_KERNEL[i], |
|
|
h2h_dilate=CONFIG.H2H_DILATE[i], |
|
|
i2h_kernel=CONFIG.I2H_KERNEL[i], |
|
|
i2h_pad=CONFIG.I2H_PAD[i], |
|
|
act_type=cfg.MODEL.RNN_ACT_TYPE) |
|
|
elif CONFIG.LAYER_TYPE[i] == "TrajGRU": |
|
|
rnn_block = BaseStackRNN(base_rnn_class=TrajGRU, |
|
|
stack_num=CONFIG.STACK_NUM[i], |
|
|
name=name, |
|
|
L=CONFIG.L[i], |
|
|
residual_connection=CONFIG.RES_CONNECTION, |
|
|
num_filter=CONFIG.NUM_FILTER[i], |
|
|
b_h_w=(batch_size, |
|
|
cfg.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE[i], |
|
|
cfg.MODEL.ENCODER_FORECASTER.FEATMAP_SIZE[i]), |
|
|
h2h_kernel=CONFIG.H2H_KERNEL[i], |
|
|
h2h_dilate=CONFIG.H2H_DILATE[i], |
|
|
i2h_kernel=CONFIG.I2H_KERNEL[i], |
|
|
i2h_pad=CONFIG.I2H_PAD[i], |
|
|
act_type=cfg.MODEL.RNN_ACT_TYPE) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
vec.append(rnn_block) |
|
|
return encoder_rnn_blocks, forecaster_rnn_blocks, gan_rnn_blocks |
|
|
|
|
|
class EncoderForecasterBaseFactory(PredictionBaseFactory): |
|
|
def __init__(self, |
|
|
batch_size, |
|
|
in_seq_len, |
|
|
out_seq_len, |
|
|
height, |
|
|
width, |
|
|
ctx_num=1, |
|
|
name="encoder_forecaster"): |
|
|
super(EncoderForecasterBaseFactory, self).__init__(batch_size=batch_size, |
|
|
in_seq_len=in_seq_len, |
|
|
out_seq_len=out_seq_len, |
|
|
height=height, |
|
|
width=width, |
|
|
name=name) |
|
|
self._ctx_num = ctx_num |
|
|
|
|
|
def _init_rnn(self): |
|
|
self._encoder_rnn_blocks, self._forecaster_rnn_blocks, self._gan_rnn_blocks =\ |
|
|
get_encoder_forecaster_rnn_blocks(batch_size=self._batch_size) |
|
|
return self._encoder_rnn_blocks + self._forecaster_rnn_blocks + self._gan_rnn_blocks |
|
|
|
|
|
@property |
|
|
def init_encoder_state_info(self): |
|
|
init_state_info = [] |
|
|
for block in self._encoder_rnn_blocks: |
|
|
for state in block.init_state_vars(): |
|
|
init_state_info.append({'name': state.name, |
|
|
'shape': state.attr('__shape__'), |
|
|
'__layout__': state.list_attr()['__layout__']}) |
|
|
return init_state_info |
|
|
|
|
|
@property |
|
|
def init_forecaster_state_info(self): |
|
|
init_state_info = [] |
|
|
for block in self._forecaster_rnn_blocks: |
|
|
for state in block.init_state_vars(): |
|
|
init_state_info.append({'name': state.name, |
|
|
'shape': state.attr('__shape__'), |
|
|
'__layout__': state.list_attr()['__layout__']}) |
|
|
return init_state_info |
|
|
|
|
|
@property |
|
|
def init_gan_state_info(self): |
|
|
init_gan_state_info = [] |
|
|
for block in self._gan_rnn_blocks: |
|
|
for state in block.init_state_vars(): |
|
|
init_gan_state_info.append({'name': state.name, |
|
|
'shape': state.attr('__shape__'), |
|
|
'__layout__': state.list_attr()['__layout__']}) |
|
|
return init_gan_state_info |
|
|
|
|
|
def stack_rnn_encode(self, data): |
|
|
CONFIG = cfg.MODEL.ENCODER_FORECASTER |
|
|
pre_encoded_data = self._pre_encode_frame(frame_data=data, seqlen=self._in_seq_len) |
|
|
reshape_data = mx.sym.Reshape(pre_encoded_data, shape=(-1, 0, 0, 0), reverse=True) |
|
|
|
|
|
|
|
|
conv1 = conv2d_act(data=reshape_data, |
|
|
num_filter=CONFIG.FIRST_CONV[0], |
|
|
kernel=(CONFIG.FIRST_CONV[1], CONFIG.FIRST_CONV[1]), |
|
|
stride=(CONFIG.FIRST_CONV[2], CONFIG.FIRST_CONV[2]), |
|
|
pad=(CONFIG.FIRST_CONV[3], CONFIG.FIRST_CONV[3]), |
|
|
act_type=cfg.MODEL.CNN_ACT_TYPE, |
|
|
name="econv1") |
|
|
rnn_block_num = len(CONFIG.RNN_BLOCKS.NUM_FILTER) |
|
|
encoder_rnn_block_states = [] |
|
|
for i in range(rnn_block_num): |
|
|
if i == 0: |
|
|
inputs = conv1 |
|
|
else: |
|
|
inputs = downsample |
|
|
rnn_out, states = self._encoder_rnn_blocks[i].unroll( |
|
|
length=self._in_seq_len, |
|
|
inputs=inputs, |
|
|
begin_states=None, |
|
|
ret_mid=False) |
|
|
encoder_rnn_block_states.append(states) |
|
|
if i < rnn_block_num - 1: |
|
|
downsample = downsample_module(data=rnn_out[-1], |
|
|
num_filter=CONFIG.RNN_BLOCKS.NUM_FILTER[i + 1], |
|
|
kernel=(CONFIG.DOWNSAMPLE[i][0], |
|
|
CONFIG.DOWNSAMPLE[i][0]), |
|
|
stride=(CONFIG.DOWNSAMPLE[i][1], |
|
|
CONFIG.DOWNSAMPLE[i][1]), |
|
|
pad=(CONFIG.DOWNSAMPLE[i][2], |
|
|
CONFIG.DOWNSAMPLE[i][2]), |
|
|
b_h_w=(self._batch_size, |
|
|
CONFIG.FEATMAP_SIZE[i + 1], |
|
|
CONFIG.FEATMAP_SIZE[i + 1]), |
|
|
name="edown%d" %(i + 1)) |
|
|
return encoder_rnn_block_states |
|
|
|
|
|
def stack_rnn_forecast(self, block_state_list, last_frame): |
|
|
CONFIG = cfg.MODEL.ENCODER_FORECASTER |
|
|
block_state_list = [self._forecaster_rnn_blocks[i].to_split(block_state_list[i]) |
|
|
for i in range(len(self._forecaster_rnn_blocks))] |
|
|
rnn_block_num = len(CONFIG.RNN_BLOCKS.NUM_FILTER) |
|
|
rnn_block_outputs = [] |
|
|
|
|
|
curr_inputs = None |
|
|
for i in range(rnn_block_num - 1, -1, -1): |
|
|
rnn_out, rnn_state = self._forecaster_rnn_blocks[i].unroll( |
|
|
length=self._out_seq_len, inputs=curr_inputs, |
|
|
begin_states=block_state_list[i][::-1], |
|
|
ret_mid=False) |
|
|
rnn_block_outputs.append(rnn_out) |
|
|
if i > 0: |
|
|
upsample = upsample_module(data=rnn_out[-1], |
|
|
num_filter=CONFIG.RNN_BLOCKS.NUM_FILTER[i], |
|
|
kernel=(CONFIG.UPSAMPLE[i - 1][0], |
|
|
CONFIG.UPSAMPLE[i - 1][0]), |
|
|
stride=(CONFIG.UPSAMPLE[i - 1][1], |
|
|
CONFIG.UPSAMPLE[i - 1][1]), |
|
|
pad=(CONFIG.UPSAMPLE[i - 1][2], |
|
|
CONFIG.UPSAMPLE[i - 1][2]), |
|
|
b_h_w=(self._batch_size, CONFIG.FEATMAP_SIZE[i - 1]), |
|
|
name="fup%d" %i) |
|
|
curr_inputs = upsample |
|
|
|
|
|
if cfg.MODEL.OUT_TYPE == "DFN": |
|
|
concat_fbrnn1_out = mx.sym.concat(*rnn_out[-1], dim=0) |
|
|
dynamic_filter = deconv2d(data=concat_fbrnn1_out, |
|
|
num_filter=121, |
|
|
kernel=(CONFIG.LAST_DECONV[1], CONFIG.LAST_DECONV[1]), |
|
|
stride=(CONFIG.LAST_DECONV[2], CONFIG.LAST_DECONV[2]), |
|
|
pad=(CONFIG.LAST_DECONV[3], CONFIG.LAST_DECONV[3])) |
|
|
flow = dynamic_filter |
|
|
dynamic_filter = mx.sym.SliceChannel(dynamic_filter, axis=0, num_outputs=self._out_seq_len) |
|
|
prev_frame = last_frame |
|
|
preds = [] |
|
|
for i in range(self._out_seq_len): |
|
|
pred_ele = DFN(data=prev_frame, local_kernels=dynamic_filter[i], K=11, batch_size=self._batch_size) |
|
|
preds.append(pred_ele) |
|
|
prev_frame = pred_ele |
|
|
pred = mx.sym.concat(*preds, dim=0) |
|
|
elif cfg.MODEL.OUT_TYPE == "direct": |
|
|
flow = None |
|
|
deconv1 = deconv2d_act(data=mx.sym.concat(*rnn_out[-1], dim=0), |
|
|
num_filter=CONFIG.LAST_DECONV[0], |
|
|
kernel=(CONFIG.LAST_DECONV[1], CONFIG.LAST_DECONV[1]), |
|
|
stride=(CONFIG.LAST_DECONV[2], CONFIG.LAST_DECONV[2]), |
|
|
pad=(CONFIG.LAST_DECONV[3], CONFIG.LAST_DECONV[3]), |
|
|
act_type=cfg.MODEL.CNN_ACT_TYPE, |
|
|
name="fdeconv1") |
|
|
conv_final = conv2d_act(data=deconv1, |
|
|
num_filter=CONFIG.LAST_DECONV[0], |
|
|
kernel=(3, 3), stride=(1, 1), pad=(1, 1), |
|
|
act_type=cfg.MODEL.CNN_ACT_TYPE, name="conv_final") |
|
|
pred = conv2d(data=conv_final, |
|
|
num_filter=1, kernel=(1, 1), name="out") |
|
|
else: |
|
|
raise NotImplementedError |
|
|
pred = mx.sym.Reshape(pred, |
|
|
shape=(self._out_seq_len, self._batch_size, |
|
|
1, self._height, self._width), |
|
|
__layout__="TNCHW") |
|
|
return pred, flow |
|
|
|
|
|
def encoder_sym(self): |
|
|
self.reset_all() |
|
|
data = mx.sym.Variable('data') |
|
|
block_state_list = self.stack_rnn_encode(data=data) |
|
|
states = [] |
|
|
for i, rnn_block in enumerate(self._encoder_rnn_blocks): |
|
|
states.extend(rnn_block.flatten_add_layout(block_state_list[i])) |
|
|
return mx.sym.Group(states) |
|
|
|
|
|
def encoder_data_desc(self): |
|
|
ret = list() |
|
|
ret.append(mx.io.DataDesc(name='data', |
|
|
shape=(self._in_seq_len, |
|
|
self._batch_size * self._ctx_num, |
|
|
1, |
|
|
self._height, |
|
|
self._width), |
|
|
layout="TNCHW")) |
|
|
for info in self.init_encoder_state_info: |
|
|
state_shape = safe_eval(info['shape']) |
|
|
assert info['__layout__'].find('N') == 0,\ |
|
|
"Layout=%s is not supported!" %info["__layout__"] |
|
|
state_shape = (state_shape[0] * self._ctx_num, ) + state_shape[1:] |
|
|
ret.append(mx.io.DataDesc(name=info['name'], |
|
|
shape=state_shape, |
|
|
layout=info['__layout__'])) |
|
|
return ret |
|
|
|
|
|
def forecaster_sym(self): |
|
|
self.reset_all() |
|
|
block_state_list = [] |
|
|
for block in self._forecaster_rnn_blocks: |
|
|
block_state_list.append(block.init_state_vars()) |
|
|
|
|
|
if cfg.MODEL.OUT_TYPE == "direct": |
|
|
pred, _ = self.stack_rnn_forecast(block_state_list=block_state_list, |
|
|
last_frame=None) |
|
|
return mx.sym.Group([pred]) |
|
|
else: |
|
|
last_frame = mx.sym.Variable('last_frame') |
|
|
pred, flow = self.stack_rnn_forecast(block_state_list=block_state_list, |
|
|
last_frame=last_frame) |
|
|
return mx.sym.Group([pred, mx.sym.BlockGrad(flow)]) |
|
|
|
|
|
def forecaster_data_desc(self): |
|
|
ret = list() |
|
|
for info in self.init_forecaster_state_info: |
|
|
state_shape = safe_eval(info['shape']) |
|
|
assert info['__layout__'].find('N') == 0, \ |
|
|
"Layout=%s is not supported!" % info["__layout__"] |
|
|
state_shape = (state_shape[0] * self._ctx_num,) + state_shape[1:] |
|
|
ret.append(mx.io.DataDesc(name=info['name'], |
|
|
shape=state_shape, |
|
|
layout=info['__layout__'])) |
|
|
if cfg.MODEL.OUT_TYPE != "direct": |
|
|
ret.append(mx.io.DataDesc(name="last_frame", |
|
|
shape=(self._ctx_num * self._batch_size, |
|
|
1, self._height, self._width), |
|
|
layout="NCHW")) |
|
|
return ret |
|
|
|
|
|
def loss_sym(self): |
|
|
raise NotImplementedError |
|
|
|
|
|
def loss_data_desc(self): |
|
|
ret = list() |
|
|
ret.append(mx.io.DataDesc(name='pred', |
|
|
shape=(self._out_seq_len, |
|
|
self._ctx_num * self._batch_size, |
|
|
1, |
|
|
self._height, |
|
|
self._width), |
|
|
layout="TNCHW")) |
|
|
return ret |
|
|
|
|
|
def loss_label_desc(self): |
|
|
ret = list() |
|
|
ret.append(mx.io.DataDesc(name='target', |
|
|
shape=(self._out_seq_len, |
|
|
self._ctx_num * self._batch_size, |
|
|
1, |
|
|
self._height, |
|
|
self._width), |
|
|
layout="TNCHW")) |
|
|
if cfg.MODEL.ENCODER_FORECASTER.HAS_MASK: |
|
|
ret.append(mx.io.DataDesc(name='mask', |
|
|
shape=(self._out_seq_len, |
|
|
self._ctx_num * self._batch_size, |
|
|
1, |
|
|
self._height, |
|
|
self._width), |
|
|
layout="TNCHW")) |
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
def init_optimizer_using_cfg(net, for_finetune): |
|
|
if not for_finetune: |
|
|
lr_scheduler = mx.lr_scheduler.FactorScheduler(step=cfg.MODEL.TRAIN.LR_DECAY_ITER, |
|
|
factor=cfg.MODEL.TRAIN.LR_DECAY_FACTOR, |
|
|
stop_factor_lr=cfg.MODEL.TRAIN.MIN_LR) |
|
|
if cfg.MODEL.TRAIN.OPTIMIZER.lower() == "adam": |
|
|
net.init_optimizer(optimizer="adam", |
|
|
optimizer_params={'learning_rate': cfg.MODEL.TRAIN.LR, |
|
|
'beta1': cfg.MODEL.TRAIN.BETA1, |
|
|
'rescale_grad': 1.0, |
|
|
'epsilon': cfg.MODEL.TRAIN.EPS, |
|
|
'lr_scheduler': lr_scheduler, |
|
|
'wd': cfg.MODEL.TRAIN.WD}) |
|
|
elif cfg.MODEL.TRAIN.OPTIMIZER.lower() == "rmsprop": |
|
|
net.init_optimizer(optimizer="rmsprop", |
|
|
optimizer_params={'learning_rate': cfg.MODEL.TRAIN.LR, |
|
|
'gamma1': cfg.MODEL.TRAIN.GAMMA1, |
|
|
'rescale_grad': 1.0, |
|
|
'epsilon': cfg.MODEL.TRAIN.EPS, |
|
|
'lr_scheduler': lr_scheduler, |
|
|
'wd': cfg.MODEL.TRAIN.WD}) |
|
|
elif cfg.MODEL.TRAIN.OPTIMIZER.lower() == "sgd": |
|
|
net.init_optimizer(optimizer="sgd", |
|
|
optimizer_params={'learning_rate': cfg.MODEL.TRAIN.LR, |
|
|
'momentum': 0.0, |
|
|
'rescale_grad': 1.0, |
|
|
'lr_scheduler': lr_scheduler, |
|
|
'wd': cfg.MODEL.TRAIN.WD}) |
|
|
elif cfg.MODEL.TRAIN.OPTIMIZER.lower() == "adagrad": |
|
|
net.init_optimizer(optimizer="adagrad", |
|
|
optimizer_params={'learning_rate': cfg.MODEL.TRAIN.LR, |
|
|
'eps': cfg.MODEL.TRAIN.EPS, |
|
|
'rescale_grad': 1.0, |
|
|
'wd': cfg.MODEL.TRAIN.WD}) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
else: |
|
|
if cfg.MODEL.TEST.ONLINE.OPTIMIZER.lower() == "adam": |
|
|
net.init_optimizer(optimizer="adam", |
|
|
optimizer_params={'learning_rate': cfg.MODEL.TEST.ONLINE.LR, |
|
|
'beta1': cfg.MODEL.TEST.ONLINE.BETA1, |
|
|
'rescale_grad': 1.0, |
|
|
'epsilon': cfg.MODEL.TEST.ONLINE.EPS, |
|
|
'wd': cfg.MODEL.TEST.ONLINE.WD}) |
|
|
elif cfg.MODEL.TEST.ONLINE.OPTIMIZER.lower() == "rmsprop": |
|
|
net.init_optimizer(optimizer="rmsprop", |
|
|
optimizer_params={'learning_rate': cfg.MODEL.TEST.ONLINE.LR, |
|
|
'gamma1': cfg.MODEL.TEST.ONLINE.GAMMA1, |
|
|
'rescale_grad': 1.0, |
|
|
'epsilon': cfg.MODEL.TEST.ONLINE.EPS, |
|
|
'wd': cfg.MODEL.TEST.ONLINE.WD}) |
|
|
elif cfg.MODEL.TEST.ONLINE.OPTIMIZER.lower() == "sgd": |
|
|
net.init_optimizer(optimizer="sgd", |
|
|
optimizer_params={'learning_rate': cfg.MODEL.TEST.ONLINE.LR, |
|
|
'momentum': 0.0, |
|
|
'rescale_grad': 1.0, |
|
|
'wd': cfg.MODEL.TEST.ONLINE.WD}) |
|
|
elif cfg.MODEL.TEST.ONLINE.OPTIMIZER.lower() == "adagrad": |
|
|
net.init_optimizer(optimizer="adagrad", |
|
|
optimizer_params={'learning_rate': cfg.MODEL.TEST.ONLINE.LR, |
|
|
'eps': cfg.MODEL.TRAIN.EPS, |
|
|
'rescale_grad': 1.0, |
|
|
'wd': cfg.MODEL.TEST.ONLINE.WD}) |
|
|
return net |
|
|
|
|
|
|
|
|
def encoder_forecaster_build_networks(factory, context, |
|
|
shared_encoder_net=None, |
|
|
shared_forecaster_net=None, |
|
|
shared_loss_net=None, |
|
|
for_finetune=False): |
|
|
""" |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
factory : EncoderForecasterBaseFactory |
|
|
context : list |
|
|
shared_encoder_net : MyModule or None |
|
|
shared_forecaster_net : MyModule or None |
|
|
shared_loss_net : MyModule or None |
|
|
for_finetune : bool |
|
|
|
|
|
Returns |
|
|
------- |
|
|
|
|
|
""" |
|
|
encoder_net = MyModule(factory.encoder_sym(), |
|
|
data_names=[ele.name for ele in factory.encoder_data_desc()], |
|
|
label_names=[], |
|
|
context=context, |
|
|
name="encoder_net") |
|
|
encoder_net.bind(data_shapes=factory.encoder_data_desc(), |
|
|
label_shapes=None, |
|
|
inputs_need_grad=True, |
|
|
shared_module=shared_encoder_net) |
|
|
if shared_encoder_net is None: |
|
|
encoder_net.init_params(mx.init.MSRAPrelu(slope=0.2)) |
|
|
init_optimizer_using_cfg(encoder_net, for_finetune=for_finetune) |
|
|
forecaster_net = MyModule(factory.forecaster_sym(), |
|
|
data_names=[ele.name for ele in |
|
|
factory.forecaster_data_desc()], |
|
|
label_names=[], |
|
|
context=context, |
|
|
name="forecaster_net") |
|
|
forecaster_net.bind(data_shapes=factory.forecaster_data_desc(), |
|
|
label_shapes=None, |
|
|
inputs_need_grad=True, |
|
|
shared_module=shared_forecaster_net) |
|
|
if shared_forecaster_net is None: |
|
|
forecaster_net.init_params(mx.init.MSRAPrelu(slope=0.2)) |
|
|
init_optimizer_using_cfg(forecaster_net, for_finetune=for_finetune) |
|
|
|
|
|
loss_net = MyModule(factory.loss_sym(), |
|
|
data_names=[ele.name for ele in |
|
|
factory.loss_data_desc()], |
|
|
label_names=[ele.name for ele in |
|
|
factory.loss_label_desc()], |
|
|
context=context, |
|
|
name="loss_net") |
|
|
loss_net.bind(data_shapes=factory.loss_data_desc(), |
|
|
label_shapes=factory.loss_label_desc(), |
|
|
inputs_need_grad=True, |
|
|
shared_module=shared_loss_net) |
|
|
if shared_loss_net is None: |
|
|
loss_net.init_params() |
|
|
return encoder_net, forecaster_net, loss_net |
|
|
|
|
|
|
|
|
class EncoderForecasterStates(object): |
|
|
def __init__(self, factory, ctx): |
|
|
self._factory = factory |
|
|
self._ctx = ctx |
|
|
self._encoder_state_info = factory.init_encoder_state_info |
|
|
self._forecaster_state_info = factory.init_forecaster_state_info |
|
|
self._states_nd = [] |
|
|
for info in self._encoder_state_info: |
|
|
state_shape = safe_eval(info['shape']) |
|
|
state_shape = (state_shape[0] * factory._ctx_num, ) + state_shape[1:] |
|
|
self._states_nd.append(mx.nd.zeros(shape=state_shape, ctx=ctx)) |
|
|
|
|
|
def reset_all(self): |
|
|
for ele, info in zip(self._states_nd, self._encoder_state_info): |
|
|
ele[:] = 0 |
|
|
|
|
|
def reset_batch(self, batch_id): |
|
|
for ele, info in zip(self._states_nd, self._encoder_state_info): |
|
|
ele[batch_id][:] = 0 |
|
|
|
|
|
def update(self, states_nd): |
|
|
for target, src in zip(self._states_nd, states_nd): |
|
|
target[:] = src |
|
|
|
|
|
def get_encoder_states(self): |
|
|
return self._states_nd |
|
|
|
|
|
def get_forecaster_state(self): |
|
|
return self._states_nd |
|
|
|
|
|
|
|
|
def train_step(batch_size, encoder_net, forecaster_net, |
|
|
loss_net, init_states, |
|
|
data_nd, gt_nd, mask_nd, iter_id=None): |
|
|
"""Finetune the encoder, forecaster and GAN for one step |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
batch_size : int |
|
|
encoder_net : MyModule |
|
|
forecaster_net : MyModule |
|
|
loss_net : MyModule |
|
|
init_states : EncoderForecasterStates |
|
|
data_nd : mx.nd.ndarray |
|
|
gt_nd : mx.nd.ndarray |
|
|
mask_nd : mx.nd.ndarray |
|
|
iter_id : int |
|
|
|
|
|
Returns |
|
|
------- |
|
|
init_states: EncoderForecasterStates |
|
|
loss_dict: dict |
|
|
""" |
|
|
|
|
|
encoder_net.forward(is_train=True, |
|
|
data_batch=mx.io.DataBatch(data=[data_nd] + init_states.get_encoder_states())) |
|
|
encoder_states_nd = encoder_net.get_outputs() |
|
|
init_states.update(encoder_states_nd) |
|
|
|
|
|
if cfg.MODEL.OUT_TYPE == "direct": |
|
|
forecaster_net.forward(is_train=True, |
|
|
data_batch=mx.io.DataBatch(data=init_states.get_forecaster_state())) |
|
|
else: |
|
|
last_frame_nd = data_nd[data_nd.shape[0] - 1] |
|
|
forecaster_net.forward(is_train=True, |
|
|
data_batch=mx.io.DataBatch(data=init_states.get_forecaster_state() + |
|
|
[last_frame_nd])) |
|
|
forecaster_outputs = forecaster_net.get_outputs() |
|
|
pred_nd = forecaster_outputs[0] |
|
|
|
|
|
|
|
|
if cfg.MODEL.ENCODER_FORECASTER.HAS_MASK: |
|
|
loss_net.forward_backward(data_batch=mx.io.DataBatch(data=[pred_nd], |
|
|
label=[gt_nd, mask_nd])) |
|
|
else: |
|
|
loss_net.forward_backward(data_batch=mx.io.DataBatch(data=[pred_nd], |
|
|
label=[gt_nd])) |
|
|
pred_grad = loss_net.get_input_grads()[0] |
|
|
loss_dict = loss_net.get_output_dict() |
|
|
for k in loss_dict: |
|
|
loss_dict[k] = nd.mean(loss_dict[k]).asscalar() |
|
|
|
|
|
forecaster_net.backward(out_grads=[pred_grad]) |
|
|
if cfg.MODEL.OUT_TYPE == "direct": |
|
|
encoder_states_grad_nd = forecaster_net.get_input_grads() |
|
|
else: |
|
|
encoder_states_grad_nd = forecaster_net.get_input_grads()[:-1] |
|
|
|
|
|
encoder_net.backward(encoder_states_grad_nd) |
|
|
|
|
|
forecaster_grad_norm = forecaster_net.clip_by_global_norm(max_norm=cfg.MODEL.TRAIN.GRAD_CLIP) |
|
|
encoder_grad_norm = encoder_net.clip_by_global_norm(max_norm=cfg.MODEL.TRAIN.GRAD_CLIP) |
|
|
forecaster_net.update() |
|
|
encoder_net.update() |
|
|
loss_str = ", ".join(["%s=%g" %(k, v) for k, v in loss_dict.items()]) |
|
|
if iter_id is not None: |
|
|
logging.info("Iter:%d, %s, e_gnorm=%g, f_gnorm=%g" |
|
|
% (iter_id, loss_str, encoder_grad_norm, forecaster_grad_norm)) |
|
|
return init_states, loss_dict |
|
|
|
|
|
|
|
|
def load_encoder_forecaster_params(load_dir, load_iter, encoder_net, forecaster_net): |
|
|
logging.info("Loading parameters from {}, Iter = {}" |
|
|
.format(os.path.realpath(load_dir), load_iter)) |
|
|
encoder_arg_params, encoder_aux_params = load_params(prefix=os.path.join(load_dir, |
|
|
"encoder_net"), |
|
|
epoch=load_iter) |
|
|
encoder_net.init_params(arg_params=encoder_arg_params, aux_params=encoder_aux_params, |
|
|
allow_missing=False, force_init=True) |
|
|
forecaster_arg_params, forecaster_aux_params = load_params(prefix=os.path.join(load_dir, |
|
|
"forecaster_net"), |
|
|
epoch=load_iter) |
|
|
forecaster_net.init_params(arg_params=forecaster_arg_params, |
|
|
aux_params=forecaster_aux_params, |
|
|
allow_missing=False, |
|
|
force_init=True) |
|
|
logging.info("Loading Complete!") |
|
|
|