Alexander Bagus
initial commit
d2c9b66
raw
history blame
2.16 kB
from functools import partial
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)
non_reentrant_wrapper = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
def apply_checkpointing(model, block, p):
"""
Apply selective activation checkpointing.
Selectivity is defined as a percentage p, which means we apply ac
on p of the total blocks. p is a floating number in the range of
[0, 1].
Some examples:
p = 0: no ac for all blocks. same as `fsdp_activation_checkpointing=False`
p = 1: apply ac on every block. i.e. "full ac".
p = 1/2: [ac, no-ac, ac, no-ac, ...]
p = 1/3: [no-ac, ac, no-ac, no-ac, ac, no-ac, ...]
p = 2/3: [ac, no-ac, ac, ac, no-ac, ac, ...]
Since blocks are homogeneous, we make ac blocks evenly spaced among
all blocks.
Implementation:
For a given ac ratio p, we should essentially apply ac on every "1/p"
blocks. The first ac block can be as early as the 0th block, or as
late as the "1/p"th block, and we pick the middle one: (0.5p)th block.
Therefore, we are essentially to apply ac on:
(0.5/p)th block, (1.5/p)th block, (2.5/p)th block, etc., and of course,
with these values rounding to integers.
Since ac is applied recursively, we can simply use the following math
in the code to apply ac on corresponding blocks.
"""
block_idx = 0
cut_off = 1 / 2
# when passing p as a fraction number (e.g. 1/3), it will be interpreted
# as a string in argv, thus we need eval("1/3") here for fractions.
p = eval(p) if isinstance(p, str) else p
def selective_checkpointing(submodule):
nonlocal block_idx
nonlocal cut_off
if isinstance(submodule, block):
block_idx += 1
if block_idx * p >= cut_off:
cut_off += 1
return True
return False
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=selective_checkpointing,
)