wcy1122's picture
initial commi
26a63c0
import os
import shutil
import threading
from pathlib import Path
import torch
def import_from_transformers_modules(
pretrained_model_name_or_path, file_name, class_name
):
import transformers
module_path = transformers.dynamic_module_utils.get_cached_module_file(
pretrained_model_name_or_path, file_name
)
return transformers.dynamic_module_utils.get_class_in_module(
class_name, module_path
)
def deepspeed_zero_init_disabled_context_manager():
"""
returns either a context list that includes one that will disable zero.Init or an empty context list
"""
import accelerate
deepspeed_plugin = (
accelerate.state.AcceleratorState().deepspeed_plugin
if accelerate.state.is_initialized()
else None
)
if deepspeed_plugin is None:
return []
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
def remove_excess_checkpoints(
save_directory,
checkpoints_total_limit: int = None,
checkpoint_prefix="checkpoint",
is_main_process: bool = True,
):
# _after_ saving state, check if this save would set us over the `checkpoints_total_limit`
if is_main_process and checkpoints_total_limit is not None:
checkpoints = os.listdir(save_directory)
checkpoints = [d for d in checkpoints if d.startswith(checkpoint_prefix)]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[2]))
# _after_ we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit` checkpoints
if len(checkpoints) > checkpoints_total_limit:
num_to_remove = len(checkpoints) - checkpoints_total_limit
removing_checkpoints = checkpoints[0:num_to_remove]
print(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
print(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(save_directory, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
def is_distributed_training():
if torch.distributed.is_available() and torch.distributed.is_initialized():
return True
world_size = int(os.environ.get("WORLD_SIZE", 1))
return world_size > 1
def contain_invalid_grad(optimizer):
invalid_grad = False
for param_group in optimizer.param_groups:
for param in param_group["params"]:
if param.grad is not None:
invalid_grad = invalid_grad or (
torch.isnan(param.grad).any()
or torch.isinf(param.grad).any()
or torch.isneginf(param.grad).any()
)
if is_distributed_training():
invalid_grad_flag = torch.tensor(
[1.0 if invalid_grad else 0.0],
dtype=torch.float32,
requires_grad=False,
).cuda()
torch.distributed.all_reduce(
invalid_grad_flag, op=torch.distributed.ReduceOp.MAX
)
invalid_grad = invalid_grad_flag.item() > 0
return invalid_grad
def patch_npu_record_stream():
torch.utils.rename_privateuse1_backend("npu")
record_stream = torch.Tensor.record_stream
def _func(*args, **kwargs):
ret = record_stream(*args, **kwargs)
torch.cuda.synchronize()
return ret
torch.Tensor.record_stream = _func
def patch_npu_diffusers_get_1d_rotary_pos_embed():
from typing import Union
import numpy as np
import diffusers
def __get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
):
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor
freqs = (
1.0
/ (
theta
** (
torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[
: (dim // 2)
]
/ dim
)
)
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = (
freqs.cos().float().repeat_interleave(2, dim=1).float()
) # [S, D]
freqs_sin = (
freqs.sin().float().repeat_interleave(2, dim=1).float()
) # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(
torch.ones_like(freqs), freqs
) # complex64 # [S, D/2]
return freqs_cis
diffusers.models.embeddings.get_1d_rotary_pos_embed = __get_1d_rotary_pos_embed