|
|
from pathlib import Path |
|
|
from typing import Sequence |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from utils.torch_utilities import ( |
|
|
load_pretrained_model, merge_matched_keys, create_mask_from_length, |
|
|
loss_with_mask, create_alignment_path |
|
|
) |
|
|
|
|
|
|
|
|
class LoadPretrainedBase(nn.Module): |
|
|
def process_state_dict( |
|
|
self, model_dict: dict[str, torch.Tensor], |
|
|
state_dict: dict[str, torch.Tensor] |
|
|
): |
|
|
""" |
|
|
Custom processing functions of each model that transforms `state_dict` loaded from |
|
|
checkpoints to the state that can be used in `load_state_dict`. |
|
|
Use `merge_mathced_keys` to update parameters with matched names and shapes by |
|
|
default. |
|
|
|
|
|
Args |
|
|
model_dict: |
|
|
The state dict of the current model, which is going to load pretrained parameters |
|
|
state_dict: |
|
|
A dictionary of parameters from a pre-trained model. |
|
|
|
|
|
Returns: |
|
|
dict[str, torch.Tensor]: |
|
|
The updated state dict, where parameters with matched keys and shape are |
|
|
updated with values in `state_dict`. |
|
|
""" |
|
|
state_dict = merge_matched_keys(model_dict, state_dict) |
|
|
return state_dict |
|
|
|
|
|
def load_pretrained(self, ckpt_path: str | Path): |
|
|
load_pretrained_model( |
|
|
self, ckpt_path, state_dict_process_fn=self.process_state_dict |
|
|
) |
|
|
|
|
|
|
|
|
class CountParamsBase(nn.Module): |
|
|
def count_params(self): |
|
|
num_params = 0 |
|
|
trainable_params = 0 |
|
|
for param in self.parameters(): |
|
|
num_params += param.numel() |
|
|
if param.requires_grad: |
|
|
trainable_params += param.numel() |
|
|
return num_params, trainable_params |
|
|
|
|
|
|
|
|
class SaveTrainableParamsBase(nn.Module): |
|
|
@property |
|
|
def param_names_to_save(self): |
|
|
names = [] |
|
|
for name, param in self.named_parameters(): |
|
|
if param.requires_grad: |
|
|
names.append(name) |
|
|
for name, _ in self.named_buffers(): |
|
|
names.append(name) |
|
|
return names |
|
|
|
|
|
def load_state_dict(self, state_dict, strict=True): |
|
|
missing_keys = [] |
|
|
for key in self.param_names_to_save: |
|
|
if key not in state_dict: |
|
|
missing_keys.append(key) |
|
|
|
|
|
if strict and len(missing_keys) > 0: |
|
|
raise Exception( |
|
|
f"{missing_keys} not found in either pre-trained models (e.g. BERT) or resumed checkpoints (e.g. epoch_40/model.pt)" |
|
|
) |
|
|
elif len(missing_keys) > 0: |
|
|
print(f"Warning: missing keys {missing_keys}, skipping them.") |
|
|
|
|
|
return super().load_state_dict(state_dict, strict) |
|
|
|