File size: 5,464 Bytes
626ec32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import torch
import math
import time
import numpy as np
class ModelWrapper(torch.nn.Module):
def __init__(self, model, feature_dim, num_classes, normalize=False, initial_weights=None, checkpoint_path = None):
super(ModelWrapper, self).__init__()
self.model = model
self.classification_head = torch.nn.Linear(feature_dim, num_classes)
self.normalize = normalize
if initial_weights is None:
initial_weights = torch.zeros_like(self.classification_head.weight)
torch.nn.init.kaiming_uniform_(initial_weights, a=math.sqrt(5))
self.classification_head.weight = torch.nn.Parameter(initial_weights.clone())
self.classification_head.bias = torch.nn.Parameter(torch.zeros_like(self.classification_head.bias))
# Note: modified. Get rid of the language part.
if hasattr(self.model, 'transformer'):
delattr(self.model, 'transformer')
if checkpoint_path:
print("Loading checkpoint", checkpoint_path)
checkpoint = torch.load(checkpoint_path)
checkpoint.pop('classification_head.weight')
checkpoint.pop('classification_head.bias')
model.load_state_dict(checkpoint, strict=False)
def forward(self, images, return_features=False):
features = self.model.encode_image(images)
if self.normalize:
features = features / features.norm(dim=-1, keepdim=True)
logits = self.classification_head(features)
if return_features:
return logits, features
return logits
def get_model_from_sd(state_dict, base_model):
feature_dim = state_dict['classification_head.weight'].shape[1]
num_classes = state_dict['classification_head.weight'].shape[0]
model = ModelWrapper(base_model, feature_dim, num_classes, normalize=True)
for p in model.parameters():
p.data = p.data.float()
model.load_state_dict(state_dict)
model = model.cuda()
devices = [x for x in range(torch.cuda.device_count())]
return torch.nn.DataParallel(model, device_ids=devices)
def maybe_dictionarize_batch(batch):
if isinstance(batch, dict):
return batch
if len(batch) == 2:
return {'images': batch[0], 'labels': batch[1]}
elif len(batch) == 3:
return {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]}
else:
raise ValueError(f'Unexpected number of elements: {len(batch)}')
def test_model_on_dataset(model, dataset):
model.eval()
device = 'cuda'
with torch.no_grad():
top1, correct, n = 0., 0., 0.
end = time.time()
loader = dataset.test_loader
if type(dataset).__name__ == 'ImageNet2p':
loader = dataset.train_loader
# assert to make sure the imagenet held-out minival logic is consistent across machines.
# tested on a few machines but if this fails for you please submit an issue and we will resolve.
assert dataset.train_dataset.__getitem__(dataset.sampler.indices[1000])['image_paths'].endswith('n01675722_4108.JPEG')
for i, batch in enumerate(loader):
batch = maybe_dictionarize_batch(batch)
inputs, labels = batch['images'].cuda(), batch['labels'].cuda()
data_time = time.time() - end
y = labels
if 'image_paths' in batch:
image_paths = batch['image_paths']
logits = model(inputs)
projection_fn = getattr(dataset, 'project_logits', None)
if projection_fn is not None:
logits = projection_fn(logits, device)
if hasattr(dataset, 'project_labels'):
y = dataset.project_labels(y, device)
if isinstance(logits, list):
logits = logits[0]
pred = logits.argmax(dim=1, keepdim=True).to(device)
if hasattr(dataset, 'accuracy'):
acc1, num_total = dataset.accuracy(logits, y, image_paths, None)
correct += acc1
n += num_total
else:
correct += pred.eq(y.view_as(pred)).sum().item()
n += y.size(0)
batch_time = time.time() - end
end = time.time()
if i % 20 == 0:
percent_complete = 100.0 * i / len(loader)
print(
f"[{percent_complete:.0f}% {i}/{len(loader)}]\t"
f"Acc: {100 * (correct/n):.2f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}"
)
top1 = correct / n
return top1
def assign_learning_rate(param_group, new_lr):
param_group["lr"] = new_lr
def _warmup_lr(base_lr, warmup_length, step):
return base_lr * (step + 1) / warmup_length
def cosine_lr(optimizer, base_lrs, warmup_length, steps):
if not isinstance(base_lrs, list):
base_lrs = [base_lrs for _ in optimizer.param_groups]
assert len(base_lrs) == len(optimizer.param_groups)
def _lr_adjuster(step):
for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
e = step - warmup_length
es = steps - warmup_length
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
assign_learning_rate(param_group, lr)
return _lr_adjuster
|