Spaces:
Running
on
Zero
Running
on
Zero
| from functools import partial | |
| from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | |
| CheckpointImpl, | |
| apply_activation_checkpointing, | |
| checkpoint_wrapper, | |
| ) | |
| non_reentrant_wrapper = partial( | |
| checkpoint_wrapper, | |
| checkpoint_impl=CheckpointImpl.NO_REENTRANT, | |
| ) | |
| def apply_checkpointing(model, block, p): | |
| """ | |
| Apply selective activation checkpointing. | |
| Selectivity is defined as a percentage p, which means we apply ac | |
| on p of the total blocks. p is a floating number in the range of | |
| [0, 1]. | |
| Some examples: | |
| p = 0: no ac for all blocks. same as `fsdp_activation_checkpointing=False` | |
| p = 1: apply ac on every block. i.e. "full ac". | |
| p = 1/2: [ac, no-ac, ac, no-ac, ...] | |
| p = 1/3: [no-ac, ac, no-ac, no-ac, ac, no-ac, ...] | |
| p = 2/3: [ac, no-ac, ac, ac, no-ac, ac, ...] | |
| Since blocks are homogeneous, we make ac blocks evenly spaced among | |
| all blocks. | |
| Implementation: | |
| For a given ac ratio p, we should essentially apply ac on every "1/p" | |
| blocks. The first ac block can be as early as the 0th block, or as | |
| late as the "1/p"th block, and we pick the middle one: (0.5p)th block. | |
| Therefore, we are essentially to apply ac on: | |
| (0.5/p)th block, (1.5/p)th block, (2.5/p)th block, etc., and of course, | |
| with these values rounding to integers. | |
| Since ac is applied recursively, we can simply use the following math | |
| in the code to apply ac on corresponding blocks. | |
| """ | |
| block_idx = 0 | |
| cut_off = 1 / 2 | |
| # when passing p as a fraction number (e.g. 1/3), it will be interpreted | |
| # as a string in argv, thus we need eval("1/3") here for fractions. | |
| p = eval(p) if isinstance(p, str) else p | |
| def selective_checkpointing(submodule): | |
| nonlocal block_idx | |
| nonlocal cut_off | |
| if isinstance(submodule, block): | |
| block_idx += 1 | |
| if block_idx * p >= cut_off: | |
| cut_off += 1 | |
| return True | |
| return False | |
| apply_activation_checkpointing( | |
| model, | |
| checkpoint_wrapper_fn=non_reentrant_wrapper, | |
| check_fn=selective_checkpointing, | |
| ) | |