Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # Portions Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| import einops | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| class Normalize(nn.Module): | |
| def __init__(self, dim: int) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x): | |
| return torch.nn.functional.normalize(x, dim=self.dim, p=2) | |
| class LearnableLogitScaling(nn.Module): | |
| def __init__( | |
| self, | |
| logit_scale_init: float = 1 / 0.07, | |
| learnable: bool = True, | |
| max_logit_scale: float = 100, | |
| ) -> None: | |
| super().__init__() | |
| self.max_logit_scale = max_logit_scale | |
| self.logit_scale_init = logit_scale_init | |
| self.learnable = learnable | |
| log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init) | |
| if learnable: | |
| self.log_logit_scale = nn.Parameter(log_logit_scale) | |
| else: | |
| self.register_buffer("log_logit_scale", log_logit_scale) | |
| def forward(self, x): | |
| return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x | |
| def extra_repr(self): | |
| st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}" | |
| return st | |
| class EinOpsRearrange(nn.Module): | |
| def __init__(self, rearrange_expr: str, **kwargs) -> None: | |
| super().__init__() | |
| self.rearrange_expr = rearrange_expr | |
| self.kwargs = kwargs | |
| def forward(self, x): | |
| assert isinstance(x, torch.Tensor) | |
| return einops.rearrange(x, self.rearrange_expr, **self.kwargs) | |
| class VerboseNNModule(nn.Module): | |
| """ | |
| Wrapper around nn.Module that prints registered buffers and parameter names. | |
| """ | |
| def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str: | |
| st = ( | |
| "(" | |
| + name | |
| + "): " | |
| + "tensor(" | |
| + str(tuple(tensor[1].shape)) | |
| + ", requires_grad=" | |
| + str(tensor[1].requires_grad) | |
| + ")\n" | |
| ) | |
| return st | |
| def extra_repr(self) -> str: | |
| named_modules = set() | |
| for p in self.named_modules(): | |
| named_modules.update([p[0]]) | |
| named_modules = list(named_modules) | |
| string_repr = "" | |
| for p in self.named_parameters(): | |
| name = p[0].split(".")[0] | |
| if name not in named_modules: | |
| string_repr += self.get_readable_tensor_repr(name, p) | |
| for p in self.named_buffers(): | |
| name = p[0].split(".")[0] | |
| string_repr += self.get_readable_tensor_repr(name, p) | |
| return string_repr | |
| def cast_if_src_dtype( | |
| tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype | |
| ): | |
| updated = False | |
| if tensor.dtype == src_dtype: | |
| tensor = tensor.to(dtype=tgt_dtype) | |
| updated = True | |
| return tensor, updated | |
| class QuickGELU(nn.Module): | |
| # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166 | |
| def forward(self, x: torch.Tensor): | |
| return x * torch.sigmoid(1.702 * x) | |
| class SelectElement(nn.Module): | |
| def __init__(self, index) -> None: | |
| super().__init__() | |
| self.index = index | |
| def forward(self, x): | |
| assert x.ndim >= 3 | |
| return x[:, self.index, ...] | |
| class SelectEOSAndProject(nn.Module): | |
| """ | |
| Text Pooling used in OpenCLIP | |
| """ | |
| def __init__(self, proj: nn.Module) -> None: | |
| super().__init__() | |
| self.proj = proj | |
| def forward(self, x, seq_len): | |
| assert x.ndim == 3 | |
| # x is of shape B x L x D | |
| # take features from the eot embedding (eot_token is the highest number in each sequence) | |
| x = x[torch.arange(x.shape[0]), seq_len] | |
| x = self.proj(x) | |
| return x | |