Spaces:
Runtime error
Runtime error
| import functools | |
| import importlib | |
| import os | |
| from functools import partial | |
| from inspect import isfunction | |
| import fsspec | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| from safetensors.torch import load_file as load_safetensors | |
| import torchaudio | |
| import math | |
| from einops import rearrange | |
| import torchvision | |
| try: | |
| import moviepy.editor as mpy | |
| except ImportError: | |
| pass | |
| import contextlib | |
| import io | |
| from functools import wraps | |
| import warnings | |
| def save_audio_video( | |
| video, | |
| audio=None, | |
| frame_rate=25, | |
| sample_rate=16000, | |
| save_path="temp.mp4", | |
| keep_intermediate=False, | |
| ): | |
| """Save audio and video to a single file. | |
| video: (t, c, h, w) | |
| audio: (channels t) | |
| """ | |
| save_path = str(save_path) | |
| video_tensor = rearrange(video, "t c h w -> t h w c").astype(np.uint8) | |
| if audio is not None: | |
| # Assuming audio is a tensor of shape (channels, samples) | |
| audio_tensor = audio | |
| torchvision.io.write_video( | |
| save_path, | |
| video_tensor, | |
| fps=frame_rate, | |
| audio_array=audio_tensor, | |
| audio_fps=sample_rate, | |
| video_codec="h264", # Specify a codec to address the error | |
| audio_codec="aac", | |
| ) | |
| else: | |
| torchvision.io.write_video( | |
| save_path, | |
| video_tensor, | |
| fps=frame_rate, | |
| video_codec="h264", # Specify a codec to address the error | |
| audio_codec="aac", | |
| ) | |
| return 1 | |
| def get_raw_audio(audio_path, audio_rate, fps=25): | |
| audio, sr = torchaudio.load(audio_path, channels_first=True) | |
| if audio.shape[0] > 1: | |
| audio = audio.mean(0, keepdim=True) | |
| audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=audio_rate)[0] | |
| samples_per_frame = math.ceil(audio_rate / fps) | |
| n_frames = audio.shape[-1] / samples_per_frame | |
| if not n_frames.is_integer(): | |
| print("Audio shape before trim_pad_audio: ", audio.shape) | |
| audio = trim_pad_audio( | |
| audio, audio_rate, max_len_raw=math.ceil(n_frames) * samples_per_frame | |
| ) | |
| print("Audio shape after trim_pad_audio: ", audio.shape) | |
| audio = rearrange(audio, "(f s) -> f s", s=samples_per_frame) | |
| return audio | |
| def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None): | |
| len_file = audio.shape[-1] | |
| if max_len_sec or max_len_raw: | |
| max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr) | |
| if len_file < int(max_len): | |
| # dummy = np.zeros((1, int(max_len_sec * sr) - len_file)) | |
| # extened_wav = np.concatenate((audio_data, dummy[0])) | |
| extened_wav = torch.nn.functional.pad( | |
| audio, (0, int(max_len) - len_file), "constant" | |
| ) | |
| else: | |
| extened_wav = audio[:, : int(max_len)] | |
| else: | |
| extened_wav = audio | |
| return extened_wav | |
| def disabled_train(self, mode=True): | |
| """Overwrite model.train with this function to make sure train/eval mode | |
| does not change anymore.""" | |
| return self | |
| def get_string_from_tuple(s): | |
| try: | |
| # Check if the string starts and ends with parentheses | |
| if s[0] == "(" and s[-1] == ")": | |
| # Convert the string to a tuple | |
| t = eval(s) | |
| # Check if the type of t is tuple | |
| if type(t) == tuple: | |
| return t[0] | |
| else: | |
| pass | |
| except: | |
| pass | |
| return s | |
| def is_power_of_two(n): | |
| """ | |
| chat.openai.com/chat | |
| Return True if n is a power of 2, otherwise return False. | |
| The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. | |
| The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. | |
| If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. | |
| Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. | |
| """ | |
| if n <= 0: | |
| return False | |
| return (n & (n - 1)) == 0 | |
| def autocast(f, enabled=True): | |
| def do_autocast(*args, **kwargs): | |
| with torch.cuda.amp.autocast( | |
| enabled=enabled, | |
| dtype=torch.get_autocast_gpu_dtype(), | |
| cache_enabled=torch.is_autocast_cache_enabled(), | |
| ): | |
| return f(*args, **kwargs) | |
| return do_autocast | |
| def load_partial_from_config(config): | |
| return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) | |
| def log_txt_as_img(wh, xc, size=10): | |
| # wh a tuple of (width, height) | |
| # xc a list of captions to plot | |
| b = len(xc) | |
| txts = list() | |
| for bi in range(b): | |
| txt = Image.new("RGB", wh, color="white") | |
| draw = ImageDraw.Draw(txt) | |
| font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) | |
| nc = int(40 * (wh[0] / 256)) | |
| if isinstance(xc[bi], list): | |
| text_seq = xc[bi][0] | |
| else: | |
| text_seq = xc[bi] | |
| lines = "\n".join( | |
| text_seq[start : start + nc] for start in range(0, len(text_seq), nc) | |
| ) | |
| try: | |
| draw.text((0, 0), lines, fill="black", font=font) | |
| except UnicodeEncodeError: | |
| print("Cant encode string for logging. Skipping.") | |
| txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 | |
| txts.append(txt) | |
| txts = np.stack(txts) | |
| txts = torch.tensor(txts) | |
| return txts | |
| def partialclass(cls, *args, **kwargs): | |
| class NewCls(cls): | |
| __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) | |
| return NewCls | |
| def make_path_absolute(path): | |
| fs, p = fsspec.core.url_to_fs(path) | |
| if fs.protocol == "file": | |
| return os.path.abspath(p) | |
| return path | |
| def ismap(x): | |
| if not isinstance(x, torch.Tensor): | |
| return False | |
| return (len(x.shape) == 4) and (x.shape[1] > 3) | |
| def isimage(x): | |
| if not isinstance(x, torch.Tensor): | |
| return False | |
| return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) | |
| def isheatmap(x): | |
| if not isinstance(x, torch.Tensor): | |
| return False | |
| return x.ndim == 2 | |
| def isneighbors(x): | |
| if not isinstance(x, torch.Tensor): | |
| return False | |
| return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) | |
| def exists(x): | |
| return x is not None | |
| def expand_dims_like(x, y): | |
| while x.dim() != y.dim(): | |
| x = x.unsqueeze(-1) | |
| return x | |
| def default(val, d): | |
| if exists(val): | |
| return val | |
| return d() if isfunction(d) else d | |
| def mean_flat(tensor): | |
| """ | |
| https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 | |
| Take the mean over all non-batch dimensions. | |
| """ | |
| return tensor.mean(dim=list(range(1, len(tensor.shape)))) | |
| def count_params(model, verbose=False): | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| if verbose: | |
| print(f"{model.__class__.__name__} has {total_params * 1.0e-6:.2f} M params.") | |
| return total_params | |
| def instantiate_from_config(config): | |
| if not "target" in config: | |
| if config == "__is_first_stage__": | |
| return None | |
| elif config == "__is_unconditional__": | |
| return None | |
| raise KeyError("Expected key `target` to instantiate.") | |
| return get_obj_from_str(config["target"])(**config.get("params", dict())) | |
| def get_obj_from_str(string, reload=False, invalidate_cache=True): | |
| module, cls = string.rsplit(".", 1) | |
| if invalidate_cache: | |
| importlib.invalidate_caches() | |
| if reload: | |
| module_imp = importlib.import_module(module) | |
| importlib.reload(module_imp) | |
| return getattr(importlib.import_module(module, package=None), cls) | |
| def append_zero(x): | |
| return torch.cat([x, x.new_zeros([1])]) | |
| def append_dims(x, target_dims): | |
| """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
| dims_to_append = target_dims - x.ndim | |
| if dims_to_append < 0: | |
| raise ValueError( | |
| f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" | |
| ) | |
| return x[(...,) + (None,) * dims_to_append] | |
| def load_model_from_config(config, ckpt, verbose=True, freeze=True): | |
| print(f"Loading model from {ckpt}") | |
| if ckpt.endswith("ckpt"): | |
| pl_sd = torch.load(ckpt, map_location="cpu") | |
| if "global_step" in pl_sd: | |
| print(f"Global Step: {pl_sd['global_step']}") | |
| sd = pl_sd["state_dict"] | |
| elif ckpt.endswith("safetensors"): | |
| sd = load_safetensors(ckpt) | |
| else: | |
| raise NotImplementedError | |
| model = instantiate_from_config(config.model) | |
| m, u = model.load_state_dict(sd, strict=False) | |
| if len(m) > 0 and verbose: | |
| print("missing keys:") | |
| print(m) | |
| if len(u) > 0 and verbose: | |
| print("unexpected keys:") | |
| print(u) | |
| if freeze: | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| model.eval() | |
| return model | |
| def get_configs_path() -> str: | |
| """ | |
| Get the `configs` directory. | |
| For a working copy, this is the one in the root of the repository, | |
| but for an installed copy, it's in the `sgm` package (see pyproject.toml). | |
| """ | |
| this_dir = os.path.dirname(__file__) | |
| candidates = ( | |
| os.path.join(this_dir, "configs"), | |
| os.path.join(this_dir, "..", "configs"), | |
| ) | |
| for candidate in candidates: | |
| candidate = os.path.abspath(candidate) | |
| if os.path.isdir(candidate): | |
| return candidate | |
| raise FileNotFoundError(f"Could not find SGM configs in {candidates}") | |
| def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): | |
| """ | |
| Will return the result of a recursive get attribute call. | |
| E.g.: | |
| a.b.c | |
| = getattr(getattr(a, "b"), "c") | |
| = get_nested_attribute(a, "b.c") | |
| If any part of the attribute call is an integer x with current obj a, will | |
| try to call a[x] instead of a.x first. | |
| """ | |
| attributes = attribute_path.split(".") | |
| if depth is not None and depth > 0: | |
| attributes = attributes[:depth] | |
| assert len(attributes) > 0, "At least one attribute should be selected" | |
| current_attribute = obj | |
| current_key = None | |
| for level, attribute in enumerate(attributes): | |
| current_key = ".".join(attributes[: level + 1]) | |
| try: | |
| id_ = int(attribute) | |
| current_attribute = current_attribute[id_] | |
| except ValueError: | |
| current_attribute = getattr(current_attribute, attribute) | |
| return (current_attribute, current_key) if return_key else current_attribute | |
| def suppress_output(f): | |
| def wrapper(*args, **kwargs): | |
| with ( | |
| contextlib.redirect_stdout(io.StringIO()), | |
| contextlib.redirect_stderr(io.StringIO()), | |
| warnings.catch_warnings(), | |
| ): | |
| warnings.simplefilter("ignore") | |
| return f(*args, **kwargs) | |
| return wrapper | |
| def calculate_splits(tensor, min_last_size, dim=1): | |
| # Check the total number of elements in the tensor | |
| total_size = tensor.size(dim) # size along the second dimension | |
| # If total size is less than the minimum size for the last split, return the tensor as a single split | |
| if total_size <= min_last_size: | |
| return [tensor] | |
| # Calculate number of splits and size of each split | |
| num_splits = (total_size - min_last_size) // min_last_size + 1 | |
| base_size = (total_size - min_last_size) // num_splits | |
| # Create split sizes list | |
| split_sizes = [base_size] * (num_splits - 1) | |
| split_sizes.append( | |
| total_size - sum(split_sizes) | |
| ) # Ensure the last split has at least min_last_size | |
| # Adjust sizes to ensure they sum exactly to total_size | |
| sum_sizes = sum(split_sizes) | |
| while sum_sizes != total_size: | |
| for i in range(num_splits): | |
| if sum_sizes < total_size: | |
| split_sizes[i] += 1 | |
| sum_sizes += 1 | |
| if sum_sizes >= total_size: | |
| break | |
| # Split the tensor | |
| splits = torch.split(tensor, split_sizes, dim=dim) | |
| return splits | |