File size: 7,563 Bytes
3ed0796
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

dependencies = ["torch", "timm", "einops"]

import os
from typing import Dict, Any, Optional, Union, List
import warnings

import torch
from torch.hub import load_state_dict_from_url

from timm.models import clean_state_dict

from .radio.adaptor_registry import adaptor_registry
from .radio.common import DEFAULT_VERSION, RadioResource, RESOURCE_MAP
from .radio.enable_damp import configure_damp_from_args
from .radio.enable_spectral_reparam import disable_spectral_reparam, configure_spectral_reparam_from_args
from .radio.feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
from .radio.radio_model import RADIOModel, create_model_from_args
from .radio.input_conditioner import get_default_conditioner
from .radio.vitdet import apply_vitdet_arch, VitDetArgs


def radio_model(
    version: str = "",
    progress: bool = True,
    adaptor_names: Union[str, List[str]] = None,
    vitdet_window_size: Optional[int] = None,
    return_checkpoint: bool = False,
    support_packing: bool=False,
    **kwargs,
) -> RADIOModel:
    if not version:
        version = DEFAULT_VERSION

    if os.path.isfile(version):
        chk = torch.load(version, map_location="cpu", weights_only=False)
        resource = RadioResource(version, patch_size=None, max_resolution=None, preferred_resolution=None)
    else:
        resource = RESOURCE_MAP[version]
        chk = load_state_dict_from_url(
            resource.url, progress=progress, map_location="cpu", weights_only=False,
        )

    if "state_dict_ema" in chk:
        state_dict = chk["state_dict_ema"]
        chk['args'].spectral_reparam = False
    else:
        state_dict = chk["state_dict"]

    args = chk["args"]
    args.support_packing = support_packing
    mod = create_model_from_args(args)

    mod_state_dict = get_prefix_state_dict(state_dict, "base_model.")

    if args.spectral_reparam:
        configure_spectral_reparam_from_args(mod, args, state_dict_guidance=mod_state_dict)

    if getattr(args, 'damp', None):
        configure_damp_from_args(mod, args)

    state_dict = clean_state_dict(state_dict)

    key_warn = mod.load_state_dict(mod_state_dict, strict=False)
    if key_warn.missing_keys:
        warnings.warn(f'Missing keys in state dict: {key_warn.missing_keys}')
    if key_warn.unexpected_keys:
        warnings.warn(f'Unexpected keys in state dict: {key_warn.unexpected_keys}')

    if chk['args'].spectral_reparam:
        # Spectral reparametrization uses PyTorch's "parametrizations" API. The idea behind
        # the method is that instead of there being a `weight` tensor for certain Linear layers
        # in the model, we make it a dynamically computed function. During training, this
        # helps stabilize the model. However, for downstream use cases, it shouldn't be necessary.
        # Disabling it in this context means that instead of having `w' = f(w)`, we just compute `w' = f(w)`
        # once, during this function call, and replace the parametrization with the realized weights.
        # This makes the model run faster, and also use less memory.
        disable_spectral_reparam(mod)
        chk['args'].spectral_reparam = False

    conditioner = get_default_conditioner()
    conditioner.load_state_dict(get_prefix_state_dict(state_dict, "input_conditioner."))

    dtype = getattr(chk['args'], 'dtype', torch.float32)
    mod.to(dtype=dtype)
    conditioner.dtype = dtype

    cls_token_per_teacher = getattr(chk['args'], 'cls_token_per_teacher', True)
    if cls_token_per_teacher:
        name_to_idx_map = dict()
        for i, t in enumerate(chk['args'].teachers):
            if t.get('use_summary', True):
                name = t['name']
                if name not in name_to_idx_map:
                    name_to_idx_map[name] = i
        summary_idxs = torch.tensor(sorted(name_to_idx_map.values()), dtype=torch.int64)
    else:
        summary_idxs = torch.tensor([0], dtype=torch.int64)

    if adaptor_names is None:
        adaptor_names = []
    elif isinstance(adaptor_names, str):
        adaptor_names = [adaptor_names]

    teachers = chk["args"].teachers
    adaptors = dict()
    for adaptor_name in adaptor_names:
        for tidx, tconf in enumerate(teachers):
            if tconf["name"] == adaptor_name:
                break
        else:
            raise ValueError(f'Unable to find the specified adaptor name. Known names: {list(t["name"] for t in teachers)}')

        ttype = tconf["type"]

        pf_idx_head = f'_heads.{tidx}'
        pf_name_head = f'_heads.{adaptor_name}'
        pf_idx_feat = f'_feature_projections.{tidx}'
        pf_name_feat = f'_feature_projections.{adaptor_name}'

        adaptor_state = dict()
        for k, v in state_dict.items():
            if k.startswith(pf_idx_head):
                adaptor_state['summary' + k[len(pf_idx_head):]] = v
            elif k.startswith(pf_name_head):
                adaptor_state['summary' + k[len(pf_name_head):]] = v
            elif k.startswith(pf_idx_feat):
                adaptor_state['feature' + k[len(pf_idx_feat):]] = v
            elif k.startswith(pf_name_feat):
                adaptor_state['feature' + k[len(pf_name_feat):]] = v

        adaptor = adaptor_registry.create_adaptor(ttype, chk["args"], tconf, adaptor_state)
        adaptor.head_idx = tidx if cls_token_per_teacher else 0
        adaptors[adaptor_name] = adaptor

    feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.')
    feature_normalizer = None
    if feat_norm_sd:
        feature_normalizer = FeatureNormalizer(feat_norm_sd['mean'].shape[0], dtype=dtype)
        feature_normalizer.load_state_dict(feat_norm_sd)

    inter_feat_norm_sd = get_prefix_state_dict(state_dict, '_intermediate_feature_normalizer.')
    inter_feature_normalizer = None
    if inter_feat_norm_sd:
        inter_feature_normalizer = IntermediateFeatureNormalizer(
            *inter_feat_norm_sd['means'].shape[:2],
            rot_per_layer=inter_feat_norm_sd['rotation'].ndim == 3,
            dtype=dtype
        )
        inter_feature_normalizer.load_state_dict(inter_feat_norm_sd)

    radio = RADIOModel(
        mod,
        conditioner,
        summary_idxs=summary_idxs,
        patch_size=resource.patch_size,
        max_resolution=resource.max_resolution,
        window_size=vitdet_window_size,
        preferred_resolution=resource.preferred_resolution,
        adaptors=adaptors,
        feature_normalizer=feature_normalizer,
        inter_feature_normalizer=inter_feature_normalizer,
    )

    if vitdet_window_size is not None:
        apply_vitdet_arch(
            mod,
            VitDetArgs(
                vitdet_window_size,
                radio.num_summary_tokens,
                num_windowed=resource.vitdet_num_windowed,
                num_global=resource.vitdet_num_global,
            ),
        )

    if return_checkpoint:
        return radio, chk
    return radio


def get_prefix_state_dict(state_dict: Dict[str, Any], prefix: str):
    mod_state_dict = {
        k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)
    }
    return mod_state_dict