Spaces:
Sleeping
Sleeping
| import logging | |
| import re | |
| from functools import cache, partial | |
| from typing import Callable, TypeVar | |
| import deepspeed | |
| import pandas as pd | |
| from deepspeed.accelerator import get_accelerator | |
| from deepspeed.runtime.engine import DeepSpeedEngine | |
| from deepspeed.runtime.utils import clip_grad_norm_ | |
| from torch import nn | |
| from .distributed import fix_unset_envs | |
| logger = logging.getLogger(__name__) | |
| T = TypeVar("T") | |
| def flatten_dict(d): | |
| records = pd.json_normalize(d, sep="/").to_dict(orient="records") | |
| return records[0] if records else {} | |
| def _get_named_modules(module, attrname, sep="/"): | |
| for name, module in module.named_modules(): | |
| name = name.replace(".", sep) | |
| if hasattr(module, attrname): | |
| yield name, module | |
| def gather_attribute(module, attrname, delete=True, prefix=None): | |
| ret = {} | |
| for name, module in _get_named_modules(module, attrname): | |
| ret[name] = getattr(module, attrname) | |
| if delete: | |
| try: | |
| delattr(module, attrname) | |
| except Exception as e: | |
| raise RuntimeError(f"{name} {module} {attrname}") from e | |
| if prefix: | |
| ret = {prefix: ret} | |
| ret = flatten_dict(ret) | |
| # remove consecutive / | |
| ret = {re.sub(r"\/+", "/", k): v for k, v in ret.items()} | |
| return ret | |
| def dispatch_attribute(module, attrname, value, filter_fn: Callable[[nn.Module], bool] | None = None): | |
| for _, module in _get_named_modules(module, attrname): | |
| if filter_fn is None or filter_fn(module): | |
| setattr(module, attrname, value) | |
| def update_deepspeed_logger(): | |
| logger = logging.getLogger("DeepSpeed") | |
| logger.setLevel(logging.WARNING) | |
| def init_distributed(): | |
| update_deepspeed_logger() | |
| fix_unset_envs() | |
| deepspeed.init_distributed(get_accelerator().communication_backend_name()) | |
| def _try_each(*fns, e=None): | |
| if len(fns) == 0: | |
| raise RuntimeError("All functions failed") | |
| head, *tails = fns | |
| try: | |
| return head() | |
| except Exception as e: | |
| logger.warning(f"Tried {head} but failed: {e}, trying next") | |
| return _try_each(*tails) | |
| class Engine(DeepSpeedEngine): | |
| def __init__(self, *args, ckpt_dir, **kwargs): | |
| init_distributed() | |
| super().__init__(args=None, *args, **kwargs) | |
| self._ckpt_dir = ckpt_dir | |
| self._frozen_params = set() | |
| self._fp32_grad_norm = None | |
| def path(self): | |
| return self._ckpt_dir | |
| def freeze_(self): | |
| for p in self.module.parameters(): | |
| if p.requires_grad: | |
| p.requires_grad_(False) | |
| self._frozen_params.add(p) | |
| def unfreeze_(self): | |
| for p in self._frozen_params: | |
| p.requires_grad_(True) | |
| self._frozen_params.clear() | |
| def global_step(self): | |
| return self.global_steps | |
| def gather_attribute(self, *args, **kwargs): | |
| return gather_attribute(self.module, *args, **kwargs) | |
| def dispatch_attribute(self, *args, **kwargs): | |
| return dispatch_attribute(self.module, *args, **kwargs) | |
| def clip_fp32_gradients(self): | |
| self._fp32_grad_norm = clip_grad_norm_( | |
| parameters=self.module.parameters(), | |
| max_norm=self.gradient_clipping(), | |
| mpu=self.mpu, | |
| ) | |
| def get_grad_norm(self): | |
| grad_norm = self.get_global_grad_norm() | |
| if grad_norm is None: | |
| grad_norm = self._fp32_grad_norm | |
| return grad_norm | |
| def save_checkpoint(self, *args, **kwargs): | |
| if not self._ckpt_dir.exists(): | |
| self._ckpt_dir.mkdir(parents=True, exist_ok=True) | |
| super().save_checkpoint(save_dir=self._ckpt_dir, *args, **kwargs) | |
| logger.info(f"Saved checkpoint to {self._ckpt_dir}") | |
| def load_checkpoint(self, *args, **kwargs): | |
| fn = partial(super().load_checkpoint, *args, load_dir=self._ckpt_dir, **kwargs) | |
| return _try_each( | |
| lambda: fn(), | |
| lambda: fn(load_optimizer_states=False), | |
| lambda: fn(load_lr_scheduler_states=False), | |
| lambda: fn(load_optimizer_states=False, load_lr_scheduler_states=False), | |
| lambda: fn( | |
| load_optimizer_states=False, | |
| load_lr_scheduler_states=False, | |
| load_module_strict=False, | |
| ), | |
| ) | |