Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| from abc import ABC, abstractmethod | |
| import torch | |
| import torchvision.transforms as transforms | |
| from einops import rearrange | |
| from torchvision.datasets.utils import download_url | |
| from typing import Optional, Tuple | |
| # All reward models. | |
| __all__ = ["AestheticReward", "HPSReward", "PickScoreReward", "MPSReward"] | |
| class BaseReward(ABC): | |
| """An base class for reward models. A custom Reward class must implement two functions below. | |
| """ | |
| def __init__(self): | |
| """Define your reward model and image transformations (optional) here. | |
| """ | |
| pass | |
| def __call__(self, batch_frames: torch.Tensor, batch_prompt: Optional[list[str]]=None) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Given batch frames with shape `[B, C, T, H, W]` extracted from a list of videos and a list of prompts | |
| (optional) correspondingly, return the loss and reward computed by your reward model (reduction by mean). | |
| """ | |
| pass | |
| class AestheticReward(BaseReward): | |
| """Aesthetic Predictor [V2](https://github.com/christophschuhmann/improved-aesthetic-predictor) | |
| and [V2.5](https://github.com/discus0434/aesthetic-predictor-v2-5) reward model. | |
| """ | |
| def __init__( | |
| self, | |
| encoder_path="openai/clip-vit-large-patch14", | |
| predictor_path=None, | |
| version="v2", | |
| device="cpu", | |
| dtype=torch.float16, | |
| max_reward=10, | |
| loss_scale=0.1, | |
| ): | |
| from .improved_aesthetic_predictor import ImprovedAestheticPredictor | |
| from ..video_caption.utils.siglip_v2_5 import convert_v2_5_from_siglip | |
| self.encoder_path = encoder_path | |
| self.predictor_path = predictor_path | |
| self.version = version | |
| self.device = device | |
| self.dtype = dtype | |
| self.max_reward = max_reward | |
| self.loss_scale = loss_scale | |
| if self.version != "v2" and self.version != "v2.5": | |
| raise ValueError("Only v2 and v2.5 are supported.") | |
| if self.version == "v2": | |
| assert "clip-vit-large-patch14" in encoder_path.lower() | |
| self.model = ImprovedAestheticPredictor(encoder_path=self.encoder_path, predictor_path=self.predictor_path) | |
| # https://huggingface.co/openai/clip-vit-large-patch14/blob/main/preprocessor_config.json | |
| # TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio. | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]), | |
| ]) | |
| elif self.version == "v2.5": | |
| assert "siglip-so400m-patch14-384" in encoder_path.lower() | |
| self.model, _ = convert_v2_5_from_siglip(encoder_model_name=self.encoder_path) | |
| # https://huggingface.co/google/siglip-so400m-patch14-384/blob/main/preprocessor_config.json | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((384, 384), interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| ]) | |
| self.model.to(device=self.device, dtype=self.dtype) | |
| self.model.requires_grad_(False) | |
| def __call__(self, batch_frames: torch.Tensor, batch_prompt: Optional[list[str]]=None) -> Tuple[torch.Tensor, torch.Tensor]: | |
| batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w") | |
| batch_loss, batch_reward = 0, 0 | |
| for frames in batch_frames: | |
| pixel_values = torch.stack([self.transform(frame) for frame in frames]) | |
| pixel_values = pixel_values.to(self.device, dtype=self.dtype) | |
| if self.version == "v2": | |
| reward = self.model(pixel_values) | |
| elif self.version == "v2.5": | |
| reward = self.model(pixel_values).logits.squeeze() | |
| # Convert reward to loss in [0, 1]. | |
| if self.max_reward is None: | |
| loss = (-1 * reward) * self.loss_scale | |
| else: | |
| loss = abs(reward - self.max_reward) * self.loss_scale | |
| batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean() | |
| return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0] | |
| class HPSReward(BaseReward): | |
| """[HPS](https://github.com/tgxs002/HPSv2) v2 and v2.1 reward model. | |
| """ | |
| def __init__( | |
| self, | |
| model_path=None, | |
| version="v2.0", | |
| device="cpu", | |
| dtype=torch.float16, | |
| max_reward=1, | |
| loss_scale=1, | |
| ): | |
| from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer | |
| self.model_path = model_path | |
| self.version = version | |
| self.device = device | |
| self.dtype = dtype | |
| self.max_reward = max_reward | |
| self.loss_scale = loss_scale | |
| self.model, _, _ = create_model_and_transforms( | |
| "ViT-H-14", | |
| "laion2B-s32B-b79K", | |
| precision=self.dtype, | |
| device=self.device, | |
| jit=False, | |
| force_quick_gelu=False, | |
| force_custom_text=False, | |
| force_patch_dropout=False, | |
| force_image_size=None, | |
| pretrained_image=False, | |
| image_mean=None, | |
| image_std=None, | |
| light_augmentation=True, | |
| aug_cfg={}, | |
| output_dict=True, | |
| with_score_predictor=False, | |
| with_region_predictor=False, | |
| ) | |
| self.tokenizer = get_tokenizer("ViT-H-14") | |
| # https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/preprocessor_config.json | |
| # TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio. | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]), | |
| ]) | |
| if version == "v2.0": | |
| url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/HPS_v2_compressed.pt" | |
| filename = "HPS_v2_compressed.pt" | |
| md5 = "fd9180de357abf01fdb4eaad64631db4" | |
| elif version == "v2.1": | |
| url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/HPS_v2.1_compressed.pt" | |
| filename = "HPS_v2.1_compressed.pt" | |
| md5 = "4067542e34ba2553a738c5ac6c1d75c0" | |
| else: | |
| raise ValueError("Only v2.0 and v2.1 are supported.") | |
| if self.model_path is None or not os.path.exists(self.model_path): | |
| download_url(url, torch.hub.get_dir(), md5=md5) | |
| model_path = os.path.join(torch.hub.get_dir(), filename) | |
| state_dict = torch.load(model_path, map_location="cpu")["state_dict"] | |
| self.model.load_state_dict(state_dict) | |
| self.model.to(device=self.device, dtype=self.dtype) | |
| self.model.requires_grad_(False) | |
| self.model.eval() | |
| def __call__(self, batch_frames: torch.Tensor, batch_prompt: list[str]) -> Tuple[torch.Tensor, torch.Tensor]: | |
| assert batch_frames.shape[0] == len(batch_prompt) | |
| # Compute batch reward and loss in frame-wise. | |
| batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w") | |
| batch_loss, batch_reward = 0, 0 | |
| for frames in batch_frames: | |
| image_inputs = torch.stack([self.transform(frame) for frame in frames]) | |
| image_inputs = image_inputs.to(device=self.device, dtype=self.dtype) | |
| text_inputs = self.tokenizer(batch_prompt).to(device=self.device) | |
| outputs = self.model(image_inputs, text_inputs) | |
| image_features, text_features = outputs["image_features"], outputs["text_features"] | |
| logits = image_features @ text_features.T | |
| reward = torch.diagonal(logits) | |
| # Convert reward to loss in [0, 1]. | |
| if self.max_reward is None: | |
| loss = (-1 * reward) * self.loss_scale | |
| else: | |
| loss = abs(reward - self.max_reward) * self.loss_scale | |
| batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean() | |
| return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0] | |
| class PickScoreReward(BaseReward): | |
| """[PickScore](https://github.com/yuvalkirstain/PickScore) reward model. | |
| """ | |
| def __init__( | |
| self, | |
| model_path="yuvalkirstain/PickScore_v1", | |
| device="cpu", | |
| dtype=torch.float16, | |
| max_reward=1, | |
| loss_scale=1, | |
| ): | |
| from transformers import AutoProcessor, AutoModel | |
| self.model_path = model_path | |
| self.device = device | |
| self.dtype = dtype | |
| self.max_reward = max_reward | |
| self.loss_scale = loss_scale | |
| # https://huggingface.co/yuvalkirstain/PickScore_v1/blob/main/preprocessor_config.json | |
| self.transform = transforms.Compose([ | |
| transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.CenterCrop(224), | |
| transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]), | |
| ]) | |
| self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=self.dtype) | |
| self.model = AutoModel.from_pretrained(model_path, torch_dtype=self.dtype).eval().to(device) | |
| self.model.requires_grad_(False) | |
| self.model.eval() | |
| def __call__(self, batch_frames: torch.Tensor, batch_prompt: list[str]) -> Tuple[torch.Tensor, torch.Tensor]: | |
| assert batch_frames.shape[0] == len(batch_prompt) | |
| # Compute batch reward and loss in frame-wise. | |
| batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w") | |
| batch_loss, batch_reward = 0, 0 | |
| for frames in batch_frames: | |
| image_inputs = torch.stack([self.transform(frame) for frame in frames]) | |
| image_inputs = image_inputs.to(device=self.device, dtype=self.dtype) | |
| text_inputs = self.processor( | |
| text=batch_prompt, | |
| padding=True, | |
| truncation=True, | |
| max_length=77, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| image_features = self.model.get_image_features(pixel_values=image_inputs) | |
| text_features = self.model.get_text_features(**text_inputs) | |
| image_features = image_features / torch.norm(image_features, dim=-1, keepdim=True) | |
| text_features = text_features / torch.norm(text_features, dim=-1, keepdim=True) | |
| logits = image_features @ text_features.T | |
| reward = torch.diagonal(logits) | |
| # Convert reward to loss in [0, 1]. | |
| if self.max_reward is None: | |
| loss = (-1 * reward) * self.loss_scale | |
| else: | |
| loss = abs(reward - self.max_reward) * self.loss_scale | |
| batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean() | |
| return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0] | |
| class MPSReward(BaseReward): | |
| """[MPS](https://github.com/Kwai-Kolors/MPS) reward model. | |
| """ | |
| def __init__( | |
| self, | |
| model_path=None, | |
| device="cpu", | |
| dtype=torch.float16, | |
| max_reward=1, | |
| loss_scale=1, | |
| ): | |
| from transformers import AutoTokenizer, AutoConfig | |
| from .MPS.trainer.models.clip_model import CLIPModel | |
| self.model_path = model_path | |
| self.device = device | |
| self.dtype = dtype | |
| self.condition = "light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things." | |
| self.max_reward = max_reward | |
| self.loss_scale = loss_scale | |
| processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" | |
| # https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/preprocessor_config.json | |
| # TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio. | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]), | |
| ]) | |
| # We convert the original [ckpt](http://drive.google.com/file/d/17qrK_aJkVNM75ZEvMEePpLj6L867MLkN/view?usp=sharing) | |
| # (contains the entire model) to a `state_dict`. | |
| url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/MPS_overall.pth" | |
| filename = "MPS_overall.pth" | |
| md5 = "1491cbbbd20565747fe07e7572e2ac56" | |
| if self.model_path is None or not os.path.exists(self.model_path): | |
| download_url(url, torch.hub.get_dir(), md5=md5) | |
| model_path = os.path.join(torch.hub.get_dir(), filename) | |
| self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True) | |
| config = AutoConfig.from_pretrained(processor_name_or_path) | |
| self.model = CLIPModel(config) | |
| state_dict = torch.load(model_path, map_location="cpu") | |
| self.model.load_state_dict(state_dict, strict=False) | |
| self.model.to(device=self.device, dtype=self.dtype) | |
| self.model.requires_grad_(False) | |
| self.model.eval() | |
| def _tokenize(self, caption): | |
| input_ids = self.tokenizer( | |
| caption, | |
| max_length=self.tokenizer.model_max_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt" | |
| ).input_ids | |
| return input_ids | |
| def __call__( | |
| self, | |
| batch_frames: torch.Tensor, | |
| batch_prompt: list[str], | |
| batch_condition: Optional[list[str]] = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if batch_condition is None: | |
| batch_condition = [self.condition] * len(batch_prompt) | |
| batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w") | |
| batch_loss, batch_reward = 0, 0 | |
| for frames in batch_frames: | |
| image_inputs = torch.stack([self.transform(frame) for frame in frames]) | |
| image_inputs = image_inputs.to(device=self.device, dtype=self.dtype) | |
| text_inputs = self._tokenize(batch_prompt).to(self.device) | |
| condition_inputs = self._tokenize(batch_condition).to(device=self.device) | |
| text_features, image_features = self.model(text_inputs, image_inputs, condition_inputs) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| # reward = self.model.logit_scale.exp() * torch.diag(torch.einsum('bd,cd->bc', text_features, image_features)) | |
| logits = image_features @ text_features.T | |
| reward = torch.diagonal(logits) | |
| # Convert reward to loss in [0, 1]. | |
| if self.max_reward is None: | |
| loss = (-1 * reward) * self.loss_scale | |
| else: | |
| loss = abs(reward - self.max_reward) * self.loss_scale | |
| batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean() | |
| return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0] | |
| if __name__ == "__main__": | |
| import numpy as np | |
| from decord import VideoReader | |
| video_path_list = ["your_video_path_1.mp4", "your_video_path_2.mp4"] | |
| prompt_list = ["your_prompt_1", "your_prompt_2"] | |
| num_sampled_frames = 8 | |
| to_tensor = transforms.ToTensor() | |
| sampled_frames_list = [] | |
| for video_path in video_path_list: | |
| vr = VideoReader(video_path) | |
| sampled_frame_indices = np.linspace(0, len(vr), num_sampled_frames, endpoint=False, dtype=int) | |
| sampled_frames = vr.get_batch(sampled_frame_indices).asnumpy() | |
| sampled_frames = torch.stack([to_tensor(frame) for frame in sampled_frames]) | |
| sampled_frames_list.append(sampled_frames) | |
| sampled_frames = torch.stack(sampled_frames_list) | |
| sampled_frames = rearrange(sampled_frames, "b t c h w -> b c t h w") | |
| aesthetic_reward_v2 = AestheticReward(device="cuda", dtype=torch.bfloat16) | |
| print(f"aesthetic_reward_v2: {aesthetic_reward_v2(sampled_frames)}") | |
| aesthetic_reward_v2_5 = AestheticReward( | |
| encoder_path="google/siglip-so400m-patch14-384", version="v2.5", device="cuda", dtype=torch.bfloat16 | |
| ) | |
| print(f"aesthetic_reward_v2_5: {aesthetic_reward_v2_5(sampled_frames)}") | |
| hps_reward_v2 = HPSReward(device="cuda", dtype=torch.bfloat16) | |
| print(f"hps_reward_v2: {hps_reward_v2(sampled_frames, prompt_list)}") | |
| hps_reward_v2_1 = HPSReward(version="v2.1", device="cuda", dtype=torch.bfloat16) | |
| print(f"hps_reward_v2_1: {hps_reward_v2_1(sampled_frames, prompt_list)}") | |
| pick_score = PickScoreReward(device="cuda", dtype=torch.bfloat16) | |
| print(f"pick_score_reward: {pick_score(sampled_frames, prompt_list)}") | |
| mps_score = MPSReward(device="cuda", dtype=torch.bfloat16) | |
| print(f"mps_reward: {mps_score(sampled_frames, prompt_list)}") |