Spaces:
Runtime error
Runtime error
| import hashlib | |
| import itertools | |
| import json | |
| import logging | |
| import math | |
| import random | |
| import os | |
| import tempfile | |
| import time | |
| import einops | |
| from torch.nn.utils.rnn import pad_sequence | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.nn.parallel.distributed import DistributedDataParallel | |
| from .data import ImageRewardDataset, RankingDataset | |
| from open_clip import get_cast_dtype, CLIP, CustomTextCLIP | |
| from .distributed import is_master, barrier | |
| from .zero_shot import zero_shot_eval | |
| from .precision import get_autocast | |
| from ..open_clip.loss import PreferenceLoss, RankingLoss, HPSLoss | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| def postprocess_clip_output(model_out): | |
| return { | |
| "image_features": model_out[0], | |
| "text_features": model_out[1], | |
| "logit_scale": model_out[2] | |
| } | |
| def unwrap_model(model): | |
| if hasattr(model, 'module'): | |
| return model.module | |
| else: | |
| return model | |
| def backward(total_loss, scaler): | |
| if scaler is not None: | |
| scaler.scale(total_loss).backward() | |
| else: | |
| total_loss.backward() | |
| def random_sampling_iterator(iterators, sampling_ratios, data_types, num_iters): | |
| iterators = [iter(iterator) for iterator in iterators] | |
| num_iterators = len(iterators) | |
| loop_counter = 0 | |
| while loop_counter < num_iters: | |
| current_state = random.getstate() | |
| random.seed(loop_counter) | |
| iterator_idx = random.choices(range(num_iterators), sampling_ratios)[0] | |
| random.setstate(current_state) | |
| yield next(iterators[iterator_idx]), data_types[iterator_idx] | |
| loop_counter += 1 | |
| def train_iters(model, data, iterations, optimizer, scaler, scheduler, dist_model, args, tb_writer=None): | |
| device = torch.device(args.device) | |
| autocast = get_autocast(args.precision) | |
| cast_dtype = get_cast_dtype(args.precision) | |
| model.train() | |
| ce_loss = PreferenceLoss() | |
| mse_loss = torch.nn.MSELoss() | |
| rk_loss = RankingLoss() | |
| hps_loss = HPSLoss() | |
| if args.distill: | |
| dist_model.eval() | |
| for train_set in data['train']: | |
| train_set.set_epoch(0) # set epoch in process safe manner via sampler or shared_epoch | |
| data_types = [d.data_type for d in data['train']] | |
| train_data_sample_ratios = [sample_ratio for sample_ratio, ignore in zip(args.train_data_sample_ratio, args.ignore_in_train) if not ignore] | |
| dataloader = random_sampling_iterator([dataset.dataloader for dataset in data['train']], train_data_sample_ratios, data_types, iterations) | |
| sample_digits = math.ceil(math.log(sum([dataset.dataloader.num_samples for dataset in data['train']]) + 1, 10)) | |
| losses_m = {} | |
| batch_time_m = AverageMeter() | |
| data_time_m = AverageMeter() | |
| end = time.time() | |
| for step, (batch, data_type) in enumerate(dataloader): | |
| # TODO: currently only test on accum_freq==1 | |
| if not args.skip_scheduler: | |
| scheduler(step) | |
| if data_type == 'preference': | |
| images, num_images, labels, texts = batch | |
| texts = texts.to(device=device, non_blocking=True) | |
| elif data_type == 'rating': | |
| images, labels = batch | |
| elif data_type == 'regional': | |
| images, labels = batch | |
| elif data_type == 'ranking': | |
| images, num_images, labels, texts = batch | |
| texts = texts.to(device=device, non_blocking=True) | |
| elif data_type == 'HPD': | |
| images, labels, texts = batch | |
| # num_per_prompts = num_per_prompts.to(device=device, non_blocking=True) | |
| texts = texts.to(device=device, non_blocking=True) | |
| images = images.to(device=device, dtype=cast_dtype, non_blocking=True) | |
| labels = labels.to(device=device, non_blocking=True) | |
| data_time_m.update(time.time() - end) | |
| optimizer.zero_grad() | |
| if args.accum_freq == 1: | |
| with autocast(): | |
| if data_type == 'rating' or args.no_text_condition: | |
| image_features = unwrap_model(model).visual(images) | |
| scores = unwrap_model(model).score_predictor(image_features) | |
| if args.no_text_condition: | |
| paired_logits_list = [logit[:,0] for i, logit in enumerate(scores.split(num_images.tolist()))] | |
| paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-999) | |
| total_loss = F.cross_entropy(paired_logits, labels) | |
| else: | |
| total_loss = mse_loss(scores.squeeze(), labels.to(scores.dtype)) | |
| elif data_type == 'preference' : | |
| output = model(images, texts) | |
| image_features, text_features, logit_scale = output["image_features"], output["text_features"], output["logit_scale"] | |
| # total_loss = loss(image_features, text_features, logit_scale) | |
| logits_per_image = logit_scale * image_features @ text_features.T | |
| total_loss = ce_loss(logits_per_image, num_images, labels) | |
| elif data_type == 'HPD': | |
| output = model(images, texts) | |
| image_features, text_features, logit_scale = output["image_features"], output["text_features"], output["logit_scale"] | |
| logits_per_text = logit_scale * text_features @ image_features.T | |
| total_loss = hps_loss(logits_per_text, labels) | |
| elif data_type == 'ranking': | |
| output = model(images, texts) | |
| image_features, text_features, logit_scale = output["image_features"], output["text_features"], output["logit_scale"] | |
| # logits_per_image = logit_scale * image_features @ text_features.T | |
| score = logit_scale * image_features @ text_features.T | |
| total_loss = rk_loss(score, num_images, labels, args.margin) | |
| elif data_type == 'regional': | |
| # logit_scale = model.logit_scale | |
| feature_map = unwrap_model(model).visual(images, skip_pool=True)[:, 1:] | |
| logits = unwrap_model(model).region_predictor(feature_map) | |
| wh = int(math.sqrt(feature_map.size(1))) | |
| ps = images.size(2) // wh | |
| logits = logits.unflatten(1, (wh, wh))[:,:,:,0] | |
| # downsample the labels to match the feature map size | |
| patches = einops.reduce(labels, 'b (h s1) (w s2) -> b h w', 'mean', s1=ps, s2=ps) | |
| patches = (patches > 0).float() | |
| total_loss = mse_loss(logits.sigmoid(), patches.to(patches.dtype)) | |
| backward(total_loss, scaler) | |
| losses = dict(total_loss=total_loss) | |
| if scaler is not None: | |
| if args.horovod: | |
| optimizer.synchronize() | |
| scaler.unscale_(optimizer) | |
| if args.grad_clip_norm is not None: | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) | |
| with optimizer.skip_synchronize(): | |
| scaler.step(optimizer) | |
| else: | |
| if args.grad_clip_norm is not None: | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| else: | |
| if args.grad_clip_norm is not None: | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) | |
| optimizer.step() | |
| # Note: we clamp to 4.6052 = ln(100), as in the original paper. | |
| with torch.no_grad(): | |
| unwrap_model(model).logit_scale.clamp_(0, math.log(100)) | |
| batch_time_m.update(time.time() - end) | |
| end = time.time() | |
| batch_count = step + 1 | |
| if is_master(args) and (step % args.log_every_n_steps == 0 or batch_count == iterations): | |
| batch_size = len(images) | |
| num_samples = batch_count * args.accum_freq | |
| percent_complete = 100.0 * batch_count / iterations | |
| # NOTE loss is coarsely sampled, just master node and per log update | |
| for key, val in losses.items(): | |
| if key not in losses_m: | |
| losses_m[key] = AverageMeter() | |
| losses_m[key].update(val.item(), batch_size) | |
| logit_scale_scalar = unwrap_model(model).logit_scale.item() | |
| loss_log = " ".join( | |
| [ | |
| f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})" | |
| for loss_name, loss_m in losses_m.items() | |
| ] | |
| ) | |
| samples_per_second = args.accum_freq * args.world_size / batch_time_m.val | |
| samples_per_second_per_gpu = args.accum_freq / batch_time_m.val | |
| logging.info( | |
| f"Train iterations: [{num_samples:>{sample_digits}}/{iterations} ({percent_complete:.0f}%)] " | |
| f"Data (t): {data_time_m.avg:.3f} " | |
| f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu " | |
| f"LR: {optimizer.param_groups[0]['lr']:5f} " | |
| f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log | |
| ) | |
| # Save train loss / etc. Using non avg meter values as loggers have their own smoothing | |
| log_data = { | |
| "data_time": data_time_m.val, | |
| "batch_time": batch_time_m.val, | |
| "samples_per_second": samples_per_second, | |
| "samples_per_second_per_gpu": samples_per_second_per_gpu, | |
| "scale": logit_scale_scalar, | |
| "lr": optimizer.param_groups[0]["lr"] | |
| } | |
| log_data.update({name:val.val for name,val in losses_m.items()}) | |
| for name, val in log_data.items(): | |
| name = "train/" + name | |
| if tb_writer is not None: | |
| tb_writer.add_scalar(name, val, step) | |
| # resetting batch / data time meters per log window | |
| batch_time_m.reset() | |
| data_time_m.reset() | |
| def evaluate_preference(model, data, args): | |
| model = unwrap_model(model) | |
| model.eval() | |
| dataloader = data.dataloader | |
| samples_per_val = dataloader.num_samples | |
| device = torch.device(args.device) | |
| autocast = get_autocast(args.precision) | |
| cast_dtype = get_cast_dtype(args.precision) | |
| total = 0 | |
| correct = 0 | |
| with torch.no_grad(): | |
| for i, batch in enumerate(dataloader): | |
| if i % args.world_size != args.rank: | |
| continue | |
| images, num_images, labels, texts = batch | |
| images = images.to(device=device, dtype=cast_dtype, non_blocking=True) | |
| texts = texts.to(device=device, non_blocking=True) | |
| with autocast(): | |
| if args.no_text_condition: | |
| image_features = model.visual(images) | |
| logit_scale = model.logit_scale | |
| scores = model.score_predictor(image_features) | |
| paired_logits_list = [logit[:,0] for i, logit in enumerate(scores.split(num_images.tolist()))] | |
| else: | |
| outputs = model(images, texts) | |
| image_features, text_features, logit_scale = outputs["image_features"], outputs["text_features"], outputs["logit_scale"] | |
| logits_per_image = logit_scale * image_features @ text_features.T | |
| paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))] | |
| predicted = torch.tensor([k.argmax().item() for k in paired_logits_list]) | |
| correct += (predicted == labels).int().sum().item() | |
| total += predicted.numel() | |
| # write to a temp file | |
| file_name = hashlib.md5(str(args.name).encode()).hexdigest() | |
| with open(f"{file_name}_{args.rank}.json", "w") as f: | |
| json.dump(dict( | |
| correct=correct, | |
| total=total, | |
| ), f) | |
| time.sleep(0.1) | |
| barrier(args) | |
| correct = 0 | |
| total = 0 | |
| if is_master(args): | |
| for i in range(args.world_size): | |
| with open(f"{file_name}_{i}.json", "r") as f: | |
| data = json.load(f) | |
| correct += data["correct"] | |
| total += data["total"] | |
| os.remove(f"{file_name}_{i}.json") | |
| logging.info( | |
| f"Final Acc: {correct / total:.4f}\t") | |
| return correct / (total + 1e-6) | |
| def evaluate_regional(model, data, args): | |
| dataloader = data.dataloader | |
| samples_per_val = dataloader.num_samples | |
| device = torch.device(args.device) | |
| autocast = get_autocast(args.precision) | |
| cast_dtype = get_cast_dtype(args.precision) | |
| num_samples = len(dataloader) | |
| threshold = 0.5 | |
| with torch.no_grad(): | |
| score = 0 | |
| total = 0 | |
| for i, batch in enumerate(dataloader): | |
| images, labels = batch | |
| images = images.to(device=device, dtype=cast_dtype, non_blocking=True) | |
| labels = labels.to(device=device, non_blocking=True) | |
| with autocast(): | |
| feature_map = model.visual(images, skip_pool=True)[:, 1:] | |
| logits = model.region_predictor(feature_map) | |
| wh = int(math.sqrt(feature_map.size(1))) | |
| ps = images.size(2) // wh | |
| logits = logits.unflatten(1, (wh, wh))[:,:,:,0] | |
| # downsample the labels to match the feature map size | |
| patches = einops.reduce(labels, 'b (h s1) (w s2) -> b h w', 'mean', s1=ps, s2=ps) | |
| patches = (patches > 0).float() | |
| pred_mask = (logits.sigmoid() > threshold).float() | |
| #calc IOU | |
| intersection = (pred_mask * patches).sum() | |
| union = pred_mask.sum() + patches.sum() - intersection | |
| iou_score = intersection / union | |
| score += iou_score | |
| total += 1 | |
| if is_master(args) and (i % 100) == 0: | |
| logging.info( | |
| # f"[{i} / {samples_per_val}]\t" | |
| f"[{i} / {len(dataloader)}]\t" | |
| f"Current IoU: {score / (total + 0.001):.4f}\t") | |
| if is_master(args): | |
| logging.info( | |
| f"Final IoU: {score / (total + 0.001):.4f}\t") | |
| return score / (total + 0.001) | |
| def inversion_score(p1, p2): | |
| assert len(p1) == len(p2), f'{len(p1)}, {len(p2)}' | |
| n = len(p1) | |
| cnt = 0 | |
| for i in range(n-1): | |
| for j in range(i+1, n): | |
| if p1[i] > p1[j] and p2[i] < p2[j]: | |
| cnt += 1 | |
| elif p1[i] < p1[j] and p2[i] > p2[j]: | |
| cnt += 1 | |
| return 1 - cnt / (n * (n - 1) / 2) | |
| def model_pair_score(score:dict, p1, p2, num_image): | |
| model_pairs = set() | |
| for i in range(num_image): | |
| if i not in score.keys(): | |
| score[i] = {} | |
| for j in range(num_image): | |
| if j not in score[i].keys(): | |
| score[i][j] = 0 | |
| if j == i or (i, j) in model_pairs or (j, i) in model_pairs: | |
| continue | |
| model_pairs.add((i,j)) | |
| if (p1[i] - p1[j]) * (p2[i] - p2[j]) > 0: | |
| score[i][j] += 1 | |
| return score | |
| def all_gather(tensor): | |
| world_size = torch.distributed.get_world_size() | |
| tensor_list = [torch.ones_like(tensor) for _ in range(world_size)] | |
| torch.distributed.all_gather(tensor_list, tensor, async_op=False) | |
| return torch.cat(tensor_list, dim=0) | |
| def evaluate_ranking(model, data, args): | |
| model = unwrap_model(model) | |
| model.eval() | |
| dataloader = data.dataloader | |
| samples_per_val = dataloader.num_samples | |
| device = torch.device(args.device) | |
| autocast = get_autocast(args.precision) | |
| cast_dtype = get_cast_dtype(args.precision) | |
| score = 0 | |
| # pair_score = {} | |
| with torch.no_grad(): | |
| for i, batch in enumerate(dataloader): | |
| if i % args.world_size != args.local_rank: | |
| continue | |
| images, num_images, labels, texts = batch | |
| images = images.to(device=device, dtype=cast_dtype, non_blocking=True) | |
| texts = texts.to(device=device, non_blocking=True) | |
| num_images = num_images.to(device=device, non_blocking=True) | |
| labels = labels.to(device=device, non_blocking=True) | |
| with autocast(): | |
| if args.no_text_condition: | |
| image_features = model.visual(images) | |
| logit_scale = model.logit_scale | |
| scores = model.score_predictor(image_features) | |
| paired_logits_list = [logit[:,0] for i, logit in enumerate(scores.split(num_images.tolist()))] | |
| else: | |
| outputs = model(images, texts) | |
| image_features, text_features, logit_scale = outputs["image_features"], outputs["text_features"], outputs["logit_scale"] | |
| logits_per_image = logit_scale * image_features @ text_features.T | |
| paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))] | |
| predicted = [torch.argsort(-k) for k in paired_logits_list] | |
| hps_ranking = [[predicted[i].tolist().index(j) for j in range(n)] for i,n in enumerate(num_images)] | |
| labels = [label for label in labels.split(num_images.tolist())] | |
| if isinstance(dataloader.dataset, RankingDataset): | |
| score += sum([inversion_score(hps_ranking[i], labels[i]) for i in range(len(hps_ranking))]) | |
| elif isinstance(dataloader.dataset, ImageRewardDataset): | |
| score +=sum([calc_ImageReward(paired_logits_list[i].tolist(), labels[i]) for i in range(len(hps_ranking))]) | |
| # write score to a tempfile, file name is a hash string | |
| file_name = hashlib.md5(str(args.name).encode()).hexdigest() | |
| with open(f"{file_name}_{args.rank}.tmp", "w") as f: | |
| f.write(str(score)) | |
| time.sleep(0.1) | |
| barrier(args) | |
| score = 0 | |
| if is_master(args): | |
| for i in range(args.world_size): | |
| with open(f"{file_name}_{i}.tmp", "r") as f: | |
| score += float(f.read()) | |
| os.remove(f"{file_name}_{i}.tmp") | |
| score = score / samples_per_val | |
| logging.info( | |
| f"Final Acc: {score:.4f}\t") | |
| # return score, pair_score | |
| return score | |
| def calc_ImageReward( pred, gt): | |
| # using inversion score calculate method in ImageReward | |
| # There's some little difference because ImageReward benchmark has tie rankings | |
| tol_cnt = 0. | |
| true_cnt = 0. | |
| for idx in range(len(gt)): | |
| item_base = gt | |
| item = pred | |
| for i in range(len(item_base)): | |
| for j in range(i+1, len(item_base)): | |
| if item_base[i] > item_base[j]: | |
| if item[i] >= item[j]: | |
| tol_cnt += 1 | |
| elif item[i] < item[j]: | |
| tol_cnt += 1 | |
| true_cnt += 1 | |
| elif item_base[i] < item_base[j]: | |
| if item[i] > item[j]: | |
| tol_cnt += 1 | |
| true_cnt += 1 | |
| elif item[i] <= item[j]: | |
| tol_cnt += 1 | |
| return true_cnt / tol_cnt | |
| def get_clip_metrics(image_features, text_features, logit_scale): | |
| metrics = {} | |
| logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() | |
| logits_per_text = logits_per_image.t().detach().cpu() | |
| logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} | |
| ground_truth = torch.arange(len(text_features)).view(-1, 1) | |
| for name, logit in logits.items(): | |
| ranking = torch.argsort(logit, descending=True) | |
| preds = torch.where(ranking == ground_truth)[1] | |
| preds = preds.detach().cpu().numpy() | |
| metrics[f"{name}_mean_rank"] = preds.mean() + 1 | |
| metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 | |
| for k in [1, 5, 10]: | |
| metrics[f"{name}_R@{k}"] = np.mean(preds < k) | |
| return metrics | |
| def maybe_compute_generative_loss(model_out): | |
| if "logits" in model_out and "labels" in model_out: | |
| token_logits = model_out["logits"] | |
| token_labels = model_out["labels"] | |
| return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels) | |