|
|
import os |
|
|
import glob |
|
|
import logging |
|
|
import importlib |
|
|
from tqdm import tqdm |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from core.prefetch_dataloader import PrefetchDataLoader, CPUPrefetcher |
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
from core.lr_scheduler import MultiStepRestartLR, CosineAnnealingRestartLR |
|
|
from core.dataset import TrainDataset |
|
|
|
|
|
from model.modules.flow_comp_raft import RAFT_bi, FlowLoss, EdgeLoss |
|
|
|
|
|
|
|
|
from model.canny.canny_filter import Canny |
|
|
from RAFT.utils.flow_viz_pt import flow_to_image |
|
|
|
|
|
|
|
|
class Trainer: |
|
|
def __init__(self, config): |
|
|
self.config = config |
|
|
self.epoch = 0 |
|
|
self.iteration = 0 |
|
|
self.num_local_frames = config['train_data_loader']['num_local_frames'] |
|
|
self.num_ref_frames = config['train_data_loader']['num_ref_frames'] |
|
|
|
|
|
|
|
|
self.train_dataset = TrainDataset(config['train_data_loader']) |
|
|
|
|
|
self.train_sampler = None |
|
|
self.train_args = config['trainer'] |
|
|
if config['distributed']: |
|
|
self.train_sampler = DistributedSampler( |
|
|
self.train_dataset, |
|
|
num_replicas=config['world_size'], |
|
|
rank=config['global_rank']) |
|
|
|
|
|
dataloader_args = dict( |
|
|
dataset=self.train_dataset, |
|
|
batch_size=self.train_args['batch_size'] // config['world_size'], |
|
|
shuffle=(self.train_sampler is None), |
|
|
num_workers=self.train_args['num_workers'], |
|
|
sampler=self.train_sampler, |
|
|
drop_last=True) |
|
|
|
|
|
self.train_loader = PrefetchDataLoader(self.train_args['num_prefetch_queue'], **dataloader_args) |
|
|
self.prefetcher = CPUPrefetcher(self.train_loader) |
|
|
|
|
|
|
|
|
self.fix_raft = RAFT_bi(device = self.config['device']) |
|
|
self.flow_loss = FlowLoss() |
|
|
self.edge_loss = EdgeLoss() |
|
|
self.canny = Canny(sigma=(2,2), low_threshold=0.1, high_threshold=0.2) |
|
|
|
|
|
|
|
|
net = importlib.import_module('model.' + config['model']['net']) |
|
|
self.netG = net.RecurrentFlowCompleteNet() |
|
|
|
|
|
self.netG = self.netG.to(self.config['device']) |
|
|
|
|
|
|
|
|
self.setup_optimizers() |
|
|
self.setup_schedulers() |
|
|
self.load() |
|
|
|
|
|
if config['distributed']: |
|
|
self.netG = DDP(self.netG, |
|
|
device_ids=[self.config['local_rank']], |
|
|
output_device=self.config['local_rank'], |
|
|
broadcast_buffers=True, |
|
|
find_unused_parameters=True) |
|
|
|
|
|
|
|
|
self.dis_writer = None |
|
|
self.gen_writer = None |
|
|
self.summary = {} |
|
|
if self.config['global_rank'] == 0 or (not config['distributed']): |
|
|
self.gen_writer = SummaryWriter( |
|
|
os.path.join(config['save_dir'], 'gen')) |
|
|
|
|
|
def setup_optimizers(self): |
|
|
"""Set up optimizers.""" |
|
|
backbone_params = [] |
|
|
for name, param in self.netG.named_parameters(): |
|
|
if param.requires_grad: |
|
|
backbone_params.append(param) |
|
|
else: |
|
|
print(f'Params {name} will not be optimized.') |
|
|
|
|
|
optim_params = [ |
|
|
{ |
|
|
'params': backbone_params, |
|
|
'lr': self.config['trainer']['lr'] |
|
|
}, |
|
|
] |
|
|
|
|
|
self.optimG = torch.optim.Adam(optim_params, |
|
|
betas=(self.config['trainer']['beta1'], |
|
|
self.config['trainer']['beta2'])) |
|
|
|
|
|
|
|
|
def setup_schedulers(self): |
|
|
"""Set up schedulers.""" |
|
|
scheduler_opt = self.config['trainer']['scheduler'] |
|
|
scheduler_type = scheduler_opt.pop('type') |
|
|
|
|
|
if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: |
|
|
self.scheG = MultiStepRestartLR( |
|
|
self.optimG, |
|
|
milestones=scheduler_opt['milestones'], |
|
|
gamma=scheduler_opt['gamma']) |
|
|
elif scheduler_type == 'CosineAnnealingRestartLR': |
|
|
self.scheG = CosineAnnealingRestartLR( |
|
|
self.optimG, |
|
|
periods=scheduler_opt['periods'], |
|
|
restart_weights=scheduler_opt['restart_weights']) |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
f'Scheduler {scheduler_type} is not implemented yet.') |
|
|
|
|
|
def update_learning_rate(self): |
|
|
"""Update learning rate.""" |
|
|
self.scheG.step() |
|
|
|
|
|
def get_lr(self): |
|
|
"""Get current learning rate.""" |
|
|
return self.optimG.param_groups[0]['lr'] |
|
|
|
|
|
def add_summary(self, writer, name, val): |
|
|
"""Add tensorboard summary.""" |
|
|
if name not in self.summary: |
|
|
self.summary[name] = 0 |
|
|
self.summary[name] += val |
|
|
n = self.train_args['log_freq'] |
|
|
if writer is not None and self.iteration % n == 0: |
|
|
writer.add_scalar(name, self.summary[name] / n, self.iteration) |
|
|
self.summary[name] = 0 |
|
|
|
|
|
def load(self): |
|
|
"""Load netG.""" |
|
|
|
|
|
model_path = self.config['save_dir'] |
|
|
if os.path.isfile(os.path.join(model_path, 'latest.ckpt')): |
|
|
latest_epoch = open(os.path.join(model_path, 'latest.ckpt'), |
|
|
'r').read().splitlines()[-1] |
|
|
else: |
|
|
ckpts = [ |
|
|
os.path.basename(i).split('.pth')[0] |
|
|
for i in glob.glob(os.path.join(model_path, '*.pth')) |
|
|
] |
|
|
ckpts.sort() |
|
|
latest_epoch = ckpts[-1][4:] if len(ckpts) > 0 else None |
|
|
|
|
|
if latest_epoch is not None: |
|
|
gen_path = os.path.join(model_path, f'gen_{int(latest_epoch):06d}.pth') |
|
|
opt_path = os.path.join(model_path,f'opt_{int(latest_epoch):06d}.pth') |
|
|
|
|
|
if self.config['global_rank'] == 0: |
|
|
print(f'Loading model from {gen_path}...') |
|
|
dataG = torch.load(gen_path, map_location=self.config['device']) |
|
|
self.netG.load_state_dict(dataG) |
|
|
|
|
|
|
|
|
data_opt = torch.load(opt_path, map_location=self.config['device']) |
|
|
self.optimG.load_state_dict(data_opt['optimG']) |
|
|
self.scheG.load_state_dict(data_opt['scheG']) |
|
|
|
|
|
self.epoch = data_opt['epoch'] |
|
|
self.iteration = data_opt['iteration'] |
|
|
|
|
|
else: |
|
|
if self.config['global_rank'] == 0: |
|
|
print('Warnning: There is no trained model found.' |
|
|
'An initialized model will be used.') |
|
|
|
|
|
def save(self, it): |
|
|
"""Save parameters every eval_epoch""" |
|
|
if self.config['global_rank'] == 0: |
|
|
|
|
|
gen_path = os.path.join(self.config['save_dir'], |
|
|
f'gen_{it:06d}.pth') |
|
|
opt_path = os.path.join(self.config['save_dir'], |
|
|
f'opt_{it:06d}.pth') |
|
|
print(f'\nsaving model to {gen_path} ...') |
|
|
|
|
|
|
|
|
if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP): |
|
|
netG = self.netG.module |
|
|
else: |
|
|
netG = self.netG |
|
|
|
|
|
|
|
|
torch.save(netG.state_dict(), gen_path) |
|
|
torch.save( |
|
|
{ |
|
|
'epoch': self.epoch, |
|
|
'iteration': self.iteration, |
|
|
'optimG': self.optimG.state_dict(), |
|
|
'scheG': self.scheG.state_dict() |
|
|
}, opt_path) |
|
|
|
|
|
latest_path = os.path.join(self.config['save_dir'], 'latest.ckpt') |
|
|
os.system(f"echo {it:06d} > {latest_path}") |
|
|
|
|
|
def train(self): |
|
|
"""training entry""" |
|
|
pbar = range(int(self.train_args['iterations'])) |
|
|
if self.config['global_rank'] == 0: |
|
|
pbar = tqdm(pbar, |
|
|
initial=self.iteration, |
|
|
dynamic_ncols=True, |
|
|
smoothing=0.01) |
|
|
|
|
|
os.makedirs('logs', exist_ok=True) |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s %(filename)s[line:%(lineno)d]" |
|
|
"%(levelname)s %(message)s", |
|
|
datefmt="%a, %d %b %Y %H:%M:%S", |
|
|
filename=f"logs/{self.config['save_dir'].split('/')[-1]}.log", |
|
|
filemode='w') |
|
|
|
|
|
while True: |
|
|
self.epoch += 1 |
|
|
self.prefetcher.reset() |
|
|
if self.config['distributed']: |
|
|
self.train_sampler.set_epoch(self.epoch) |
|
|
self._train_epoch(pbar) |
|
|
if self.iteration > self.train_args['iterations']: |
|
|
break |
|
|
print('\nEnd training....') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_edges(self, flows): |
|
|
|
|
|
b, t, _, h, w = flows.shape |
|
|
flows = flows.view(-1, 2, h, w) |
|
|
flows_gray = (flows[:, 0, None] ** 2 + flows[:, 1, None] ** 2) ** 0.5 |
|
|
if flows_gray.max() < 1: |
|
|
flows_gray = flows_gray*0 |
|
|
else: |
|
|
flows_gray = flows_gray / flows_gray.max() |
|
|
|
|
|
magnitude, edges = self.canny(flows_gray.float()) |
|
|
edges = edges.view(b, t, 1, h, w) |
|
|
return edges |
|
|
|
|
|
def _train_epoch(self, pbar): |
|
|
"""Process input and calculate loss every training epoch""" |
|
|
device = self.config['device'] |
|
|
train_data = self.prefetcher.next() |
|
|
while train_data is not None: |
|
|
self.iteration += 1 |
|
|
frames, masks, flows_f, flows_b, _ = train_data |
|
|
frames, masks = frames.to(device), masks.to(device) |
|
|
masks = masks.float() |
|
|
|
|
|
l_t = self.num_local_frames |
|
|
b, t, c, h, w = frames.size() |
|
|
gt_local_frames = frames[:, :l_t, ...] |
|
|
local_masks = masks[:, :l_t, ...].contiguous() |
|
|
|
|
|
|
|
|
if flows_f[0] == 'None' or flows_b[0] == 'None': |
|
|
gt_flows_bi = self.fix_raft(gt_local_frames) |
|
|
else: |
|
|
gt_flows_bi = (flows_f.to(device), flows_b.to(device)) |
|
|
|
|
|
|
|
|
gt_edges_forward = self.get_edges(gt_flows_bi[0]) |
|
|
gt_edges_backward = self.get_edges(gt_flows_bi[1]) |
|
|
gt_edges_bi = [gt_edges_forward, gt_edges_backward] |
|
|
|
|
|
|
|
|
pred_flows_bi, pred_edges_bi = self.netG.module.forward_bidirect_flow(gt_flows_bi, local_masks) |
|
|
|
|
|
|
|
|
self.optimG.zero_grad() |
|
|
|
|
|
|
|
|
flow_loss, warp_loss = self.flow_loss(pred_flows_bi, gt_flows_bi, local_masks, gt_local_frames) |
|
|
flow_loss = flow_loss * self.config['losses']['flow_weight'] |
|
|
warp_loss = warp_loss * 0.01 |
|
|
self.add_summary(self.gen_writer, 'loss/flow_loss', flow_loss.item()) |
|
|
self.add_summary(self.gen_writer, 'loss/warp_loss', warp_loss.item()) |
|
|
|
|
|
|
|
|
edge_loss = self.edge_loss(pred_edges_bi, gt_edges_bi, local_masks) |
|
|
edge_loss = edge_loss*1.0 |
|
|
self.add_summary(self.gen_writer, 'loss/edge_loss', edge_loss.item()) |
|
|
|
|
|
loss = flow_loss + warp_loss + edge_loss |
|
|
loss.backward() |
|
|
self.optimG.step() |
|
|
self.update_learning_rate() |
|
|
|
|
|
|
|
|
|
|
|
if self.iteration % 200 == 0 and self.gen_writer is not None: |
|
|
t = 5 |
|
|
|
|
|
gt_flows_forward_cpu = flow_to_image(gt_flows_bi[0][0]).cpu() |
|
|
masked_flows_forward_cpu = (gt_flows_forward_cpu[t] * (1-local_masks[0][t].cpu())).to(gt_flows_forward_cpu) |
|
|
pred_flows_forward_cpu = flow_to_image(pred_flows_bi[0][0]).cpu() |
|
|
|
|
|
flow_results = torch.cat([gt_flows_forward_cpu[t], masked_flows_forward_cpu, pred_flows_forward_cpu[t]], 1) |
|
|
self.gen_writer.add_image('img/flow-f:gt-pred', flow_results, self.iteration) |
|
|
|
|
|
|
|
|
gt_flows_backward_cpu = flow_to_image(gt_flows_bi[1][0]).cpu() |
|
|
masked_flows_backward_cpu = (gt_flows_backward_cpu[t] * (1-local_masks[0][t+1].cpu())).to(gt_flows_backward_cpu) |
|
|
pred_flows_backward_cpu = flow_to_image(pred_flows_bi[1][0]).cpu() |
|
|
|
|
|
flow_results = torch.cat([gt_flows_backward_cpu[t], masked_flows_backward_cpu, pred_flows_backward_cpu[t]], 1) |
|
|
self.gen_writer.add_image('img/flow-b:gt-pred', flow_results, self.iteration) |
|
|
|
|
|
|
|
|
|
|
|
gt_edges_forward_cpu = gt_edges_bi[0][0].cpu() |
|
|
masked_edges_forward_cpu = (gt_edges_forward_cpu[t] * (1-local_masks[0][t].cpu())).to(gt_edges_forward_cpu) |
|
|
pred_edges_forward_cpu = pred_edges_bi[0][0].cpu() |
|
|
|
|
|
edge_results = torch.cat([gt_edges_forward_cpu[t], masked_edges_forward_cpu, pred_edges_forward_cpu[t]], 1) |
|
|
self.gen_writer.add_image('img/edge-f:gt-pred', edge_results, self.iteration) |
|
|
|
|
|
gt_edges_backward_cpu = gt_edges_bi[1][0].cpu() |
|
|
masked_edges_backward_cpu = (gt_edges_backward_cpu[t] * (1-local_masks[0][t+1].cpu())).to(gt_edges_backward_cpu) |
|
|
pred_edges_backward_cpu = pred_edges_bi[1][0].cpu() |
|
|
|
|
|
edge_results = torch.cat([gt_edges_backward_cpu[t], masked_edges_backward_cpu, pred_edges_backward_cpu[t]], 1) |
|
|
self.gen_writer.add_image('img/edge-b:gt-pred', edge_results, self.iteration) |
|
|
|
|
|
|
|
|
if self.config['global_rank'] == 0: |
|
|
pbar.update(1) |
|
|
pbar.set_description((f"flow: {flow_loss.item():.3f}; " |
|
|
f"warp: {warp_loss.item():.3f}; " |
|
|
f"edge: {edge_loss.item():.3f}; " |
|
|
f"lr: {self.get_lr()}")) |
|
|
|
|
|
if self.iteration % self.train_args['log_freq'] == 0: |
|
|
logging.info(f"[Iter {self.iteration}] " |
|
|
f"flow: {flow_loss.item():.4f}; " |
|
|
f"warp: {warp_loss.item():.4f}") |
|
|
|
|
|
|
|
|
if self.iteration % self.train_args['save_freq'] == 0: |
|
|
self.save(int(self.iteration)) |
|
|
|
|
|
if self.iteration > self.train_args['iterations']: |
|
|
break |
|
|
|
|
|
train_data = self.prefetcher.next() |