Spaces:
Runtime error
Runtime error
| import torch | |
| OPTIMIZERS_POOL = { | |
| 'sgd': torch.optim.SGD, | |
| } | |
| def get_optimizer(model_params, optimizer_config): | |
| name, params = list(optimizer_config.items())[0] | |
| optimizer = OPTIMIZERS_POOL[name](model_params, **params) | |
| return optimizer | |