Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from unittest.mock import patch | |
| import torch | |
| import torch.nn as nn | |
| from mmengine.model import BaseModel | |
| from mmengine.registry import MODELS | |
| class ToyModel(BaseModel): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__() | |
| self.conv = nn.Conv2d(1, 1, 1) | |
| def forward(self, *args, **kwargs): | |
| return {'loss': torch.tensor(0.0)} | |
| def update_params_step(self, loss): | |
| pass | |
| def runtimeinfo_step(self, runner, batch_idx, data_batch=None): | |
| runner.message_hub.update_info('iter', runner.iter) | |
| lr_dict = runner.optim_wrapper.get_lr() | |
| for name, lr in lr_dict.items(): | |
| runner.message_hub.update_scalar(f'train/{name}', lr[0]) | |
| momentum_dict = runner.optim_wrapper.get_momentum() | |
| for name, momentum in momentum_dict.items(): | |
| runner.message_hub.update_scalar(f'train/{name}', momentum[0]) | |
| def fake_run(cfg): | |
| from mmengine.runner import Runner | |
| cfg.pop('model') | |
| cfg.pop('visualizer') | |
| cfg.pop('val_dataloader') | |
| cfg.pop('val_evaluator') | |
| cfg.pop('val_cfg') | |
| cfg.pop('test_dataloader') | |
| cfg.pop('test_evaluator') | |
| cfg.pop('test_cfg') | |
| extra_cfg = dict( | |
| model=dict(type='ToyModel'), | |
| visualizer=dict( | |
| type='Visualizer', | |
| vis_backends=[ | |
| dict(type='TensorboardVisBackend', save_dir='temp_dir') | |
| ]), | |
| ) | |
| cfg.merge_from_dict(extra_cfg) | |
| # build the runner from config | |
| runner = Runner.from_cfg(cfg) | |
| # start training | |
| runner.train() | |