File size: 7,563 Bytes
3ed0796 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
dependencies = ["torch", "timm", "einops"]
import os
from typing import Dict, Any, Optional, Union, List
import warnings
import torch
from torch.hub import load_state_dict_from_url
from timm.models import clean_state_dict
from .radio.adaptor_registry import adaptor_registry
from .radio.common import DEFAULT_VERSION, RadioResource, RESOURCE_MAP
from .radio.enable_damp import configure_damp_from_args
from .radio.enable_spectral_reparam import disable_spectral_reparam, configure_spectral_reparam_from_args
from .radio.feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
from .radio.radio_model import RADIOModel, create_model_from_args
from .radio.input_conditioner import get_default_conditioner
from .radio.vitdet import apply_vitdet_arch, VitDetArgs
def radio_model(
version: str = "",
progress: bool = True,
adaptor_names: Union[str, List[str]] = None,
vitdet_window_size: Optional[int] = None,
return_checkpoint: bool = False,
support_packing: bool=False,
**kwargs,
) -> RADIOModel:
if not version:
version = DEFAULT_VERSION
if os.path.isfile(version):
chk = torch.load(version, map_location="cpu", weights_only=False)
resource = RadioResource(version, patch_size=None, max_resolution=None, preferred_resolution=None)
else:
resource = RESOURCE_MAP[version]
chk = load_state_dict_from_url(
resource.url, progress=progress, map_location="cpu", weights_only=False,
)
if "state_dict_ema" in chk:
state_dict = chk["state_dict_ema"]
chk['args'].spectral_reparam = False
else:
state_dict = chk["state_dict"]
args = chk["args"]
args.support_packing = support_packing
mod = create_model_from_args(args)
mod_state_dict = get_prefix_state_dict(state_dict, "base_model.")
if args.spectral_reparam:
configure_spectral_reparam_from_args(mod, args, state_dict_guidance=mod_state_dict)
if getattr(args, 'damp', None):
configure_damp_from_args(mod, args)
state_dict = clean_state_dict(state_dict)
key_warn = mod.load_state_dict(mod_state_dict, strict=False)
if key_warn.missing_keys:
warnings.warn(f'Missing keys in state dict: {key_warn.missing_keys}')
if key_warn.unexpected_keys:
warnings.warn(f'Unexpected keys in state dict: {key_warn.unexpected_keys}')
if chk['args'].spectral_reparam:
# Spectral reparametrization uses PyTorch's "parametrizations" API. The idea behind
# the method is that instead of there being a `weight` tensor for certain Linear layers
# in the model, we make it a dynamically computed function. During training, this
# helps stabilize the model. However, for downstream use cases, it shouldn't be necessary.
# Disabling it in this context means that instead of having `w' = f(w)`, we just compute `w' = f(w)`
# once, during this function call, and replace the parametrization with the realized weights.
# This makes the model run faster, and also use less memory.
disable_spectral_reparam(mod)
chk['args'].spectral_reparam = False
conditioner = get_default_conditioner()
conditioner.load_state_dict(get_prefix_state_dict(state_dict, "input_conditioner."))
dtype = getattr(chk['args'], 'dtype', torch.float32)
mod.to(dtype=dtype)
conditioner.dtype = dtype
cls_token_per_teacher = getattr(chk['args'], 'cls_token_per_teacher', True)
if cls_token_per_teacher:
name_to_idx_map = dict()
for i, t in enumerate(chk['args'].teachers):
if t.get('use_summary', True):
name = t['name']
if name not in name_to_idx_map:
name_to_idx_map[name] = i
summary_idxs = torch.tensor(sorted(name_to_idx_map.values()), dtype=torch.int64)
else:
summary_idxs = torch.tensor([0], dtype=torch.int64)
if adaptor_names is None:
adaptor_names = []
elif isinstance(adaptor_names, str):
adaptor_names = [adaptor_names]
teachers = chk["args"].teachers
adaptors = dict()
for adaptor_name in adaptor_names:
for tidx, tconf in enumerate(teachers):
if tconf["name"] == adaptor_name:
break
else:
raise ValueError(f'Unable to find the specified adaptor name. Known names: {list(t["name"] for t in teachers)}')
ttype = tconf["type"]
pf_idx_head = f'_heads.{tidx}'
pf_name_head = f'_heads.{adaptor_name}'
pf_idx_feat = f'_feature_projections.{tidx}'
pf_name_feat = f'_feature_projections.{adaptor_name}'
adaptor_state = dict()
for k, v in state_dict.items():
if k.startswith(pf_idx_head):
adaptor_state['summary' + k[len(pf_idx_head):]] = v
elif k.startswith(pf_name_head):
adaptor_state['summary' + k[len(pf_name_head):]] = v
elif k.startswith(pf_idx_feat):
adaptor_state['feature' + k[len(pf_idx_feat):]] = v
elif k.startswith(pf_name_feat):
adaptor_state['feature' + k[len(pf_name_feat):]] = v
adaptor = adaptor_registry.create_adaptor(ttype, chk["args"], tconf, adaptor_state)
adaptor.head_idx = tidx if cls_token_per_teacher else 0
adaptors[adaptor_name] = adaptor
feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.')
feature_normalizer = None
if feat_norm_sd:
feature_normalizer = FeatureNormalizer(feat_norm_sd['mean'].shape[0], dtype=dtype)
feature_normalizer.load_state_dict(feat_norm_sd)
inter_feat_norm_sd = get_prefix_state_dict(state_dict, '_intermediate_feature_normalizer.')
inter_feature_normalizer = None
if inter_feat_norm_sd:
inter_feature_normalizer = IntermediateFeatureNormalizer(
*inter_feat_norm_sd['means'].shape[:2],
rot_per_layer=inter_feat_norm_sd['rotation'].ndim == 3,
dtype=dtype
)
inter_feature_normalizer.load_state_dict(inter_feat_norm_sd)
radio = RADIOModel(
mod,
conditioner,
summary_idxs=summary_idxs,
patch_size=resource.patch_size,
max_resolution=resource.max_resolution,
window_size=vitdet_window_size,
preferred_resolution=resource.preferred_resolution,
adaptors=adaptors,
feature_normalizer=feature_normalizer,
inter_feature_normalizer=inter_feature_normalizer,
)
if vitdet_window_size is not None:
apply_vitdet_arch(
mod,
VitDetArgs(
vitdet_window_size,
radio.num_summary_tokens,
num_windowed=resource.vitdet_num_windowed,
num_global=resource.vitdet_num_global,
),
)
if return_checkpoint:
return radio, chk
return radio
def get_prefix_state_dict(state_dict: Dict[str, Any], prefix: str):
mod_state_dict = {
k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)
}
return mod_state_dict
|