| | import gc |
| | import io |
| | import logging |
| | import pickle |
| | import shutil |
| | import traceback |
| | from abc import ABCMeta, abstractmethod |
| | from collections import defaultdict |
| | from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed |
| | from contextlib import contextmanager |
| | from copy import deepcopy |
| | from dataclasses import dataclass, field, replace |
| | from functools import reduce |
| | from multiprocessing import shared_memory |
| | from pathlib import Path |
| | from typing import Any, Dict, Generator, List, Optional, Set, Tuple, cast |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.distributed.checkpoint as dist_cp |
| | import torch.multiprocessing as mp |
| | from packaging import version |
| | from torch.distributed import _remote_device |
| | from torch.distributed._shard._utils import narrow_tensor_by_index |
| | from torch.distributed._shard.metadata import ShardMetadata |
| | from torch.distributed._shard.sharded_tensor import ShardedTensor |
| | from torch.distributed.checkpoint.filesystem import WriteResult, _StorageInfo |
| | from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex |
| | from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict |
| | from torch.distributed.checkpoint.planner import LoadItemType, ReadItem |
| | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| | from torch.distributed.fsdp import StateDictType |
| | from torch.distributed.fsdp.api import ( |
| | FullOptimStateDictConfig, |
| | FullStateDictConfig, |
| | ShardedOptimStateDictConfig, |
| | ShardedStateDictConfig, |
| | ) |
| | from torch.futures import Future |
| |
|
| | try: |
| | from torch.distributed.fsdp.flat_param import FlatParamHandle |
| | except ModuleNotFoundError: |
| | from torch.distributed.fsdp._flat_param import FlatParamHandle |
| |
|
| | from . import util |
| |
|
| | from .aliases import PathOrStr |
| | from .config import BaseConfig, ShardedCheckpointerType, TrainConfig |
| | from .exceptions import OLMoCheckpointError |
| | from .optim import Optimizer, fix_optim_state_dict |
| | from .safetensors_util import safetensors_file_to_state_dict |
| | from .torch_util import ( |
| | barrier, |
| | gc_cuda, |
| | get_fs_local_rank, |
| | get_global_rank, |
| | get_world_size, |
| | ) |
| | from .util import ( |
| | _get_s3_client, |
| | default_thread_count, |
| | dir_is_empty, |
| | get_bytes_range, |
| | get_progress_bar, |
| | resource_path, |
| | upload, |
| | wait_for, |
| | ) |
| |
|
| | __all__ = [ |
| | "save_fsdp_model_and_optim_state", |
| | "load_fsdp_model_and_optim_state", |
| | "load_fsdp_optim_state", |
| | "save_state_dict", |
| | "load_state_dict", |
| | "load_model_state", |
| | "RemoteFileSystemWriter", |
| | "RemoteFileSystemReader", |
| | "Checkpointer", |
| | "FullCheckpointer", |
| | "TorchNewStyleShardedCheckpointer", |
| | "TorchLegacyShardedCheckpointer", |
| | "LocalShardedCheckpointer", |
| | "build_sharded_checkpointer", |
| | ] |
| |
|
| |
|
| | log = logging.getLogger(__name__) |
| |
|
| | MODEL_AND_OPTIM_FOLDER = "model_and_optim" |
| |
|
| |
|
| | def save_fsdp_model_and_optim_state( |
| | checkpoint_dir: PathOrStr, |
| | fsdp_model: FSDP, |
| | optim: Optimizer, |
| | *, |
| | upload_to: Optional[str] = None, |
| | save_overwrite: bool = False, |
| | ): |
| | """ |
| | Use this to save a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint` |
| | functions. This should be used during distributed training and should be called by all ranks. |
| | |
| | :param checkpoint_dir: The directory to save to. |
| | :param fsdp_model: The FSDP model. |
| | :param optim: The FSDP model's optimizer. |
| | :param upload_to: Optional, a remote "directory" to upload the checkpoint files to. |
| | :param save_overwrite: Overwrite existing files. |
| | |
| | :raises FileExistsError: If a model and optim checkpoint already exists in ``checkpoint_dir`` and ``save_overwrite=False``. |
| | """ |
| | checkpoint_dir = Path(checkpoint_dir) |
| | target_dir = checkpoint_dir / MODEL_AND_OPTIM_FOLDER |
| | if save_overwrite: |
| | if get_fs_local_rank() == 0: |
| | shutil.rmtree(target_dir, ignore_errors=True) |
| | elif not dir_is_empty(target_dir): |
| | raise FileExistsError(target_dir) |
| | barrier() |
| | if get_fs_local_rank() == 0: |
| | target_dir.mkdir(exist_ok=True, parents=True) |
| | barrier() |
| | with FSDP.state_dict_type( |
| | fsdp_model, |
| | state_dict_type=StateDictType.SHARDED_STATE_DICT, |
| | state_dict_config=ShardedStateDictConfig(offload_to_cpu=True), |
| | optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True), |
| | ): |
| | model_and_optim_state = { |
| | "model": fsdp_model.state_dict(), |
| | "optim": FSDP.optim_state_dict(fsdp_model, optim), |
| | } |
| | dist_cp.save_state_dict( |
| | model_and_optim_state, |
| | RemoteFileSystemWriter( |
| | target_dir, |
| | upload_to=None if upload_to is None else f"{upload_to.rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}", |
| | save_overwrite=save_overwrite, |
| | ), |
| | ) |
| |
|
| |
|
| | def load_fsdp_model_and_optim_state( |
| | checkpoint_dir: PathOrStr, |
| | fsdp_model: FSDP, |
| | optim: Optimizer, |
| | *, |
| | local_cache: Optional[PathOrStr] = None, |
| | load_optimizer_state: bool = True, |
| | ): |
| | """ |
| | Use this to load a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint` |
| | functions. This should be used during distributed training and should be called by all ranks. |
| | |
| | :param checkpoint_dir: The checkpoint directory to load from. This can be a local or remote directory. |
| | :param fsdp_model: The FSDP model. |
| | :param optim: The FSDP model's optimizer. |
| | :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a |
| | remote "directory" but there might be a cached version of the same artifacts. |
| | :param load_optimizer_state: Set to ``False`` to skip loading the optimizer state. |
| | |
| | :raises FileNotFoundError: If the ``checkpoint_dir`` doesn't contain a model and optimizer checkpoint. |
| | """ |
| | load_path = str(checkpoint_dir).rstrip("/") |
| | local_cache = None if local_cache is None else Path(local_cache) |
| | with FSDP.state_dict_type( |
| | fsdp_model, |
| | state_dict_type=StateDictType.SHARDED_STATE_DICT, |
| | state_dict_config=ShardedStateDictConfig(offload_to_cpu=True), |
| | optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True), |
| | ): |
| | |
| | log.info("Loading model state...") |
| | model_state = {"model": fsdp_model.state_dict()} |
| | dist_cp.load_state_dict( |
| | model_state, |
| | RemoteFileSystemReader( |
| | f"{load_path}/{MODEL_AND_OPTIM_FOLDER}", |
| | local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER, |
| | ), |
| | ) |
| | fsdp_model.load_state_dict(model_state["model"]) |
| |
|
| | if not load_optimizer_state: |
| | return |
| |
|
| | |
| | log.info("Loading sharded optimizer state...") |
| | optim_state = load_sharded_optimizer_state_dict( |
| | model_state_dict=model_state["model"], |
| | optimizer_key="optim", |
| | storage_reader=RemoteFileSystemReader( |
| | f"{load_path}/{MODEL_AND_OPTIM_FOLDER}", |
| | local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER, |
| | ), |
| | ) |
| | del model_state |
| | gc_cuda() |
| | load_fsdp_optim_state(fsdp_model, optim, optim_state["optim"]) |
| |
|
| |
|
| | def load_fsdp_optim_state(fsdp_model: FSDP, optim: Optimizer, optim_state: Dict[str, Any]): |
| | log.info("Flattening sharded optimizer state...") |
| | |
| | if version.parse(torch.__version__) < version.parse("2.1.0"): |
| | flattened_osd = FSDP.optim_state_dict_to_load(optim_state, fsdp_model, optim) |
| | else: |
| | flattened_osd = FSDP.optim_state_dict_to_load(fsdp_model, optim, optim_state) |
| | del optim_state |
| | gc.collect() |
| | log.info("Loading flattened optimizer state...") |
| | |
| | |
| | for state in flattened_osd["state"].values(): |
| | for k in state.keys(): |
| | v = state[k] |
| | if isinstance(v, torch.Tensor): |
| | state[k] = v.to(device="cpu") |
| | gc_cuda() |
| | optim.load_state_dict(fix_optim_state_dict(optim, flattened_osd)) |
| |
|
| |
|
| | def save_state_dict( |
| | checkpoint_dir: PathOrStr, |
| | fname: str, |
| | state_dict: Dict[str, Any], |
| | *, |
| | upload_to: Optional[str] = None, |
| | save_overwrite: bool = False, |
| | synchronize: bool = True, |
| | ): |
| | """ |
| | Save a regular state dict to the file ``fname`` within ``checkpoint_dir`` using :func:`torch.save()`. |
| | This can be used during distributed training or not. If during distributed training the ``fname`` should be unique |
| | for each rank. |
| | |
| | :param checkpoint_dir: The directory to save to. |
| | :param fname: The target file within ``checkpoint_dir`` to save to. This should be a path relative to the ``checkpoint_dir``. |
| | :param state_dict: The state dict to save. |
| | :param upload_to: Optional, a remote "directory" to upload the file to. |
| | :param save_overwrite: Overwrite existing files. |
| | :param synchronize: If ``False``, don't do any distributed synchronization. Use this when only calling |
| | this function from a single rank. |
| | |
| | :raises FileExistsError: If the ``fname`` already exists within ``checkpoint_dir`` and ``save_overwrite=False``. |
| | """ |
| | checkpoint_dir = Path(checkpoint_dir) |
| | target_path = checkpoint_dir / fname |
| | if save_overwrite: |
| | target_path.unlink(missing_ok=True) |
| | elif target_path.is_file(): |
| | raise FileExistsError(target_path) |
| | if synchronize: |
| | barrier() |
| | target_path.parent.mkdir(exist_ok=True, parents=True) |
| | if synchronize: |
| | barrier() |
| | torch.save(state_dict, target_path) |
| | if upload_to is not None: |
| | upload_target = f"{upload_to.rstrip('/')}/{fname}" |
| | log.info(f"Uploading {target_path} to {upload_target}...") |
| | upload(target_path, upload_target, save_overwrite=save_overwrite) |
| |
|
| |
|
| | def load_state_dict( |
| | checkpoint_dir: PathOrStr, |
| | fname: str, |
| | *, |
| | local_cache: Optional[PathOrStr] = None, |
| | map_location: Optional[str] = None, |
| | ): |
| | """ |
| | Load a regular state dict from the file ``fname`` within ``checkpoint_dir`` using :func:`torch.load()`. |
| | This can be used during distributed training or not. |
| | |
| | :param checkpoint_dir: A local or remote checkpoint directory. |
| | :param fname: The target file within the ``checkpoint_dir``. This should be a path relative to the ``checkpoint_dir``. |
| | :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a |
| | remote "directory" but there might be a cached version of the same artifacts. |
| | |
| | :raises FileNotFoundError: If ``fname`` doesn't exist in the ``checkpoint_dir`` or the local cache. |
| | """ |
| | if fname.endswith(".pt"): |
| | |
| | try: |
| | path = resource_path( |
| | str(checkpoint_dir).rstrip("/"), fname[:-2] + "safetensors", local_cache=local_cache |
| | ) |
| | return safetensors_file_to_state_dict(path, map_location=map_location) |
| | except FileNotFoundError: |
| | pass |
| |
|
| | path = resource_path(str(checkpoint_dir).rstrip("/"), fname, local_cache=local_cache) |
| | return torch.load(path, map_location=map_location) |
| |
|
| |
|
| | def load_model_state(checkpoint_dir: PathOrStr, model: torch.nn.Module): |
| | """ |
| | Load model state from a distributed FSDP model checkpoint created from :func:`save_fsdp_model_and_optim_state()`. |
| | Note that ``model`` should not be wrapped with FSDP. |
| | """ |
| | state_dict = {"model": model.state_dict()} |
| | dist_cp.load_state_dict( |
| | state_dict, |
| | RemoteFileSystemReader(f"{str(checkpoint_dir).rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}"), |
| | no_dist=True, |
| | ) |
| | model.load_state_dict(state_dict["model"]) |
| |
|
| |
|
| | class RemoteFileSystemWriter(dist_cp.FileSystemWriter): |
| | """ |
| | A subclass of :class:`~torch.distributed.checkpoint.FileSystemWriter` that can upload files |
| | directly to a cloud bucket when ``upload_to`` is specified. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | path: PathOrStr, |
| | single_file_per_rank: bool = True, |
| | sync_files: bool = True, |
| | thread_count: Optional[int] = None, |
| | per_thread_copy_ahead: int = 10_000_000, |
| | upload_to: Optional[str] = None, |
| | save_overwrite: bool = False, |
| | ) -> None: |
| | if thread_count is not None and thread_count <= 0: |
| | raise ValueError("thread count must be at least 1") |
| | super().__init__( |
| | path, |
| | single_file_per_rank=single_file_per_rank, |
| | sync_files=sync_files, |
| | |
| | |
| | |
| | thread_count=thread_count or 1, |
| | per_thread_copy_ahead=per_thread_copy_ahead, |
| | ) |
| | self.upload_to = None if upload_to is None else upload_to.rstrip("/") |
| | self.save_overwrite = save_overwrite |
| |
|
| | def write_data( |
| | self, |
| | plan: dist_cp.SavePlan, |
| | planner: dist_cp.SavePlanner, |
| | ) -> Future[List[WriteResult]]: |
| | fut = super().write_data(plan, planner) |
| | if self.upload_to is not None: |
| | files_to_upload = set() |
| | for write_result in fut.wait(): |
| | files_to_upload.add(write_result.storage_data.relative_path) |
| |
|
| | |
| | if self.upload_to.startswith("s3://"): |
| | _get_s3_client("s3") |
| | elif self.upload_to.startswith("r2://"): |
| | _get_s3_client("r2") |
| |
|
| | with ThreadPoolExecutor(max_workers=self.thread_count) as executor: |
| | futures = [] |
| | for fname in files_to_upload: |
| | source = self.path / fname |
| | target = f"{self.upload_to}/{fname}" |
| | log.info(f"Uploading {source} to {target}...") |
| | futures.append(executor.submit(upload, source, target, save_overwrite=self.save_overwrite)) |
| | for f in as_completed(futures): |
| | try: |
| | f.result() |
| | except BaseException: |
| | |
| | |
| | |
| | raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}") |
| | return fut |
| |
|
| | def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: |
| | super().finish(metadata, results) |
| | if self.upload_to is not None: |
| | source = self.path / ".metadata" |
| | target = f"{self.upload_to}/.metadata" |
| | log.info(f"Uploading {source} to {target}...") |
| | upload(source, target, save_overwrite=self.save_overwrite) |
| |
|
| |
|
| | class RemoteFileSystemReader(dist_cp.StorageReader): |
| | """ |
| | A :class:`~torch.distributed.checkpoint.StorageReader` based on :class:`~torch.distributed.checkpoint.FileSystemReader` |
| | that can read data directly from cloud storage as well as a local directory. |
| | """ |
| |
|
| | def __init__( |
| | self, path: PathOrStr, *, local_cache: Optional[PathOrStr] = None, thread_count: Optional[int] = None |
| | ): |
| | super().__init__() |
| | if thread_count is not None and thread_count <= 0: |
| | raise ValueError("thread count must be at least 1") |
| | self.path = str(path).rstrip("/") |
| | self.cache = None if local_cache is None else Path(local_cache) |
| | self.thread_count = thread_count or default_thread_count() |
| | self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict() |
| | self._metadata: Optional[Metadata] = None |
| |
|
| | def _get_bytes(self, relative_path: str, offset: int, length: int) -> bytes: |
| | if self.cache is not None and (path := self.cache / relative_path).is_file(): |
| | return get_bytes_range(path, offset, length) |
| | else: |
| | return get_bytes_range(f"{self.path}/{relative_path}", offset, length) |
| |
|
| | def _get_content_for_read(self, read_item: ReadItem) -> Tuple[ReadItem, bytes]: |
| | sinfo = self.storage_data[read_item.storage_index] |
| | content = self._get_bytes(sinfo.relative_path, sinfo.offset, sinfo.length) |
| | return (read_item, content) |
| |
|
| | def read_data(self, plan: dist_cp.LoadPlan, planner: dist_cp.LoadPlanner) -> Future[None]: |
| | |
| | if isinstance(self.path, str): |
| | if self.path.startswith("s3://"): |
| | _get_s3_client("s3") |
| | elif self.path.startswith("r2://"): |
| | _get_s3_client("r2") |
| |
|
| | with ThreadPoolExecutor(max_workers=self.thread_count) as executor: |
| | read_item_content_futures = [] |
| | for read_item in plan.items: |
| | read_item_content_futures.append(executor.submit(self._get_content_for_read, read_item)) |
| | read_item_content_results = [] |
| | for f in as_completed(read_item_content_futures): |
| | try: |
| | read_item_content_results.append(f.result()) |
| | except BaseException: |
| | |
| | |
| | |
| | raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}") |
| |
|
| | |
| | for read_item, content in read_item_content_results: |
| | bytes = io.BytesIO(content) |
| | bytes.seek(0) |
| | if read_item.type == LoadItemType.BYTE_IO: |
| | planner.load_bytes(read_item, bytes) |
| | else: |
| | tensor = cast(torch.Tensor, torch.load(bytes, map_location="cpu")) |
| | tensor = narrow_tensor_by_index(tensor, read_item.storage_offsets, read_item.lengths) |
| | target_tensor = planner.resolve_tensor(read_item).detach() |
| |
|
| | assert ( |
| | target_tensor.size() == tensor.size() |
| | ), f"req {read_item.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" |
| | target_tensor.copy_(tensor) |
| | planner.commit_tensor(read_item, target_tensor) |
| |
|
| | fut: Future = Future() |
| | fut.set_result(None) |
| | return fut |
| |
|
| | def read_metadata(self) -> Metadata: |
| | if self._metadata is None: |
| | with resource_path(self.path, ".metadata", local_cache=self.cache).open("rb") as metadata_file: |
| | self._metadata = pickle.load(metadata_file) |
| | return self._metadata |
| |
|
| | def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: |
| | del is_coordinator |
| | self.storage_data = metadata.storage_data |
| | assert self.storage_data is not None |
| |
|
| | def prepare_local_plan(self, plan: dist_cp.LoadPlan) -> dist_cp.LoadPlan: |
| | return plan |
| |
|
| | def prepare_global_plan(self, global_plan: List[dist_cp.LoadPlan]) -> List[dist_cp.LoadPlan]: |
| | return global_plan |
| |
|
| |
|
| | class Checkpointer(metaclass=ABCMeta): |
| | def __init__(self, cfg: TrainConfig, thread_count: Optional[int] = None): |
| | self.cfg = cfg |
| | self.thread_count = thread_count or default_thread_count() |
| |
|
| | @abstractmethod |
| | def save_checkpoint( |
| | self, |
| | dir: PathOrStr, |
| | fsdp_model: FSDP, |
| | optim: Optimizer, |
| | train_state: Dict[str, Any], |
| | *, |
| | upload_to: Optional[str] = None, |
| | ) -> None: |
| | raise NotImplementedError |
| |
|
| | @abstractmethod |
| | def restore_checkpoint( |
| | self, |
| | load_path: PathOrStr, |
| | fsdp_model: FSDP, |
| | optim: Optimizer, |
| | *, |
| | local_cache: Optional[PathOrStr] = None, |
| | load_optimizer_state: bool = True, |
| | ) -> Dict[str, Any]: |
| | """ |
| | Restores a checkpoint to the model and optimizer. Returns the remaining trainer state. |
| | """ |
| | raise NotImplementedError |
| |
|
| | def unshard_checkpoint( |
| | self, |
| | load_path: PathOrStr, |
| | *, |
| | local_cache: Optional[PathOrStr] = None, |
| | load_optimizer_state: bool = True, |
| | load_trainer_state: bool = True, |
| | device: Optional[torch.device] = None, |
| | ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]: |
| | """ |
| | Unshard a checkpoint. |
| | |
| | Note this is not marked abstract because child classes are not required to implemented this. |
| | """ |
| | del load_path, local_cache, load_optimizer_state, load_trainer_state, device |
| | raise NotImplementedError |
| |
|
| | @contextmanager |
| | def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]: |
| | |
| | checkpoint_dir = Path(dir) |
| | if not dir_is_empty(checkpoint_dir): |
| | if self.cfg.save_overwrite: |
| | if get_fs_local_rank() == 0: |
| | shutil.rmtree(checkpoint_dir, ignore_errors=True) |
| | else: |
| | raise FileExistsError(checkpoint_dir) |
| | |
| | |
| | barrier() |
| |
|
| | |
| | |
| | checkpoint_dir_tmp = checkpoint_dir.with_name(checkpoint_dir.name + "-tmp") |
| | if get_fs_local_rank() == 0: |
| | shutil.rmtree(checkpoint_dir_tmp, ignore_errors=True) |
| | checkpoint_dir_tmp.mkdir(exist_ok=True, parents=True) |
| |
|
| | barrier() |
| |
|
| | |
| | yield checkpoint_dir_tmp |
| |
|
| | barrier() |
| |
|
| | |
| | |
| | if get_fs_local_rank() == 0: |
| | |
| | try: |
| | checkpoint_dir_tmp.replace(checkpoint_dir) |
| | except FileNotFoundError: |
| | |
| | |
| | |
| | if not checkpoint_dir.exists(): |
| | raise |
| |
|
| | |
| | |
| | |
| | |
| | wait_for(lambda: checkpoint_dir.exists(), "Waiting for checkpoint directory", timeout=10.0) |
| |
|
| | barrier() |
| |
|
| | def _save_config(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None: |
| | if get_global_rank() == 0: |
| | log.info("Saving config...") |
| | self.cfg.save(config_path := Path(dir) / "config.yaml") |
| | if upload_to is not None: |
| | upload_target = f"{upload_to}/config.yaml" |
| | log.info(f"Uploading {config_path} to {upload_target}") |
| | upload(config_path, upload_target, save_overwrite=self.cfg.save_overwrite) |
| |
|
| |
|
| | class FullCheckpointer(Checkpointer): |
| | """ |
| | A :class:`Checkpointer` that saves a single full model and optimizer state dictionary. |
| | """ |
| |
|
| | def save_checkpoint( |
| | self, |
| | dir: PathOrStr, |
| | fsdp_model: FSDP, |
| | optim: Optimizer, |
| | trainer_state: Dict[str, Any], |
| | *, |
| | upload_to: Optional[str] = None, |
| | ) -> None: |
| | with self._temporary_wd(dir) as checkpoint_dir: |
| | with FSDP.state_dict_type( |
| | fsdp_model, |
| | state_dict_type=StateDictType.FULL_STATE_DICT, |
| | state_dict_config=FullStateDictConfig(rank0_only=True, offload_to_cpu=True), |
| | optim_state_dict_config=FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True), |
| | ): |
| | |
| | |
| | model_state_dict = fsdp_model.state_dict() |
| | if get_global_rank() == 0: |
| | log.info("Saving model state...") |
| | save_state_dict( |
| | checkpoint_dir, |
| | "model.pt", |
| | model_state_dict, |
| | upload_to=upload_to, |
| | save_overwrite=self.cfg.save_overwrite, |
| | synchronize=False, |
| | ) |
| | del model_state_dict |
| | barrier() |
| |
|
| | |
| | optim_state_dict = FSDP.optim_state_dict(fsdp_model, optim) |
| | if get_global_rank() == 0: |
| | log.info("Saving optim state...") |
| | save_state_dict( |
| | checkpoint_dir, |
| | "optim.pt", |
| | optim_state_dict, |
| | upload_to=upload_to, |
| | save_overwrite=self.cfg.save_overwrite, |
| | synchronize=False, |
| | ) |
| | del optim_state_dict |
| | barrier() |
| |
|
| | |
| | if get_global_rank() == 0: |
| | log.info("Saving trainer state...") |
| | save_state_dict( |
| | checkpoint_dir, |
| | "train.pt", |
| | trainer_state, |
| | upload_to=upload_to, |
| | save_overwrite=self.cfg.save_overwrite, |
| | synchronize=False, |
| | ) |
| | |
| | self._save_config(checkpoint_dir, upload_to=upload_to) |
| |
|
| | def restore_checkpoint( |
| | self, |
| | load_path: PathOrStr, |
| | fsdp_model: FSDP, |
| | optim: Optimizer, |
| | *, |
| | local_cache: Optional[PathOrStr] = None, |
| | load_optimizer_state: bool = True, |
| | ) -> Dict[str, Any]: |
| | with FSDP.state_dict_type( |
| | fsdp_model, |
| | state_dict_type=StateDictType.FULL_STATE_DICT, |
| | state_dict_config=FullStateDictConfig(rank0_only=False, offload_to_cpu=True), |
| | optim_state_dict_config=FullOptimStateDictConfig(rank0_only=False, offload_to_cpu=True), |
| | ): |
| | with torch.no_grad(): |
| | |
| | for module_name, module in fsdp_model.named_modules(): |
| | if not isinstance(module, FSDP): |
| | continue |
| | for param in module.params: |
| | param.fill_(torch.nan) |
| |
|
| | |
| | state_dict_to_load = load_state_dict( |
| | load_path, "model.pt", local_cache=local_cache, map_location="cpu" |
| | ) |
| | ( |
| | state_dict_to_load, |
| | og_keys_to_new, |
| | ) = fsdp_model._fsdp_wrapped_module._make_state_dict_compatible(state_dict_to_load) |
| |
|
| | for module_name, module in fsdp_model.named_modules(): |
| | if not isinstance(module, FSDP): |
| | continue |
| | for param in module.params: |
| | assert param._is_flat_param |
| | for fqn, spi in zip(param._fqns, param._shard_param_infos): |
| | if not spi.in_shard: |
| | continue |
| | key = f"{module_name}.{fqn}" |
| | key = key.replace("_fsdp_wrapped_module.", "") |
| | key = key.lstrip(".") |
| | t = state_dict_to_load[key] |
| | t = t.flatten() |
| | param[spi.offset_in_shard : spi.offset_in_shard + spi.numel_in_shard].copy_( |
| | t[spi.intra_param_start_idx : spi.intra_param_end_idx + 1] |
| | ) |
| |
|
| | |
| | for module_name, module in fsdp_model.named_modules(): |
| | if not isinstance(module, FSDP): |
| | continue |
| | for param in module.params: |
| | if torch.isnan(param).any(): |
| | raise ValueError( |
| | f"Module '{module_name}' contains NaNs, this is likely a bug restoring from full checkpoints" |
| | ) |
| |
|
| | |
| | if load_optimizer_state: |
| | optim_state_dict_to_load = load_state_dict( |
| | load_path, "optim.pt", local_cache=local_cache, map_location="cpu" |
| | ) |
| | optim_state_dict_to_load = self._make_optim_state_dict_compatible( |
| | optim_state_dict_to_load, |
| | og_keys_to_new, |
| | ) |
| | load_fsdp_optim_state(fsdp_model, optim, optim_state_dict_to_load) |
| | del optim_state_dict_to_load |
| |
|
| | |
| | try: |
| | trainer_state = load_state_dict(load_path, "train.pt", local_cache=local_cache) |
| | except FileNotFoundError: |
| | |
| | trainer_state = load_state_dict(load_path, "other.pt", local_cache=local_cache) |
| | barrier() |
| | return trainer_state |
| |
|
| | def _make_optim_state_dict_compatible( |
| | self, optim_state_dict: Dict[str, Any], og_keys_to_new: Dict[str, Set[str]] |
| | ) -> Dict[str, Any]: |
| | |
| | |
| | |
| | if isinstance(optim_state_dict["param_groups"][0]["params"][0], int): |
| | id_to_fqn: Dict[int, str] = {} |
| | for group in optim_state_dict["param_groups"]: |
| | new_param_names = [] |
| | for fqn, id in zip(group["param_names"], group["params"]): |
| | fqn = fqn.replace("_fsdp_wrapped_module.", "") |
| | id_to_fqn[id] = fqn |
| | new_param_names.append(fqn) |
| | group["param_names"] = new_param_names |
| | group["params"] = new_param_names |
| | for id in list(optim_state_dict["state"].keys()): |
| | optim_state_dict["state"][id_to_fqn[id]] = optim_state_dict["state"].pop(id) |
| | else: |
| | |
| | for group in optim_state_dict["param_groups"]: |
| | group["param_names"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["param_names"]] |
| | group["params"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["params"]] |
| | assert group["param_names"] == group["params"] |
| | for key in list(optim_state_dict["state"].keys()): |
| | optim_state_dict["state"][key.replace("_fsdp_wrapped_module.", "")] = optim_state_dict[ |
| | "state" |
| | ].pop(key) |
| |
|
| | |
| | |
| | for og_key, new_keys in og_keys_to_new.items(): |
| | og_state = optim_state_dict["state"].pop(og_key, None) |
| | if og_state is None: |
| | continue |
| | for i, new_key in enumerate(new_keys): |
| | if i == len(new_keys) - 1: |
| | optim_state_dict["state"][new_key] = og_state |
| | else: |
| | optim_state_dict["state"][new_key] = deepcopy(og_state) |
| | |
| | for group in optim_state_dict["param_groups"]: |
| | og_names = group["params"] |
| | new_names = [] |
| | for og_key in og_names: |
| | for new_key in og_keys_to_new[og_key]: |
| | new_names.append(new_key) |
| | group["params"] = new_names |
| | group["param_names"] = new_names |
| |
|
| | return optim_state_dict |
| |
|
| | def load_checkpoint( |
| | self, |
| | load_path: PathOrStr, |
| | *, |
| | local_cache: Optional[PathOrStr] = None, |
| | load_optimizer_state: bool = True, |
| | device: Optional[torch.device] = None, |
| | ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]]]: |
| | device = device if device is not None else torch.device("cpu") |
| | model_state = load_state_dict(load_path, "model.pt", local_cache=local_cache, map_location=device) |
| | optim_state = None |
| | if load_optimizer_state: |
| | optim_state = load_state_dict(load_path, "optim.pt", local_cache=local_cache, map_location=device) |
| | return model_state, optim_state |
| |
|
| |
|
| | class TorchNewStyleShardedCheckpointer(Checkpointer): |
| | """ |
| | A sharded :class:`Checkpointer` that uses PyTorch's new distributed checkpointing functionality. |
| | """ |
| |
|
| | def save_checkpoint( |
| | self, |
| | dir: PathOrStr, |
| | fsdp_model: FSDP, |
| | optim: Optimizer, |
| | trainer_state: Dict[str, Any], |
| | *, |
| | upload_to: Optional[str] = None, |
| | ) -> None: |
| | with self._temporary_wd(dir) as checkpoint_dir: |
| | |
| | save_fsdp_model_and_optim_state( |
| | checkpoint_dir, |
| | fsdp_model, |
| | optim, |
| | upload_to=upload_to, |
| | save_overwrite=self.cfg.save_overwrite, |
| | ) |
| |
|
| | |
| | log.info("Saving trainer state...") |
| | save_state_dict( |
| | checkpoint_dir, |
| | f"train/rank{get_global_rank()}.pt", |
| | trainer_state, |
| | upload_to=upload_to, |
| | save_overwrite=self.cfg.save_overwrite, |
| | ) |
| |
|
| | |
| | self._save_config(checkpoint_dir, upload_to=upload_to) |
| |
|
| | def restore_checkpoint( |
| | self, |
| | load_path: PathOrStr, |
| | fsdp_model: FSDP, |
| | optim: Optimizer, |
| | *, |
| | local_cache: Optional[PathOrStr] = None, |
| | load_optimizer_state: bool = True, |
| | ) -> Dict[str, Any]: |
| | |
| | log.info("Loading model and optimizer state...") |
| | load_fsdp_model_and_optim_state( |
| | load_path, |
| | fsdp_model, |
| | optim, |
| | local_cache=local_cache, |
| | load_optimizer_state=load_optimizer_state, |
| | ) |
| |
|
| | |
| | log.info("Loading trainer state...") |
| | try: |
| | trainer_state = load_state_dict( |
| | load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache |
| | ) |
| | except FileNotFoundError: |
| | |
| | |
| | trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache) |
| | barrier() |
| | return trainer_state |
| |
|
| |
|
| | class TorchLegacyShardedCheckpointer(Checkpointer): |
| | """ |
| | A sharded :class:`Checkpointer` that just uses `torch.save()` with extra logic for handling FSDP model |
| | and optim state. |
| | |
| | The world size must be kept consistent when using this checkpointer. |
| | """ |
| |
|
| | def save_checkpoint( |
| | self, |
| | dir: PathOrStr, |
| | fsdp_model: FSDP, |
| | optim: Optimizer, |
| | trainer_state: Dict[str, Any], |
| | *, |
| | upload_to: Optional[str] = None, |
| | ) -> None: |
| | with self._temporary_wd(dir) as checkpoint_dir: |
| | with FSDP.state_dict_type( |
| | fsdp_model, |
| | state_dict_type=StateDictType.SHARDED_STATE_DICT, |
| | state_dict_config=ShardedStateDictConfig(offload_to_cpu=True), |
| | optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True), |
| | ): |
| | state_dict = { |
| | "model": fsdp_model.state_dict(), |
| | "optim": FSDP.optim_state_dict(fsdp_model, optim), |
| | **trainer_state, |
| | } |
| | save_state_dict( |
| | checkpoint_dir, |
| | f"rank{get_global_rank()}.pt", |
| | state_dict, |
| | upload_to=upload_to, |
| | save_overwrite=self.cfg.save_overwrite, |
| | ) |
| |
|
| | |
| | self._save_config(checkpoint_dir, upload_to=upload_to) |
| |
|
| | def restore_checkpoint( |
| | self, |
| | load_path: PathOrStr, |
| | fsdp_model: FSDP, |
| | optim: Optimizer, |
| | *, |
| | local_cache: Optional[PathOrStr] = None, |
| | load_optimizer_state: bool = True, |
| | ) -> Dict[str, Any]: |
| | with FSDP.state_dict_type( |
| | fsdp_model, |
| | state_dict_type=StateDictType.SHARDED_STATE_DICT, |
| | state_dict_config=ShardedStateDictConfig(offload_to_cpu=True), |
| | optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True), |
| | ): |
| | |
| | state_dict = load_state_dict( |
| | load_path, f"rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu" |
| | ) |
| |
|
| | |
| | log.info("Loading model state...") |
| | fsdp_model.load_state_dict(state_dict["model"]) |
| | del state_dict["model"] |
| | if load_optimizer_state: |
| | log.info("Loading optimizer state...") |
| | load_fsdp_optim_state(fsdp_model, optim, state_dict["optim"]) |
| | del state_dict["optim"] |
| |
|
| | barrier() |
| | return state_dict |
| |
|
| | def unshard_checkpoint( |
| | self, |
| | load_path: PathOrStr, |
| | *, |
| | local_cache: Optional[PathOrStr] = None, |
| | load_optimizer_state: bool = True, |
| | load_trainer_state: bool = True, |
| | device: Optional[torch.device] = None, |
| | ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]: |
| | assert local_cache is None, "this method currently only supports local files" |
| | full_state_dict = self._unshard(load_path, device or torch.device("cpu"), skip_keys={"rng"}) |
| | model_state = full_state_dict.pop("model") |
| | optim_state = full_state_dict.pop("optim") |
| | return ( |
| | model_state, |
| | optim_state if load_optimizer_state else None, |
| | full_state_dict if load_trainer_state else None, |
| | ) |
| |
|
| | def _copy_sharded_tensors_to_shared_mem(self, state: Dict, world_size: int, rank: int, key: Tuple): |
| | key = tuple() if key is None else key |
| | if isinstance(state, (list, tuple, set)): |
| | for i, sub_state in enumerate(state): |
| | self._copy_sharded_tensors_to_shared_mem(sub_state, world_size, rank, key + (i,)) |
| | elif isinstance(state, dict): |
| | for name in state.keys(): |
| | self._copy_sharded_tensors_to_shared_mem(state[name], world_size, rank, key + (name,)) |
| | elif isinstance(state, ShardedTensor): |
| | self._copy_sharded_tensor_to_shared_mem(state, world_size, rank, key) |
| | return |
| | else: |
| | return |
| |
|
| | def _get_shard_placement_and_rank_sizes( |
| | self, shards_metadata: List[ShardMetadata], world_size: int |
| | ) -> Tuple[Dict[ShardMetadata, Tuple[int, int]], List[int]]: |
| | def shard_size(shard_md): |
| | return reduce((lambda x, y: x * y), shard_md.shard_sizes) |
| |
|
| | rank_sizes = [0 for _ in range(world_size)] |
| | shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {} |
| | for shard_md in shards_metadata: |
| | shard_rank = cast(_remote_device, shard_md.placement).rank() |
| | assert shard_rank is not None |
| | if shard_rank >= world_size: |
| | raise RuntimeError(f"Shard rank {shard_rank} exceeds world size {world_size}") |
| |
|
| | shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank]) |
| | rank_sizes[shard_rank] += shard_size(shard_md) |
| |
|
| | return shard_placement, rank_sizes |
| |
|
| | def _copy_sharded_tensor_to_shared_mem( |
| | self, sharded_tensor: ShardedTensor, world_size: int, rank: int, key: Tuple |
| | ) -> Any: |
| | shard0_md = sharded_tensor.metadata() |
| | shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes( |
| | shard0_md.shards_metadata, world_size |
| | ) |
| |
|
| | rank_size = rank_sizes[rank] |
| | assert rank_size >= 0 |
| | if rank_size == 0: |
| | return |
| |
|
| | assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32" |
| | numpy_type = np.float32 |
| |
|
| | sharded_memory_name = "-".join(key + (str(rank),)) |
| |
|
| | shm = shared_memory.SharedMemory( |
| | create=True, size=rank_size * np.dtype(numpy_type).itemsize, name=sharded_memory_name |
| | ) |
| | np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf) |
| |
|
| | for local_shard in sharded_tensor.local_shards(): |
| | shard_rank = cast(_remote_device, local_shard.metadata.placement).rank() |
| | assert shard_rank == rank |
| |
|
| | src = local_shard.tensor.flatten() |
| | shard_offset = shard_placement[local_shard.metadata][1] |
| |
|
| | np_arr[shard_offset : shard_offset + src.numel()] = src.numpy() |
| |
|
| | shm.close() |
| |
|
| | def _copy_sharded_data_to_shared_mem(self, world_size: int, shard_filepath: Path): |
| | shard_number = int(shard_filepath.name[4:-3]) |
| | log.info("Starting unsharding shard number %d to shared memory", shard_number) |
| |
|
| | with self._patch_sharded_tensor_load(): |
| | shard = torch.load(shard_filepath, map_location="cpu") |
| | log.debug("Done loading shard number %d", shard_number) |
| |
|
| | self._copy_sharded_tensors_to_shared_mem( |
| | shard, world_size, shard_number, (str(shard_filepath.parent).replace("/", "_"),) |
| | ) |
| | log.info("Done unsharding shard number %d to shared memory", shard_number) |
| |
|
| | def _unshard_using_sharded_mem( |
| | self, state: Any, world_size: int, device: torch.device, shard_dir: PathOrStr |
| | ) -> Any: |
| | return self._unshard_state_using_shared_mem(state, world_size, device, (str(shard_dir).replace("/", "_"),)) |
| |
|
| | def _unshard_state_using_shared_mem( |
| | self, state: Any, world_size: int, device: torch.device, key: Tuple |
| | ) -> Any: |
| | if isinstance(state, (list, tuple, set)): |
| | return state.__class__( |
| | self._unshard_state_using_shared_mem(sub_state, world_size, device, key + (i,)) |
| | for i, sub_state in enumerate(state) |
| | ) |
| | elif isinstance(state, dict): |
| | return { |
| | name: self._unshard_state_using_shared_mem(state[name], world_size, device, key + (name,)) |
| | for name in state.keys() |
| | } |
| | elif isinstance(state, ShardedTensor): |
| | return self._unshard_tensor_using_shared_mem(state, world_size, device, key) |
| | elif isinstance(state, torch.Tensor): |
| | return state.to(device=device) |
| | else: |
| | return state |
| |
|
| | def _unshard_tensor_using_shared_mem( |
| | self, sharded_tensor: ShardedTensor, world_size: int, device: torch.device, key: Tuple |
| | ) -> torch.Tensor: |
| | shard0_md = sharded_tensor.metadata() |
| |
|
| | def shard_size(shard_md): |
| | return reduce((lambda x, y: x * y), shard_md.shard_sizes) |
| |
|
| | shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes( |
| | shard0_md.shards_metadata, world_size |
| | ) |
| |
|
| | assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32" |
| | numpy_type = np.float32 |
| |
|
| | out = torch.empty( |
| | *sharded_tensor.metadata().size, dtype=sharded_tensor.metadata().tensor_properties.dtype, device=device |
| | ) |
| | dims = len(sharded_tensor.metadata().size) |
| | for shard_md, (rank, rank_offset) in shard_placement.items(): |
| | if rank >= world_size: |
| | raise RuntimeError(f"Shard rank {rank} exceeds world size {world_size}") |
| |
|
| | sharded_memory_name = "-".join(key + (str(rank),)) |
| | shm = shared_memory.SharedMemory(name=sharded_memory_name) |
| |
|
| | rank_size = rank_sizes[rank] |
| | assert rank_size >= 0 |
| | if rank_size == 0: |
| | continue |
| |
|
| | np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf) |
| |
|
| | tensor = torch.from_numpy(np_arr)[rank_offset : rank_offset + shard_size(shard_md)] |
| | tensor = tensor.view(shard_md.shard_sizes) |
| |
|
| | out_narrow_view = out |
| | for dim in range(dims): |
| | out_narrow_view = out_narrow_view.narrow( |
| | dim, |
| | shard_md.shard_offsets[dim], |
| | shard_md.shard_sizes[dim], |
| | ) |
| |
|
| | out_narrow_view.copy_(tensor) |
| |
|
| | shm.close() |
| | shm.unlink() |
| |
|
| | return out |
| |
|
| | @contextmanager |
| | def _patch_sharded_tensor_load(self): |
| | """ |
| | Monkeypatch for torch's ShardedTensor, so we can unpickle without having torch.distributed set up. |
| | """ |
| |
|
| | def _rebuild_from_type_v2_monkey(func, new_type, args, state): |
| | ret = func(*args) |
| | if type(ret) is not new_type: |
| | ret = ret.as_subclass(new_type) |
| |
|
| | |
| | |
| | if isinstance(ret, ShardedTensor): |
| | ret._local_shards, ret._metadata, _, ret._sharding_spec, ret._init_rrefs = state |
| | return ret |
| |
|
| | |
| | |
| | |
| | |
| | if getattr(ret.__class__, "__setstate__", torch.Tensor.__setstate__) is not torch.Tensor.__setstate__: |
| | ret.__setstate__(state) |
| | else: |
| | ret = torch._utils._set_obj_state(ret, state) |
| | return ret |
| |
|
| | original_rebuild_from_type_v2 = torch._tensor._rebuild_from_type_v2 |
| | try: |
| | torch._tensor._rebuild_from_type_v2 = _rebuild_from_type_v2_monkey |
| | yield |
| | finally: |
| | torch._tensor._rebuild_from_type_v2 = original_rebuild_from_type_v2 |
| |
|
| | def _unshard(self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None): |
| | """ |
| | The current unsharding implementation consists of: |
| | |
| | 1. Loading each shard on a separate process and copying their sharded tensors to shared memory. |
| | 2. Loading 1 shard on the main process as a base unsharded object. |
| | 3. Using the sharded tensors in shared memory to populate the base unsharded object. |
| | |
| | This implementation replaced a prior implementation that instead loaded |
| | all shards using threads, because that implementation turned out to |
| | be extremely slow (e.g. 6+ hours) sometimes when the world size was 1024. |
| | The current implementation is slower than the old one in many scenarios, |
| | but is significantly faster in the above mentioned case (e.g. 30 minutes) |
| | if there are enough CPUs. |
| | """ |
| |
|
| | input_dir = Path(input_dir) |
| | skip_keys = skip_keys or set() |
| |
|
| | shard_filepaths = list(input_dir.glob("rank*.pt")) |
| | world_size = len(shard_filepaths) |
| | if world_size == 0: |
| | raise RuntimeError("No shards found for unsharding") |
| |
|
| | log.info("Number of shards: %d", world_size) |
| | shard_size_gb = shard_filepaths[0].stat().st_size / (1024 * 1024 * 1024) |
| | min_ram_required_estimate_gb = shard_size_gb * world_size |
| | log.info( |
| | "Shards are %.2fGB each, at least %.2fGB RAM is required", shard_size_gb, min_ram_required_estimate_gb |
| | ) |
| |
|
| | log.info("Copying sharded tensors to shared memory using multiple processes") |
| | |
| | |
| | |
| | executor = ProcessPoolExecutor( |
| | mp_context=mp.get_context("spawn"), initializer=util.prepare_cli_environment |
| | ) |
| | futures = [] |
| | for shard_filepath in shard_filepaths: |
| | shard_rank = int(shard_filepath.name[4:-3]) |
| |
|
| | if shard_rank >= world_size: |
| | raise RuntimeError( |
| | f"Shard rank {shard_rank} of file {shard_filepath} exceeds world size {world_size}" |
| | ) |
| |
|
| | futures.append(executor.submit(self._copy_sharded_data_to_shared_mem, world_size, shard_filepath)) |
| |
|
| | for f in as_completed(futures): |
| | f.result() |
| | executor.shutdown() |
| |
|
| | log.info("Loading a shard on the main process to be unsharded state") |
| | with self._patch_sharded_tensor_load(): |
| | state = torch.load(shard_filepaths[0], map_location="cpu") |
| |
|
| | for key in skip_keys: |
| | if key in state: |
| | del state[key] |
| |
|
| | log.info("Unsharding from %d shards ...", world_size) |
| | return self._unshard_using_sharded_mem(state, world_size, device, input_dir) |
| |
|
| |
|
| | @dataclass |
| | class _LocalShardedCheckpointerMetadata(BaseConfig): |
| | world_size: int = field(default_factory=get_world_size) |
| |
|
| |
|
| | @dataclass |
| | class _FlatParamShard: |
| | full_shape: torch.Size |
| | shard_offsets: Tuple[int, int] |
| | shard_data: Optional[torch.Tensor] |
| |
|
| | def copy_into(self, full_tensor: torch.Tensor) -> None: |
| | assert self.shard_data is not None |
| | full_tensor_shard_view = full_tensor.view(-1)[self.shard_offsets[0] : self.shard_offsets[1] + 1] |
| | assert self.shard_data.shape == full_tensor_shard_view.shape |
| | full_tensor_shard_view.copy_(self.shard_data) |
| |
|
| |
|
| | class LocalShardedCheckpointer(Checkpointer): |
| | """ |
| | A sharded :class:`Checkpointer` that directly saves the local FSDP flat params data. |
| | The optimizer state is saved directly with `torch.save()` without reformatting via FSDP methods. |
| | |
| | The world size must be kept consistent when using this checkpointer. However, you can easily |
| | reconstruct a full unsharded model and/or optimizer state dictionary from a single Python process |
| | using :meth:`unshard_checkpoint()` (no distributed initialization required). |
| | """ |
| |
|
| | |
| | _FLAT_PARAM_METADATA_TO_SAVE = ( |
| | "_fqns", |
| | "_shard_param_offsets", |
| | "_shard_indices", |
| | "_numels", |
| | "_numels_with_padding", |
| | "_shapes", |
| | "_shard_numel_padded", |
| | "_shard_param_infos", |
| | ) |
| |
|
| | def _fsdp_modules(self, fsdp_model: FSDP) -> List[Tuple[str, FSDP]]: |
| | """ |
| | Returns a list of FSDP modules with their FQN. |
| | """ |
| | modules = [] |
| | for name, module in fsdp_model.named_modules(): |
| | if isinstance(module, FSDP): |
| | modules.append((name, module)) |
| | return modules |
| |
|
| | def _prepare_fsdp_model(self, fsdp_model: FSDP) -> None: |
| | from torch.distributed.fsdp._runtime_utils import _lazy_init |
| |
|
| | |
| | |
| | if torch.cuda.is_available(): |
| | torch.cuda.synchronize() |
| | _lazy_init(fsdp_model, fsdp_model) |
| |
|
| | def _fsdp_handles(self, fsdp_model: FSDP) -> List[FlatParamHandle]: |
| | if version.parse(torch.__version__) < version.parse("2.1.0"): |
| | return fsdp_model._handles |
| | elif version.parse(torch.__version__) < version.parse("2.3.0"): |
| | |
| | if hasattr(fsdp_model, "_handle") and fsdp_model._handle is not None: |
| | return [fsdp_model._handle] |
| | else: |
| | return [] |
| | else: |
| | |
| | raise NotImplementedError |
| |
|
| | @torch.no_grad() |
| | def _get_flat_param_state_to_save(self, fsdp_model: FSDP) -> Dict[str, Any]: |
| | self._prepare_fsdp_model(fsdp_model) |
| | module_data = [] |
| | for module_fqn, fsdp_module in self._fsdp_modules(fsdp_model): |
| | handle_data = [] |
| | for handle in self._fsdp_handles(fsdp_module): |
| | data: Dict[str, Any] = {} |
| | |
| | |
| | flat_param = handle.flat_param |
| | data["flat_param.data"] = flat_param.detach() |
| | for key in self._FLAT_PARAM_METADATA_TO_SAVE: |
| | if hasattr(flat_param, key): |
| | data[f"flat_param.{key}"] = getattr(flat_param, key) |
| | handle_data.append(data) |
| | module_data.append({"handles": handle_data, "name": module_fqn}) |
| | return {"modules": module_data} |
| |
|
| | @torch.no_grad() |
| | def _load_flat_param_state(self, fsdp_model: FSDP, model_state: Dict[str, Any]): |
| | """Load the state produced from `self._get_flat_param_state_to_save()`.""" |
| | self._prepare_fsdp_model(fsdp_model) |
| | fsdp_modules = self._fsdp_modules(fsdp_model) |
| | assert len(model_state["modules"]) == len(fsdp_modules) |
| | for (_, fsdp_module), module_data in zip(fsdp_modules, model_state["modules"]): |
| | handles = self._fsdp_handles(fsdp_module) |
| | assert len(handles) == len(module_data["handles"]) |
| | for handle, data in zip(handles, module_data["handles"]): |
| | flat_param = handle.flat_param |
| | |
| | for key in self._FLAT_PARAM_METADATA_TO_SAVE: |
| | if hasattr(flat_param, key): |
| | assert getattr(flat_param, key) == data[f"flat_param.{key}"] |
| | |
| | flat_param.copy_(data["flat_param.data"]) |
| |
|
| | def _save_metadata(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None: |
| | if get_fs_local_rank() == 0: |
| | log.info("Saving metadata...") |
| | metadata = _LocalShardedCheckpointerMetadata() |
| | metadata.save(metadata_path := Path(dir) / "metadata.yaml") |
| | if upload_to is not None and get_global_rank() == 0: |
| | upload_target = f"{upload_to}/metadata.yaml" |
| | log.info(f"Uploading {metadata_path} to {upload_target}") |
| | upload(metadata_path, upload_target, save_overwrite=self.cfg.save_overwrite) |
| |
|
| | def _load_metadata( |
| | self, load_path: PathOrStr, *, local_cache: Optional[PathOrStr] = None |
| | ) -> _LocalShardedCheckpointerMetadata: |
| | metadata_path = resource_path(load_path, "metadata.yaml", local_cache=local_cache) |
| | return _LocalShardedCheckpointerMetadata.load(metadata_path) |
| |
|
| | def save_checkpoint( |
| | self, |
| | dir: PathOrStr, |
| | fsdp_model: FSDP, |
| | optim: Optimizer, |
| | trainer_state: Dict[str, Any], |
| | *, |
| | upload_to: Optional[str] = None, |
| | ) -> None: |
| | with self._temporary_wd(dir) as checkpoint_dir: |
| | |
| | |
| | |
| | |
| | log.info("Saving local FSDP flat params data...") |
| | save_state_dict( |
| | checkpoint_dir, |
| | f"model/rank{get_global_rank()}.pt", |
| | self._get_flat_param_state_to_save(fsdp_model), |
| | upload_to=upload_to, |
| | save_overwrite=self.cfg.save_overwrite, |
| | ) |
| |
|
| | |
| | log.info("Saving local optimizer state...") |
| | save_state_dict( |
| | checkpoint_dir, |
| | f"optim/rank{get_global_rank()}.pt", |
| | optim.state_dict(), |
| | upload_to=upload_to, |
| | save_overwrite=self.cfg.save_overwrite, |
| | ) |
| |
|
| | |
| | log.info("Saving trainer state...") |
| | save_state_dict( |
| | checkpoint_dir, |
| | f"train/rank{get_global_rank()}.pt", |
| | trainer_state, |
| | upload_to=upload_to, |
| | save_overwrite=self.cfg.save_overwrite, |
| | ) |
| |
|
| | |
| | self._save_metadata(checkpoint_dir, upload_to=upload_to) |
| |
|
| | |
| | |
| | |
| | self._save_config(checkpoint_dir, upload_to=upload_to) |
| |
|
| | def restore_checkpoint( |
| | self, |
| | load_path: PathOrStr, |
| | fsdp_model: FSDP, |
| | optim: Optimizer, |
| | *, |
| | local_cache: Optional[PathOrStr] = None, |
| | load_optimizer_state: bool = True, |
| | ) -> Dict[str, Any]: |
| | |
| | metadata = self._load_metadata(load_path, local_cache=local_cache) |
| | assert metadata.world_size == get_world_size() |
| |
|
| | |
| | log.info("Loading local FSDP flat params data...") |
| | model_state = load_state_dict( |
| | load_path, f"model/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu" |
| | ) |
| | self._load_flat_param_state(fsdp_model, model_state) |
| | del model_state |
| |
|
| | |
| | if load_optimizer_state: |
| | log.info("Loading local optimizer state...") |
| | optim_state = load_state_dict( |
| | load_path, f"optim/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu" |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | for param_id in list(optim_state["state"].keys()): |
| | state = optim_state["state"][param_id] |
| | if "grad_norm_exp_avg" in state: |
| | del state["grad_norm_exp_avg"] |
| | if len(state) == 0: |
| | del optim_state["state"][param_id] |
| | optim.load_state_dict(optim_state) |
| | del optim_state |
| |
|
| | |
| | log.info("Loading local trainer state...") |
| | trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache) |
| | barrier() |
| | return trainer_state |
| |
|
| | def _iter_flat_param_shards( |
| | self, model_state: Dict[str, Any] |
| | ) -> Generator[Tuple[str, _FlatParamShard], None, None]: |
| | for module_data in model_state["modules"]: |
| | module_prefix = module_data["name"].replace("_fsdp_wrapped_module.", "") |
| | for handle in module_data["handles"]: |
| | flat_data = handle["flat_param.data"] |
| | if (num_padding := handle["flat_param._shard_numel_padded"]) > 0: |
| | |
| | assert (flat_data[-num_padding:] == 0).all() |
| | |
| | |
| | |
| | if "flat_param._shard_indices" in handle: |
| | |
| | param_start = handle["flat_param._shard_indices"][0] |
| | current_flat_index = 0 |
| | for relative_fqn, full_shape, (offset_start, offset_end) in zip( |
| | handle["flat_param._fqns"][param_start:], |
| | handle["flat_param._shapes"][param_start:], |
| | handle["flat_param._shard_param_offsets"], |
| | ): |
| | root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}" |
| | numel_shard = offset_end - offset_start + 1 |
| | flat_param_shard = _FlatParamShard( |
| | full_shape=full_shape, |
| | shard_offsets=(offset_start, offset_end), |
| | shard_data=flat_data[current_flat_index : current_flat_index + numel_shard], |
| | ) |
| | current_flat_index += numel_shard |
| | yield root_fqn, flat_param_shard |
| | else: |
| | |
| | for relative_fqn, full_shape, shard_param_info in zip( |
| | handle["flat_param._fqns"], |
| | handle["flat_param._shapes"], |
| | handle["flat_param._shard_param_infos"], |
| | ): |
| | if not shard_param_info.in_shard: |
| | continue |
| | root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}" |
| | flat_param_shard = _FlatParamShard( |
| | full_shape=full_shape, |
| | shard_offsets=( |
| | shard_param_info.intra_param_start_idx, |
| | shard_param_info.intra_param_end_idx, |
| | ), |
| | shard_data=flat_data[ |
| | shard_param_info.offset_in_shard : shard_param_info.offset_in_shard |
| | + shard_param_info.numel_in_shard |
| | ], |
| | ) |
| | yield root_fqn, flat_param_shard |
| |
|
| | def unshard_checkpoint( |
| | self, |
| | load_path: PathOrStr, |
| | *, |
| | local_cache: Optional[PathOrStr] = None, |
| | load_optimizer_state: bool = True, |
| | load_trainer_state: bool = True, |
| | device: Optional[torch.device] = None, |
| | ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]: |
| | device = device or torch.device("cpu") |
| | metadata = self._load_metadata(load_path, local_cache=local_cache) |
| |
|
| | |
| | log.info("Gathering model state dicts...") |
| | model_state_paths = self._gather_state_dict_paths( |
| | load_path, "model", metadata.world_size, local_cache=local_cache |
| | ) |
| |
|
| | |
| | log.info("Materializing full parameters...") |
| | full_model_state: Dict[str, torch.Tensor] = {} |
| | |
| | |
| | flat_params_data: Dict[int, Dict[str, _FlatParamShard]] = defaultdict(dict) |
| | for rank, path in enumerate(model_state_paths): |
| | log.info(f"Loading shards from rank {rank}...") |
| | model_state = torch.load(path, map_location="cpu") |
| | for root_fqn, flat_param_shard in self._iter_flat_param_shards(model_state): |
| | if root_fqn not in full_model_state: |
| | log.info( |
| | f"Materializing full parameter '{root_fqn}' with shape {flat_param_shard.full_shape}..." |
| | ) |
| | assert flat_param_shard.shard_data is not None |
| | full_model_state[root_fqn] = torch.empty( |
| | flat_param_shard.full_shape, dtype=flat_param_shard.shard_data.dtype, device=device |
| | ) |
| | |
| | |
| | full_model_state[root_fqn].fill_(torch.nan) |
| | |
| | full_param = full_model_state[root_fqn] |
| | log.info(f"Loading rank {rank} shard for '{root_fqn}'...") |
| | flat_param_shard.copy_into(full_param) |
| | flat_params_data[rank][root_fqn] = replace(flat_param_shard, shard_data=None) |
| |
|
| | log.info("Validating full parameters...") |
| | for key, tensor in full_model_state.items(): |
| | if torch.isnan(tensor).any(): |
| | raise ValueError(f"Parameter '{key}' contains NaNs, this is likely a bug with the unsharder") |
| |
|
| | trainer_state: Optional[Dict[str, Any]] = None |
| | if load_trainer_state: |
| | trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache) |
| |
|
| | if not load_optimizer_state: |
| | return full_model_state, None, trainer_state |
| |
|
| | log.info("Gathering optim state dicts...") |
| | optim_state_paths = self._gather_state_dict_paths( |
| | load_path, "optim", metadata.world_size, local_cache=local_cache |
| | ) |
| |
|
| | log.info("Materializing full optim state...") |
| | full_optim_state: Dict[str, Any] = {"state": defaultdict(dict)} |
| | fqn_to_id: Dict[str, int] = {} |
| | id_to_fqn: Dict[int, str] = {} |
| | for rank, path in enumerate(optim_state_paths): |
| | log.info(f"Loading sharded optim state from rank {rank}...") |
| | optim_state = torch.load(path, map_location="cpu") |
| |
|
| | |
| | |
| | |
| | if "param_groups" not in full_optim_state: |
| | full_optim_state["param_groups"] = optim_state["param_groups"] |
| | else: |
| | assert full_optim_state["param_groups"] == optim_state["param_groups"] |
| |
|
| | |
| | if not fqn_to_id or not id_to_fqn: |
| | for group in full_optim_state["param_groups"]: |
| | for fqn, id in zip(group["param_names"], group["params"]): |
| | fqn = fqn.replace("_fsdp_wrapped_module.", "") |
| | fqn_to_id[fqn] = id |
| | id_to_fqn[id] = fqn |
| |
|
| | |
| | for id, shard_state in optim_state["state"].items(): |
| | fqn = id_to_fqn[id] |
| | flat_param_shard = flat_params_data[rank].get(fqn) |
| | full_state = full_optim_state["state"][id] |
| | for key, shard_value in shard_state.items(): |
| | assert isinstance(shard_value, torch.Tensor) |
| | if shard_value.shape == torch.Size([]): |
| | |
| | |
| | assert key in ("step", "grad_norm_exp_avg") |
| | if key not in full_state: |
| | full_state[key] = shard_value.to(device) |
| | else: |
| | assert full_state[key] == shard_value |
| | else: |
| | |
| | |
| | assert flat_param_shard is not None, f"missing flat_params_data for {fqn} from rank {rank}" |
| | if key not in full_state: |
| | log.info( |
| | f"Materializing full state '{key}' for '{fqn}' with shape {flat_param_shard.full_shape}..." |
| | ) |
| | full_state[key] = torch.empty( |
| | flat_param_shard.full_shape, dtype=shard_value.dtype, device=device |
| | ) |
| | full_state_value = full_state[key] |
| |
|
| | |
| | log.info(f"Loading rank {rank} shard state of '{key}' for '{fqn}'...") |
| | replace(flat_param_shard, shard_data=shard_value).copy_into(full_state_value) |
| |
|
| | |
| | for group in full_optim_state["param_groups"]: |
| | group["param_names"] = [n.replace("_fsdp_wrapped_module.", "") for n in group["param_names"]] |
| |
|
| | return full_model_state, full_optim_state, trainer_state |
| |
|
| | def _get_state_dict_path( |
| | self, |
| | load_path: PathOrStr, |
| | state_dict_type: str, |
| | rank: int, |
| | *, |
| | local_cache: Optional[PathOrStr] = None, |
| | progress=None, |
| | ) -> Tuple[int, Path]: |
| | fname = f"{state_dict_type}/rank{rank}.pt" |
| | return rank, resource_path(str(load_path).rstrip("/"), fname, local_cache=local_cache, progress=progress) |
| |
|
| | def _gather_state_dict_paths( |
| | self, |
| | load_path: PathOrStr, |
| | state_dict_type: str, |
| | world_size: int, |
| | *, |
| | local_cache: Optional[PathOrStr] = None, |
| | ) -> List[Path]: |
| | progress = get_progress_bar() |
| | with ThreadPoolExecutor(max_workers=self.thread_count) as executor: |
| | futures = [] |
| | for rank in range(world_size): |
| | future = executor.submit( |
| | self._get_state_dict_path, |
| | load_path, |
| | state_dict_type, |
| | rank, |
| | local_cache=local_cache, |
| | progress=progress, |
| | ) |
| | futures.append(future) |
| |
|
| | results: Dict[int, Path] = {} |
| | for future in as_completed(futures): |
| | rank, path = future.result() |
| | results[rank] = path |
| |
|
| | return [results[rank] for rank in range(world_size)] |
| |
|
| |
|
| | def build_sharded_checkpointer( |
| | cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None |
| | ) -> Checkpointer: |
| | name = name or cfg.sharded_checkpointer |
| | if name == ShardedCheckpointerType.torch_new: |
| | return TorchNewStyleShardedCheckpointer(cfg) |
| | elif name == ShardedCheckpointerType.torch_legacy: |
| | return TorchLegacyShardedCheckpointer(cfg) |
| | elif name == ShardedCheckpointerType.local: |
| | return LocalShardedCheckpointer(cfg) |
| | else: |
| | raise NotImplementedError(name) |
| |
|