Spaces:
Configuration error
Configuration error
| import sys | |
| import warnings | |
| from bisect import bisect_right | |
| import torch | |
| import torch.nn as nn | |
| from torch.optim import lr_scheduler | |
| from pytorch_lightning.utilities.rank_zero import rank_zero_debug | |
| class ChainedScheduler(lr_scheduler._LRScheduler): | |
| """Chains list of learning rate schedulers. It takes a list of chainable learning | |
| rate schedulers and performs consecutive step() functions belong to them by just | |
| one call. | |
| Args: | |
| schedulers (list): List of chained schedulers. | |
| Example: | |
| >>> # Assuming optimizer uses lr = 1. for all groups | |
| >>> # lr = 0.09 if epoch == 0 | |
| >>> # lr = 0.081 if epoch == 1 | |
| >>> # lr = 0.729 if epoch == 2 | |
| >>> # lr = 0.6561 if epoch == 3 | |
| >>> # lr = 0.59049 if epoch >= 4 | |
| >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) | |
| >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) | |
| >>> scheduler = ChainedScheduler([scheduler1, scheduler2]) | |
| >>> for epoch in range(100): | |
| >>> train(...) | |
| >>> validate(...) | |
| >>> scheduler.step() | |
| """ | |
| def __init__(self, optimizer, schedulers): | |
| for scheduler_idx in range(1, len(schedulers)): | |
| if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): | |
| raise ValueError( | |
| "ChainedScheduler expects all schedulers to belong to the same optimizer, but " | |
| "got schedulers at index {} and {} to be different".format(0, scheduler_idx) | |
| ) | |
| self._schedulers = list(schedulers) | |
| self.optimizer = optimizer | |
| def step(self): | |
| for scheduler in self._schedulers: | |
| scheduler.step() | |
| def state_dict(self): | |
| """Returns the state of the scheduler as a :class:`dict`. | |
| It contains an entry for every variable in self.__dict__ which | |
| is not the optimizer. | |
| The wrapped scheduler states will also be saved. | |
| """ | |
| state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} | |
| state_dict['_schedulers'] = [None] * len(self._schedulers) | |
| for idx, s in enumerate(self._schedulers): | |
| state_dict['_schedulers'][idx] = s.state_dict() | |
| return state_dict | |
| def load_state_dict(self, state_dict): | |
| """Loads the schedulers state. | |
| Args: | |
| state_dict (dict): scheduler state. Should be an object returned | |
| from a call to :meth:`state_dict`. | |
| """ | |
| _schedulers = state_dict.pop('_schedulers') | |
| self.__dict__.update(state_dict) | |
| # Restore state_dict keys in order to prevent side effects | |
| # https://github.com/pytorch/pytorch/issues/32756 | |
| state_dict['_schedulers'] = _schedulers | |
| for idx, s in enumerate(_schedulers): | |
| self._schedulers[idx].load_state_dict(s) | |
| class SequentialLR(lr_scheduler._LRScheduler): | |
| """Receives the list of schedulers that is expected to be called sequentially during | |
| optimization process and milestone points that provides exact intervals to reflect | |
| which scheduler is supposed to be called at a given epoch. | |
| Args: | |
| schedulers (list): List of chained schedulers. | |
| milestones (list): List of integers that reflects milestone points. | |
| Example: | |
| >>> # Assuming optimizer uses lr = 1. for all groups | |
| >>> # lr = 0.1 if epoch == 0 | |
| >>> # lr = 0.1 if epoch == 1 | |
| >>> # lr = 0.9 if epoch == 2 | |
| >>> # lr = 0.81 if epoch == 3 | |
| >>> # lr = 0.729 if epoch == 4 | |
| >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) | |
| >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) | |
| >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2]) | |
| >>> for epoch in range(100): | |
| >>> train(...) | |
| >>> validate(...) | |
| >>> scheduler.step() | |
| """ | |
| def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False): | |
| for scheduler_idx in range(1, len(schedulers)): | |
| if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): | |
| raise ValueError( | |
| "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " | |
| "got schedulers at index {} and {} to be different".format(0, scheduler_idx) | |
| ) | |
| if (len(milestones) != len(schedulers) - 1): | |
| raise ValueError( | |
| "Sequential Schedulers expects number of schedulers provided to be one more " | |
| "than the number of milestone points, but got number of schedulers {} and the " | |
| "number of milestones to be equal to {}".format(len(schedulers), len(milestones)) | |
| ) | |
| self._schedulers = schedulers | |
| self._milestones = milestones | |
| self.last_epoch = last_epoch + 1 | |
| self.optimizer = optimizer | |
| def step(self): | |
| self.last_epoch += 1 | |
| idx = bisect_right(self._milestones, self.last_epoch) | |
| if idx > 0 and self._milestones[idx - 1] == self.last_epoch: | |
| self._schedulers[idx].step(0) | |
| else: | |
| self._schedulers[idx].step() | |
| def state_dict(self): | |
| """Returns the state of the scheduler as a :class:`dict`. | |
| It contains an entry for every variable in self.__dict__ which | |
| is not the optimizer. | |
| The wrapped scheduler states will also be saved. | |
| """ | |
| state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} | |
| state_dict['_schedulers'] = [None] * len(self._schedulers) | |
| for idx, s in enumerate(self._schedulers): | |
| state_dict['_schedulers'][idx] = s.state_dict() | |
| return state_dict | |
| def load_state_dict(self, state_dict): | |
| """Loads the schedulers state. | |
| Args: | |
| state_dict (dict): scheduler state. Should be an object returned | |
| from a call to :meth:`state_dict`. | |
| """ | |
| _schedulers = state_dict.pop('_schedulers') | |
| self.__dict__.update(state_dict) | |
| # Restore state_dict keys in order to prevent side effects | |
| # https://github.com/pytorch/pytorch/issues/32756 | |
| state_dict['_schedulers'] = _schedulers | |
| for idx, s in enumerate(_schedulers): | |
| self._schedulers[idx].load_state_dict(s) | |
| class ConstantLR(lr_scheduler._LRScheduler): | |
| """Decays the learning rate of each parameter group by a small constant factor until the | |
| number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can | |
| happen simultaneously with other changes to the learning rate from outside this scheduler. | |
| When last_epoch=-1, sets initial lr as lr. | |
| Args: | |
| optimizer (Optimizer): Wrapped optimizer. | |
| factor (float): The number we multiply learning rate until the milestone. Default: 1./3. | |
| total_iters (int): The number of steps that the scheduler decays the learning rate. | |
| Default: 5. | |
| last_epoch (int): The index of the last epoch. Default: -1. | |
| verbose (bool): If ``True``, prints a message to stdout for | |
| each update. Default: ``False``. | |
| Example: | |
| >>> # Assuming optimizer uses lr = 0.05 for all groups | |
| >>> # lr = 0.025 if epoch == 0 | |
| >>> # lr = 0.025 if epoch == 1 | |
| >>> # lr = 0.025 if epoch == 2 | |
| >>> # lr = 0.025 if epoch == 3 | |
| >>> # lr = 0.05 if epoch >= 4 | |
| >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4) | |
| >>> for epoch in range(100): | |
| >>> train(...) | |
| >>> validate(...) | |
| >>> scheduler.step() | |
| """ | |
| def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False): | |
| if factor > 1.0 or factor < 0: | |
| raise ValueError('Constant multiplicative factor expected to be between 0 and 1.') | |
| self.factor = factor | |
| self.total_iters = total_iters | |
| super(ConstantLR, self).__init__(optimizer, last_epoch, verbose) | |
| def get_lr(self): | |
| if not self._get_lr_called_within_step: | |
| warnings.warn("To get the last learning rate computed by the scheduler, " | |
| "please use `get_last_lr()`.", UserWarning) | |
| if self.last_epoch == 0: | |
| return [group['lr'] * self.factor for group in self.optimizer.param_groups] | |
| if (self.last_epoch > self.total_iters or | |
| (self.last_epoch != self.total_iters)): | |
| return [group['lr'] for group in self.optimizer.param_groups] | |
| if (self.last_epoch == self.total_iters): | |
| return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups] | |
| def _get_closed_form_lr(self): | |
| return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) | |
| for base_lr in self.base_lrs] | |
| class LinearLR(lr_scheduler._LRScheduler): | |
| """Decays the learning rate of each parameter group by linearly changing small | |
| multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters. | |
| Notice that such decay can happen simultaneously with other changes to the learning rate | |
| from outside this scheduler. When last_epoch=-1, sets initial lr as lr. | |
| Args: | |
| optimizer (Optimizer): Wrapped optimizer. | |
| start_factor (float): The number we multiply learning rate in the first epoch. | |
| The multiplication factor changes towards end_factor in the following epochs. | |
| Default: 1./3. | |
| end_factor (float): The number we multiply learning rate at the end of linear changing | |
| process. Default: 1.0. | |
| total_iters (int): The number of iterations that multiplicative factor reaches to 1. | |
| Default: 5. | |
| last_epoch (int): The index of the last epoch. Default: -1. | |
| verbose (bool): If ``True``, prints a message to stdout for | |
| each update. Default: ``False``. | |
| Example: | |
| >>> # Assuming optimizer uses lr = 0.05 for all groups | |
| >>> # lr = 0.025 if epoch == 0 | |
| >>> # lr = 0.03125 if epoch == 1 | |
| >>> # lr = 0.0375 if epoch == 2 | |
| >>> # lr = 0.04375 if epoch == 3 | |
| >>> # lr = 0.05 if epoch >= 4 | |
| >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4) | |
| >>> for epoch in range(100): | |
| >>> train(...) | |
| >>> validate(...) | |
| >>> scheduler.step() | |
| """ | |
| def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1, | |
| verbose=False): | |
| if start_factor > 1.0 or start_factor < 0: | |
| raise ValueError('Starting multiplicative factor expected to be between 0 and 1.') | |
| if end_factor > 1.0 or end_factor < 0: | |
| raise ValueError('Ending multiplicative factor expected to be between 0 and 1.') | |
| self.start_factor = start_factor | |
| self.end_factor = end_factor | |
| self.total_iters = total_iters | |
| super(LinearLR, self).__init__(optimizer, last_epoch, verbose) | |
| def get_lr(self): | |
| if not self._get_lr_called_within_step: | |
| warnings.warn("To get the last learning rate computed by the scheduler, " | |
| "please use `get_last_lr()`.", UserWarning) | |
| if self.last_epoch == 0: | |
| return [group['lr'] * self.start_factor for group in self.optimizer.param_groups] | |
| if (self.last_epoch > self.total_iters): | |
| return [group['lr'] for group in self.optimizer.param_groups] | |
| return [group['lr'] * (1. + (self.end_factor - self.start_factor) / | |
| (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor))) | |
| for group in self.optimizer.param_groups] | |
| def _get_closed_form_lr(self): | |
| return [base_lr * (self.start_factor + | |
| (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters) | |
| for base_lr in self.base_lrs] | |
| custom_schedulers = ['ConstantLR', 'LinearLR'] | |
| def get_scheduler(name): | |
| if hasattr(lr_scheduler, name): | |
| return getattr(lr_scheduler, name) | |
| elif name in custom_schedulers: | |
| return getattr(sys.modules[__name__], name) | |
| else: | |
| raise NotImplementedError | |
| def getattr_recursive(m, attr): | |
| for name in attr.split('.'): | |
| m = getattr(m, name) | |
| return m | |
| def get_parameters(model, name): | |
| module = getattr_recursive(model, name) | |
| if isinstance(module, nn.Module): | |
| return module.parameters() | |
| elif isinstance(module, nn.Parameter): | |
| return module | |
| return [] | |
| def parse_optimizer(config, model): | |
| if hasattr(config, 'params'): | |
| params = [{'params': get_parameters(model, name), 'name': name, **args} for name, args in config.params.items()] | |
| rank_zero_debug('Specify optimizer params:', config.params) | |
| else: | |
| params = model.parameters() | |
| if config.name in ['FusedAdam']: | |
| import apex | |
| optim = getattr(apex.optimizers, config.name)(params, **config.args) | |
| else: | |
| optim = getattr(torch.optim, config.name)(params, **config.args) | |
| return optim | |
| def parse_scheduler(config, optimizer): | |
| interval = config.get('interval', 'epoch') | |
| assert interval in ['epoch', 'step'] | |
| if config.name == 'SequentialLR': | |
| scheduler = { | |
| 'scheduler': SequentialLR(optimizer, [parse_scheduler(conf, optimizer)['scheduler'] for conf in config.schedulers], milestones=config.milestones), | |
| 'interval': interval | |
| } | |
| elif config.name == 'Chained': | |
| scheduler = { | |
| 'scheduler': ChainedScheduler([parse_scheduler(conf, optimizer)['scheduler'] for conf in config.schedulers]), | |
| 'interval': interval | |
| } | |
| else: | |
| scheduler = { | |
| 'scheduler': get_scheduler(config.name)(optimizer, **config.args), | |
| 'interval': interval | |
| } | |
| return scheduler | |
| def update_module_step(m, epoch, global_step): | |
| if hasattr(m, 'update_step'): | |
| m.update_step(epoch, global_step) | |