Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Copyright (c) XiMing Xing. All rights reserved. | |
| # Author: XiMing Xing | |
| # Description: | |
| import sys | |
| from collections import OrderedDict | |
| import numpy as np | |
| import torch | |
| layer_modules = (torch.nn.MultiheadAttention,) | |
| def summary(model, input_data=None, input_data_args=None, input_shape=None, input_dtype=torch.FloatTensor, | |
| batch_size=-1, | |
| *args, **kwargs): | |
| """ | |
| give example input data as least one way like below: | |
| ① input_data ---> model.forward(input_data) | |
| ② input_data_args ---> model.forward(*input_data_args) | |
| ③ input_shape & input_dtype ---> model.forward(*[torch.rand(2, *size).type(input_dtype) for size in input_shape]) | |
| """ | |
| hooks = [] | |
| summary = OrderedDict() | |
| def register_hook(module): | |
| def hook(module, inputs, outputs): | |
| class_name = str(module.__class__).split(".")[-1].split("'")[0] | |
| module_idx = len(summary) | |
| key = "%s-%i" % (class_name, module_idx + 1) | |
| info = OrderedDict() | |
| info["id"] = id(module) | |
| if isinstance(outputs, (list, tuple)): | |
| try: | |
| info["out"] = [batch_size] + list(outputs[0].size())[1:] | |
| except AttributeError: | |
| # pack_padded_seq and pad_packed_seq store feature into data attribute | |
| info["out"] = [batch_size] + list(outputs[0].data.size())[1:] | |
| else: | |
| info["out"] = [batch_size] + list(outputs.size())[1:] | |
| info["params_nt"], info["params"] = 0, 0 | |
| for name, param in module.named_parameters(): | |
| info["params"] += param.nelement() * param.requires_grad | |
| info["params_nt"] += param.nelement() * (not param.requires_grad) | |
| summary[key] = info | |
| # ignore Sequential and ModuleList and other containers | |
| if isinstance(module, layer_modules) or not module._modules: | |
| hooks.append(module.register_forward_hook(hook)) | |
| model.apply(register_hook) | |
| # multiple inputs to the network | |
| if isinstance(input_shape, tuple): | |
| input_shape = [input_shape] | |
| if input_data is not None: | |
| x = [input_data] | |
| elif input_shape is not None: | |
| # batch_size of 2 for batchnorm | |
| x = [torch.rand(2, *size).type(input_dtype) for size in input_shape] | |
| elif input_data_args is not None: | |
| x = input_data_args | |
| else: | |
| x = [] | |
| try: | |
| with torch.no_grad(): | |
| model(*x) if not (kwargs or args) else model(*x, *args, **kwargs) | |
| except Exception: | |
| # This can be usefull for debugging | |
| print("Failed to run summary...") | |
| raise | |
| finally: | |
| for hook in hooks: | |
| hook.remove() | |
| summary_logs = [] | |
| summary_logs.append("--------------------------------------------------------------------------") | |
| line_new = "{:<30} {:>20} {:>20}".format("Layer (type)", "Output Shape", "Param #") | |
| summary_logs.append(line_new) | |
| summary_logs.append("==========================================================================") | |
| total_params = 0 | |
| total_output = 0 | |
| trainable_params = 0 | |
| for layer in summary: | |
| # layer, output_shape, params | |
| line_new = "{:<30} {:>20} {:>20}".format( | |
| layer, | |
| str(summary[layer]["out"]), | |
| "{0:,}".format(summary[layer]["params"] + summary[layer]["params_nt"]) | |
| ) | |
| total_params += (summary[layer]["params"] + summary[layer]["params_nt"]) | |
| total_output += np.prod(summary[layer]["out"]) | |
| trainable_params += summary[layer]["params"] | |
| summary_logs.append(line_new) | |
| # assume 4 bytes/number | |
| if input_data is not None: | |
| total_input_size = abs(sys.getsizeof(input_data) / (1024 ** 2.)) | |
| elif input_shape is not None: | |
| total_input_size = abs(np.prod(input_shape) * batch_size * 4. / (1024 ** 2.)) | |
| else: | |
| total_input_size = 0.0 | |
| total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients | |
| total_params_size = abs(total_params * 4. / (1024 ** 2.)) | |
| total_size = total_params_size + total_output_size + total_input_size | |
| summary_logs.append("==========================================================================") | |
| summary_logs.append("Total params: {0:,}".format(total_params)) | |
| summary_logs.append("Trainable params: {0:,}".format(trainable_params)) | |
| summary_logs.append("Non-trainable params: {0:,}".format(total_params - trainable_params)) | |
| summary_logs.append("--------------------------------------------------------------------------") | |
| summary_logs.append("Input size (MB): %0.6f" % total_input_size) | |
| summary_logs.append("Forward/backward pass size (MB): %0.6f" % total_output_size) | |
| summary_logs.append("Params size (MB): %0.6f" % total_params_size) | |
| summary_logs.append("Estimated Total Size (MB): %0.6f" % total_size) | |
| summary_logs.append("--------------------------------------------------------------------------") | |
| summary_info = "\n".join(summary_logs) | |
| print(summary_info) | |
| return summary_info | |