from pathlib import Path from typing import Sequence import torch import torch.nn as nn import torch.nn.functional as F from utils.torch_utilities import ( load_pretrained_model, merge_matched_keys, create_mask_from_length, loss_with_mask, create_alignment_path ) class LoadPretrainedBase(nn.Module): def process_state_dict( self, model_dict: dict[str, torch.Tensor], state_dict: dict[str, torch.Tensor] ): """ Custom processing functions of each model that transforms `state_dict` loaded from checkpoints to the state that can be used in `load_state_dict`. Use `merge_mathced_keys` to update parameters with matched names and shapes by default. Args model_dict: The state dict of the current model, which is going to load pretrained parameters state_dict: A dictionary of parameters from a pre-trained model. Returns: dict[str, torch.Tensor]: The updated state dict, where parameters with matched keys and shape are updated with values in `state_dict`. """ state_dict = merge_matched_keys(model_dict, state_dict) return state_dict def load_pretrained(self, ckpt_path: str | Path): load_pretrained_model( self, ckpt_path, state_dict_process_fn=self.process_state_dict ) class CountParamsBase(nn.Module): def count_params(self): num_params = 0 trainable_params = 0 for param in self.parameters(): num_params += param.numel() if param.requires_grad: trainable_params += param.numel() return num_params, trainable_params class SaveTrainableParamsBase(nn.Module): @property def param_names_to_save(self): names = [] for name, param in self.named_parameters(): if param.requires_grad: names.append(name) for name, _ in self.named_buffers(): names.append(name) return names def load_state_dict(self, state_dict, strict=True): missing_keys = [] for key in self.param_names_to_save: if key not in state_dict: missing_keys.append(key) if strict and len(missing_keys) > 0: raise Exception( f"{missing_keys} not found in either pre-trained models (e.g. BERT) or resumed checkpoints (e.g. epoch_40/model.pt)" ) elif len(missing_keys) > 0: print(f"Warning: missing keys {missing_keys}, skipping them.") return super().load_state_dict(state_dict, strict)