|
|
import torch |
|
|
import collections |
|
|
import logging |
|
|
import omegaconf |
|
|
import wandb |
|
|
import datetime |
|
|
import glob |
|
|
import os |
|
|
import json |
|
|
|
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
class BaseTimer: |
|
|
def __init__(self): |
|
|
self.start = torch.cuda.Event(enable_timing=True) |
|
|
self.end = torch.cuda.Event(enable_timing=True) |
|
|
self.start.record() |
|
|
|
|
|
def stop(self): |
|
|
self.end.record() |
|
|
torch.cuda.synchronize() |
|
|
return self.start.elapsed_time(self.end) / 1000 |
|
|
|
|
|
class Timer: |
|
|
def __init__(self, info=None, log_event=None): |
|
|
self.info = info |
|
|
self.log_event = log_event |
|
|
|
|
|
def __enter__(self): |
|
|
self.start = torch.cuda.Event(enable_timing=True) |
|
|
self.end = torch.cuda.Event(enable_timing=True) |
|
|
self.start.record() |
|
|
return self |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
self.end.record() |
|
|
torch.cuda.synchronize() |
|
|
self.duration = self.start.elapsed_time(self.end) / 1000 |
|
|
if self.info: |
|
|
self.info[f"duration/{self.log_event}"] = self.duration |
|
|
|
|
|
|
|
|
class _StreamingMean: |
|
|
def __init__(self, val=None, counts=None): |
|
|
if val is None: |
|
|
self.mean = 0.0 |
|
|
self.counts = 0 |
|
|
else: |
|
|
if isinstance(val, torch.Tensor): |
|
|
val = val.data.cpu().numpy() |
|
|
self.mean = val |
|
|
if counts is not None: |
|
|
self.counts = counts |
|
|
else: |
|
|
self.counts = 1 |
|
|
|
|
|
def update(self, mean, counts=1): |
|
|
if isinstance(mean, torch.Tensor): |
|
|
mean = mean.data.cpu().numpy() |
|
|
elif isinstance(mean, _StreamingMean): |
|
|
mean, counts = mean.mean, mean.counts * counts |
|
|
assert counts >= 0 |
|
|
if counts == 0: |
|
|
return |
|
|
total = self.counts + counts |
|
|
self.mean = self.counts / total * self.mean + counts / total * mean |
|
|
self.counts = total |
|
|
|
|
|
def __add__(self, other): |
|
|
new = self.__class__(self.mean, self.counts) |
|
|
if isinstance(other, _StreamingMean): |
|
|
if other.counts == 0: |
|
|
return new |
|
|
else: |
|
|
new.update(other.mean, other.counts) |
|
|
else: |
|
|
new.update(other) |
|
|
return new |
|
|
|
|
|
|
|
|
class StreamingMeans(collections.defaultdict): |
|
|
def __init__(self): |
|
|
super().__init__(_StreamingMean) |
|
|
|
|
|
def __setitem__(self, key, value): |
|
|
if isinstance(value, _StreamingMean): |
|
|
super().__setitem__(key, value) |
|
|
else: |
|
|
super().__setitem__(key, _StreamingMean(value)) |
|
|
|
|
|
def update(self, *args, **kwargs): |
|
|
for_update = dict(*args, **kwargs) |
|
|
for k, v in for_update.items(): |
|
|
self[k].update(v) |
|
|
|
|
|
def to_dict(self, prefix=""): |
|
|
return dict((prefix + k, v.mean) for k, v in self.items()) |
|
|
|
|
|
def to_str(self): |
|
|
return ", ".join([f"{k} = {v:.3f}" for k, v in self.to_dict().items()]) |
|
|
|
|
|
|
|
|
class ConsoleLogger: |
|
|
def __init__(self, name): |
|
|
self.logger = logging.getLogger(name) |
|
|
self.logger.handlers = [] |
|
|
self.logger.setLevel(logging.INFO) |
|
|
log_formatter = logging.Formatter( |
|
|
"%(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" |
|
|
) |
|
|
console_handler = logging.StreamHandler() |
|
|
console_handler.setFormatter(log_formatter) |
|
|
self.logger.addHandler(console_handler) |
|
|
|
|
|
self.logger.propagate = False |
|
|
|
|
|
@staticmethod |
|
|
def format_info(info): |
|
|
if not info: |
|
|
return str(info) |
|
|
log_groups = collections.defaultdict(dict) |
|
|
for k, v in info.to_dict().items(): |
|
|
prefix, suffix = k.split("/", 1) |
|
|
log_groups[prefix][suffix] = f"{v:.3f}" if isinstance(v, float) else str(v) |
|
|
formatted_info = "" |
|
|
max_group_size = len(max(log_groups, key=len)) + 2 |
|
|
max_k_size = max([len(max(g, key=len)) for g in log_groups.values()]) + 1 |
|
|
max_v_size = ( |
|
|
max([len(max(g.values(), key=len)) for g in log_groups.values()]) + 1 |
|
|
) |
|
|
for group, group_info in log_groups.items(): |
|
|
group_str = [ |
|
|
f"{k:<{max_k_size}}={v:>{max_v_size}}" for k, v in group_info.items() |
|
|
] |
|
|
max_g_size = len(max(group_str, key=len)) + 2 |
|
|
group_str = "".join([f"{g:>{max_g_size}}" for g in group_str]) |
|
|
formatted_info += f"\n{group + ':':<{max_group_size}}{group_str}" |
|
|
return formatted_info |
|
|
|
|
|
def log_iter(self, epoch_num, iter_num, num_iters, iter_info, event="epoch"): |
|
|
output_info = f"{event.upper()} {epoch_num}, ITER {iter_num}/{num_iters}:" |
|
|
output_info += self.format_info(iter_info) |
|
|
self.logger.info(output_info) |
|
|
|
|
|
def log_epoch(self, epoch_info, epoch_num): |
|
|
output_info = f"EPOCH {epoch_num}:" |
|
|
output_info += self.format_info(epoch_info) |
|
|
self.logger.info(output_info) |
|
|
|
|
|
|
|
|
class WandbLogger: |
|
|
def __init__(self, config): |
|
|
wandb.login(key=os.environ['WANDB_KEY'].strip(), relogin=True) |
|
|
if config.train.resume_path == "": |
|
|
config_for_logger = omegaconf.OmegaConf.to_container(config) |
|
|
self.wandb_args = { |
|
|
"id": wandb.util.generate_id(), |
|
|
"project": config.exp.wandb_project, |
|
|
"name": config.exp.name, |
|
|
"config": config_for_logger, |
|
|
} |
|
|
wandb.init(**self.wandb_args, resume="allow") |
|
|
|
|
|
run_dir = wandb.run.dir |
|
|
print("run_dir", run_dir) |
|
|
|
|
|
code = wandb.Artifact("project-source", type="code") |
|
|
for path in glob.glob("**/*.py", recursive=True): |
|
|
if not path.startswith("wandb"): |
|
|
if os.path.basename(path) != path: |
|
|
code.add_dir( |
|
|
os.path.dirname(path), name=os.path.dirname(path) |
|
|
) |
|
|
else: |
|
|
code.add_file(os.path.basename(path), name=path) |
|
|
wandb.run.log_artifact(code) |
|
|
else: |
|
|
print(f"Resume training from {config.train.resume_path}") |
|
|
with open(config.train.resume_path, "r") as f: |
|
|
options = json.load(f) |
|
|
|
|
|
self.wandb_args = { |
|
|
"id": options['id'], |
|
|
"project": options['project'], |
|
|
"name": options['name'], |
|
|
"config": options['config'], |
|
|
} |
|
|
wandb.init(resume=True, **self.wandb_args) |
|
|
|
|
|
@staticmethod |
|
|
def log_epoch(iter_info, step): |
|
|
wandb.log( |
|
|
data={k: v.mean for k, v in iter_info.items()}, |
|
|
step=step + 1, |
|
|
commit=True, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def log_special_pics(pics, captions, paths): |
|
|
to_log = {} |
|
|
for i, path in enumerate(paths): |
|
|
to_log[path] = wandb.Image(pics[i], caption=captions[path]) |
|
|
wandb.log(to_log) |
|
|
|
|
|
|
|
|
class BlankWandbLogger: |
|
|
def __init__(self): |
|
|
self.wandb_args = None |
|
|
|
|
|
def log_epoch(*args, **kwars): |
|
|
pass |
|
|
|
|
|
def log_special_pics(*args, **kwars): |
|
|
pass |
|
|
|
|
|
|
|
|
class TrainigLogger: |
|
|
def __init__(self, config): |
|
|
self.console_logger = ConsoleLogger("") |
|
|
if config.exp.wandb == True: |
|
|
self.wandb_logger = WandbLogger(config) |
|
|
else: |
|
|
self.wandb_logger = BlankWandbLogger() |
|
|
|
|
|
self.trainig_steps = config.train.steps |
|
|
self.val_step = config.train.val_step |
|
|
|
|
|
def log_train_time_left(self, iter_info, step): |
|
|
float_iter_time = iter_info["duration/iter_train"].mean |
|
|
float_val_time = iter_info["duration/iter_val"].mean |
|
|
time_left = str( |
|
|
datetime.datetime.fromtimestamp( |
|
|
float_iter_time * (self.trainig_steps - step) |
|
|
+ float_val_time |
|
|
* ( |
|
|
(self.trainig_steps - step) // self.val_step |
|
|
) |
|
|
) |
|
|
- datetime.datetime.fromtimestamp(0) |
|
|
) |
|
|
|
|
|
print() |
|
|
print(f"Step {step}/{self.trainig_steps}") |
|
|
print(f"Time left: {time_left}") |
|
|
print(f"Time per step: {iter_info['duration/iter_train'].mean :.3f}") |
|
|
print() |
|
|
print() |
|
|
|
|
|
def save_train_logs(self, iter_info, step): |
|
|
self.wandb_logger.log_epoch(iter_info, step) |
|
|
self.console_logger.log_epoch(iter_info, step) |
|
|
|
|
|
self.log_train_time_left(iter_info, step) |
|
|
|
|
|
def save_validation_logs(self, orig_pics, method_pics, captions, special_paths): |
|
|
log_pics = [] |
|
|
for real_img, fake_img in zip(orig_pics, method_pics): |
|
|
concat_img = Image.new( |
|
|
"RGB", (real_img.width + fake_img.width, real_img.height) |
|
|
) |
|
|
concat_img.paste(real_img, (0, 0)) |
|
|
concat_img.paste(fake_img, (real_img.width, 0)) |
|
|
log_pics.append(concat_img) |
|
|
|
|
|
self.wandb_logger.log_special_pics(log_pics, captions, special_paths) |
|
|
|