File size: 2,549 Bytes
426874e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32d3fde
426874e
 
 
 
 
 
 
 
 
 
32d3fde
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Module for training utilities.

This module contains utility functions for training models. For example, saving model checkpoints.
"""

import logging
import os
import tempfile
from typing import Any, Union

import torch
import torch.nn as nn

logger = logging.getLogger(__name__)


def maybe_unwrap_dist_model(model: nn.Module, use_distributed: bool) -> nn.Module:
    return model.module if use_distributed else model


def get_state_dict(model, drop_untrained_params: bool = True) -> dict[str, Any]:
    """Get model state dict. Optionally drop untrained parameters to keep only those that require gradient.

    Args:
        model: Model to get state dict from
        drop_untrained_params: Whether to drop untrained parameters

    Returns:
        dict: Model state dict
    """
    if not drop_untrained_params:
        return model.state_dict()

    param_grad_dict = {k: v.requires_grad for (k, v) in model.named_parameters()}
    state_dict = model.state_dict()

    for k in list(state_dict.keys()):
        if k in param_grad_dict.keys() and not param_grad_dict[k]:
            # delete parameters that do not require gradient
            del state_dict[k]

    return state_dict


def save_model_checkpoint(
    model: nn.Module,
    save_path: Union[str, os.PathLike],
    use_distributed: bool = False,
    drop_untrained_params: bool = False,
    **objects_to_save,
) -> None:
    """Save model checkpoint.

    Args:
        model (nn.Module): Model to save
        output_dir (str): Output directory to save checkpoint
        use_distributed (bool): Whether the model is distributed, if so, unwrap it. Default: False.
        is_best (bool): Whether the model is the best in the training run. Default: False.
        drop_untrained_params (bool): Whether to drop untrained parameters to save. Default: True.
        prefix (str): Prefix to add to the checkpoint file name. Default: "".
        extention (str): Extension to use for the checkpoint file. Default: "pth".
        **objects_to_save: Additional objects to save, e.g. optimizer state dict, etc.
    """
    if not os.path.exists(os.path.dirname(save_path)):
        raise FileNotFoundError(f"Directory {os.path.dirname(save_path)} does not exist.")

    model_no_ddp = maybe_unwrap_dist_model(model, use_distributed)
    state_dict = get_state_dict(model_no_ddp, drop_untrained_params)
    save_obj = {
        "model": state_dict,
        **objects_to_save,
    }

    logger.info("Saving checkpoint to {}.".format(save_path))
    torch.save(save_obj, save_path)