Alexander Bagus
initial commit
d2c9b66
raw
history blame
17.5 kB
import os
from abc import ABC, abstractmethod
import torch
import torchvision.transforms as transforms
from einops import rearrange
from torchvision.datasets.utils import download_url
from typing import Optional, Tuple
# All reward models.
__all__ = ["AestheticReward", "HPSReward", "PickScoreReward", "MPSReward"]
class BaseReward(ABC):
"""An base class for reward models. A custom Reward class must implement two functions below.
"""
def __init__(self):
"""Define your reward model and image transformations (optional) here.
"""
pass
@abstractmethod
def __call__(self, batch_frames: torch.Tensor, batch_prompt: Optional[list[str]]=None) -> Tuple[torch.Tensor, torch.Tensor]:
"""Given batch frames with shape `[B, C, T, H, W]` extracted from a list of videos and a list of prompts
(optional) correspondingly, return the loss and reward computed by your reward model (reduction by mean).
"""
pass
class AestheticReward(BaseReward):
"""Aesthetic Predictor [V2](https://github.com/christophschuhmann/improved-aesthetic-predictor)
and [V2.5](https://github.com/discus0434/aesthetic-predictor-v2-5) reward model.
"""
def __init__(
self,
encoder_path="openai/clip-vit-large-patch14",
predictor_path=None,
version="v2",
device="cpu",
dtype=torch.float16,
max_reward=10,
loss_scale=0.1,
):
from .improved_aesthetic_predictor import ImprovedAestheticPredictor
from ..video_caption.utils.siglip_v2_5 import convert_v2_5_from_siglip
self.encoder_path = encoder_path
self.predictor_path = predictor_path
self.version = version
self.device = device
self.dtype = dtype
self.max_reward = max_reward
self.loss_scale = loss_scale
if self.version != "v2" and self.version != "v2.5":
raise ValueError("Only v2 and v2.5 are supported.")
if self.version == "v2":
assert "clip-vit-large-patch14" in encoder_path.lower()
self.model = ImprovedAestheticPredictor(encoder_path=self.encoder_path, predictor_path=self.predictor_path)
# https://huggingface.co/openai/clip-vit-large-patch14/blob/main/preprocessor_config.json
# TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio.
self.transform = transforms.Compose([
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
])
elif self.version == "v2.5":
assert "siglip-so400m-patch14-384" in encoder_path.lower()
self.model, _ = convert_v2_5_from_siglip(encoder_model_name=self.encoder_path)
# https://huggingface.co/google/siglip-so400m-patch14-384/blob/main/preprocessor_config.json
self.transform = transforms.Compose([
transforms.Resize((384, 384), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
self.model.to(device=self.device, dtype=self.dtype)
self.model.requires_grad_(False)
def __call__(self, batch_frames: torch.Tensor, batch_prompt: Optional[list[str]]=None) -> Tuple[torch.Tensor, torch.Tensor]:
batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w")
batch_loss, batch_reward = 0, 0
for frames in batch_frames:
pixel_values = torch.stack([self.transform(frame) for frame in frames])
pixel_values = pixel_values.to(self.device, dtype=self.dtype)
if self.version == "v2":
reward = self.model(pixel_values)
elif self.version == "v2.5":
reward = self.model(pixel_values).logits.squeeze()
# Convert reward to loss in [0, 1].
if self.max_reward is None:
loss = (-1 * reward) * self.loss_scale
else:
loss = abs(reward - self.max_reward) * self.loss_scale
batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean()
return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0]
class HPSReward(BaseReward):
"""[HPS](https://github.com/tgxs002/HPSv2) v2 and v2.1 reward model.
"""
def __init__(
self,
model_path=None,
version="v2.0",
device="cpu",
dtype=torch.float16,
max_reward=1,
loss_scale=1,
):
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
self.model_path = model_path
self.version = version
self.device = device
self.dtype = dtype
self.max_reward = max_reward
self.loss_scale = loss_scale
self.model, _, _ = create_model_and_transforms(
"ViT-H-14",
"laion2B-s32B-b79K",
precision=self.dtype,
device=self.device,
jit=False,
force_quick_gelu=False,
force_custom_text=False,
force_patch_dropout=False,
force_image_size=None,
pretrained_image=False,
image_mean=None,
image_std=None,
light_augmentation=True,
aug_cfg={},
output_dict=True,
with_score_predictor=False,
with_region_predictor=False,
)
self.tokenizer = get_tokenizer("ViT-H-14")
# https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/preprocessor_config.json
# TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio.
self.transform = transforms.Compose([
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
])
if version == "v2.0":
url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/HPS_v2_compressed.pt"
filename = "HPS_v2_compressed.pt"
md5 = "fd9180de357abf01fdb4eaad64631db4"
elif version == "v2.1":
url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/HPS_v2.1_compressed.pt"
filename = "HPS_v2.1_compressed.pt"
md5 = "4067542e34ba2553a738c5ac6c1d75c0"
else:
raise ValueError("Only v2.0 and v2.1 are supported.")
if self.model_path is None or not os.path.exists(self.model_path):
download_url(url, torch.hub.get_dir(), md5=md5)
model_path = os.path.join(torch.hub.get_dir(), filename)
state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
self.model.load_state_dict(state_dict)
self.model.to(device=self.device, dtype=self.dtype)
self.model.requires_grad_(False)
self.model.eval()
def __call__(self, batch_frames: torch.Tensor, batch_prompt: list[str]) -> Tuple[torch.Tensor, torch.Tensor]:
assert batch_frames.shape[0] == len(batch_prompt)
# Compute batch reward and loss in frame-wise.
batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w")
batch_loss, batch_reward = 0, 0
for frames in batch_frames:
image_inputs = torch.stack([self.transform(frame) for frame in frames])
image_inputs = image_inputs.to(device=self.device, dtype=self.dtype)
text_inputs = self.tokenizer(batch_prompt).to(device=self.device)
outputs = self.model(image_inputs, text_inputs)
image_features, text_features = outputs["image_features"], outputs["text_features"]
logits = image_features @ text_features.T
reward = torch.diagonal(logits)
# Convert reward to loss in [0, 1].
if self.max_reward is None:
loss = (-1 * reward) * self.loss_scale
else:
loss = abs(reward - self.max_reward) * self.loss_scale
batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean()
return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0]
class PickScoreReward(BaseReward):
"""[PickScore](https://github.com/yuvalkirstain/PickScore) reward model.
"""
def __init__(
self,
model_path="yuvalkirstain/PickScore_v1",
device="cpu",
dtype=torch.float16,
max_reward=1,
loss_scale=1,
):
from transformers import AutoProcessor, AutoModel
self.model_path = model_path
self.device = device
self.dtype = dtype
self.max_reward = max_reward
self.loss_scale = loss_scale
# https://huggingface.co/yuvalkirstain/PickScore_v1/blob/main/preprocessor_config.json
self.transform = transforms.Compose([
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
])
self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=self.dtype)
self.model = AutoModel.from_pretrained(model_path, torch_dtype=self.dtype).eval().to(device)
self.model.requires_grad_(False)
self.model.eval()
def __call__(self, batch_frames: torch.Tensor, batch_prompt: list[str]) -> Tuple[torch.Tensor, torch.Tensor]:
assert batch_frames.shape[0] == len(batch_prompt)
# Compute batch reward and loss in frame-wise.
batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w")
batch_loss, batch_reward = 0, 0
for frames in batch_frames:
image_inputs = torch.stack([self.transform(frame) for frame in frames])
image_inputs = image_inputs.to(device=self.device, dtype=self.dtype)
text_inputs = self.processor(
text=batch_prompt,
padding=True,
truncation=True,
max_length=77,
return_tensors="pt",
).to(self.device)
image_features = self.model.get_image_features(pixel_values=image_inputs)
text_features = self.model.get_text_features(**text_inputs)
image_features = image_features / torch.norm(image_features, dim=-1, keepdim=True)
text_features = text_features / torch.norm(text_features, dim=-1, keepdim=True)
logits = image_features @ text_features.T
reward = torch.diagonal(logits)
# Convert reward to loss in [0, 1].
if self.max_reward is None:
loss = (-1 * reward) * self.loss_scale
else:
loss = abs(reward - self.max_reward) * self.loss_scale
batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean()
return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0]
class MPSReward(BaseReward):
"""[MPS](https://github.com/Kwai-Kolors/MPS) reward model.
"""
def __init__(
self,
model_path=None,
device="cpu",
dtype=torch.float16,
max_reward=1,
loss_scale=1,
):
from transformers import AutoTokenizer, AutoConfig
from .MPS.trainer.models.clip_model import CLIPModel
self.model_path = model_path
self.device = device
self.dtype = dtype
self.condition = "light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things."
self.max_reward = max_reward
self.loss_scale = loss_scale
processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
# https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/preprocessor_config.json
# TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio.
self.transform = transforms.Compose([
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
])
# We convert the original [ckpt](http://drive.google.com/file/d/17qrK_aJkVNM75ZEvMEePpLj6L867MLkN/view?usp=sharing)
# (contains the entire model) to a `state_dict`.
url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/MPS_overall.pth"
filename = "MPS_overall.pth"
md5 = "1491cbbbd20565747fe07e7572e2ac56"
if self.model_path is None or not os.path.exists(self.model_path):
download_url(url, torch.hub.get_dir(), md5=md5)
model_path = os.path.join(torch.hub.get_dir(), filename)
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(processor_name_or_path)
self.model = CLIPModel(config)
state_dict = torch.load(model_path, map_location="cpu")
self.model.load_state_dict(state_dict, strict=False)
self.model.to(device=self.device, dtype=self.dtype)
self.model.requires_grad_(False)
self.model.eval()
def _tokenize(self, caption):
input_ids = self.tokenizer(
caption,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
return input_ids
def __call__(
self,
batch_frames: torch.Tensor,
batch_prompt: list[str],
batch_condition: Optional[list[str]] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
if batch_condition is None:
batch_condition = [self.condition] * len(batch_prompt)
batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w")
batch_loss, batch_reward = 0, 0
for frames in batch_frames:
image_inputs = torch.stack([self.transform(frame) for frame in frames])
image_inputs = image_inputs.to(device=self.device, dtype=self.dtype)
text_inputs = self._tokenize(batch_prompt).to(self.device)
condition_inputs = self._tokenize(batch_condition).to(device=self.device)
text_features, image_features = self.model(text_inputs, image_inputs, condition_inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# reward = self.model.logit_scale.exp() * torch.diag(torch.einsum('bd,cd->bc', text_features, image_features))
logits = image_features @ text_features.T
reward = torch.diagonal(logits)
# Convert reward to loss in [0, 1].
if self.max_reward is None:
loss = (-1 * reward) * self.loss_scale
else:
loss = abs(reward - self.max_reward) * self.loss_scale
batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean()
return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0]
if __name__ == "__main__":
import numpy as np
from decord import VideoReader
video_path_list = ["your_video_path_1.mp4", "your_video_path_2.mp4"]
prompt_list = ["your_prompt_1", "your_prompt_2"]
num_sampled_frames = 8
to_tensor = transforms.ToTensor()
sampled_frames_list = []
for video_path in video_path_list:
vr = VideoReader(video_path)
sampled_frame_indices = np.linspace(0, len(vr), num_sampled_frames, endpoint=False, dtype=int)
sampled_frames = vr.get_batch(sampled_frame_indices).asnumpy()
sampled_frames = torch.stack([to_tensor(frame) for frame in sampled_frames])
sampled_frames_list.append(sampled_frames)
sampled_frames = torch.stack(sampled_frames_list)
sampled_frames = rearrange(sampled_frames, "b t c h w -> b c t h w")
aesthetic_reward_v2 = AestheticReward(device="cuda", dtype=torch.bfloat16)
print(f"aesthetic_reward_v2: {aesthetic_reward_v2(sampled_frames)}")
aesthetic_reward_v2_5 = AestheticReward(
encoder_path="google/siglip-so400m-patch14-384", version="v2.5", device="cuda", dtype=torch.bfloat16
)
print(f"aesthetic_reward_v2_5: {aesthetic_reward_v2_5(sampled_frames)}")
hps_reward_v2 = HPSReward(device="cuda", dtype=torch.bfloat16)
print(f"hps_reward_v2: {hps_reward_v2(sampled_frames, prompt_list)}")
hps_reward_v2_1 = HPSReward(version="v2.1", device="cuda", dtype=torch.bfloat16)
print(f"hps_reward_v2_1: {hps_reward_v2_1(sampled_frames, prompt_list)}")
pick_score = PickScoreReward(device="cuda", dtype=torch.bfloat16)
print(f"pick_score_reward: {pick_score(sampled_frames, prompt_list)}")
mps_score = MPSReward(device="cuda", dtype=torch.bfloat16)
print(f"mps_reward: {mps_score(sampled_frames, prompt_list)}")