Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import shutil | |
| import threading | |
| from pathlib import Path | |
| import torch | |
| def import_from_transformers_modules( | |
| pretrained_model_name_or_path, file_name, class_name | |
| ): | |
| import transformers | |
| module_path = transformers.dynamic_module_utils.get_cached_module_file( | |
| pretrained_model_name_or_path, file_name | |
| ) | |
| return transformers.dynamic_module_utils.get_class_in_module( | |
| class_name, module_path | |
| ) | |
| def deepspeed_zero_init_disabled_context_manager(): | |
| """ | |
| returns either a context list that includes one that will disable zero.Init or an empty context list | |
| """ | |
| import accelerate | |
| deepspeed_plugin = ( | |
| accelerate.state.AcceleratorState().deepspeed_plugin | |
| if accelerate.state.is_initialized() | |
| else None | |
| ) | |
| if deepspeed_plugin is None: | |
| return [] | |
| return [deepspeed_plugin.zero3_init_context_manager(enable=False)] | |
| def remove_excess_checkpoints( | |
| save_directory, | |
| checkpoints_total_limit: int = None, | |
| checkpoint_prefix="checkpoint", | |
| is_main_process: bool = True, | |
| ): | |
| # _after_ saving state, check if this save would set us over the `checkpoints_total_limit` | |
| if is_main_process and checkpoints_total_limit is not None: | |
| checkpoints = os.listdir(save_directory) | |
| checkpoints = [d for d in checkpoints if d.startswith(checkpoint_prefix)] | |
| checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[2])) | |
| # _after_ we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit` checkpoints | |
| if len(checkpoints) > checkpoints_total_limit: | |
| num_to_remove = len(checkpoints) - checkpoints_total_limit | |
| removing_checkpoints = checkpoints[0:num_to_remove] | |
| print( | |
| f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" | |
| ) | |
| print(f"removing checkpoints: {', '.join(removing_checkpoints)}") | |
| for removing_checkpoint in removing_checkpoints: | |
| removing_checkpoint = os.path.join(save_directory, removing_checkpoint) | |
| shutil.rmtree(removing_checkpoint) | |
| def is_distributed_training(): | |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): | |
| return True | |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) | |
| return world_size > 1 | |
| def contain_invalid_grad(optimizer): | |
| invalid_grad = False | |
| for param_group in optimizer.param_groups: | |
| for param in param_group["params"]: | |
| if param.grad is not None: | |
| invalid_grad = invalid_grad or ( | |
| torch.isnan(param.grad).any() | |
| or torch.isinf(param.grad).any() | |
| or torch.isneginf(param.grad).any() | |
| ) | |
| if is_distributed_training(): | |
| invalid_grad_flag = torch.tensor( | |
| [1.0 if invalid_grad else 0.0], | |
| dtype=torch.float32, | |
| requires_grad=False, | |
| ).cuda() | |
| torch.distributed.all_reduce( | |
| invalid_grad_flag, op=torch.distributed.ReduceOp.MAX | |
| ) | |
| invalid_grad = invalid_grad_flag.item() > 0 | |
| return invalid_grad | |
| def patch_npu_record_stream(): | |
| torch.utils.rename_privateuse1_backend("npu") | |
| record_stream = torch.Tensor.record_stream | |
| def _func(*args, **kwargs): | |
| ret = record_stream(*args, **kwargs) | |
| torch.cuda.synchronize() | |
| return ret | |
| torch.Tensor.record_stream = _func | |
| def patch_npu_diffusers_get_1d_rotary_pos_embed(): | |
| from typing import Union | |
| import numpy as np | |
| import diffusers | |
| def __get_1d_rotary_pos_embed( | |
| dim: int, | |
| pos: Union[np.ndarray, int], | |
| theta: float = 10000.0, | |
| use_real=False, | |
| linear_factor=1.0, | |
| ntk_factor=1.0, | |
| repeat_interleave_real=True, | |
| freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) | |
| ): | |
| assert dim % 2 == 0 | |
| if isinstance(pos, int): | |
| pos = torch.arange(pos) | |
| if isinstance(pos, np.ndarray): | |
| pos = torch.from_numpy(pos) # type: ignore # [S] | |
| theta = theta * ntk_factor | |
| freqs = ( | |
| 1.0 | |
| / ( | |
| theta | |
| ** ( | |
| torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[ | |
| : (dim // 2) | |
| ] | |
| / dim | |
| ) | |
| ) | |
| / linear_factor | |
| ) # [D/2] | |
| freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] | |
| if use_real and repeat_interleave_real: | |
| # flux, hunyuan-dit, cogvideox | |
| freqs_cos = ( | |
| freqs.cos().float().repeat_interleave(2, dim=1).float() | |
| ) # [S, D] | |
| freqs_sin = ( | |
| freqs.sin().float().repeat_interleave(2, dim=1).float() | |
| ) # [S, D] | |
| return freqs_cos, freqs_sin | |
| elif use_real: | |
| # stable audio | |
| freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] | |
| freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] | |
| return freqs_cos, freqs_sin | |
| else: | |
| # lumina | |
| freqs_cis = torch.polar( | |
| torch.ones_like(freqs), freqs | |
| ) # complex64 # [S, D/2] | |
| return freqs_cis | |
| diffusers.models.embeddings.get_1d_rotary_pos_embed = __get_1d_rotary_pos_embed | |