Spaces:
Configuration error
Configuration error
| import torch | |
| import torch.nn as nn | |
| from utils.misc import get_rank | |
| class BaseModel(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.rank = get_rank() | |
| self.setup() | |
| if self.config.get('weights', None): | |
| self.load_state_dict(torch.load(self.config.weights)) | |
| def setup(self): | |
| raise NotImplementedError | |
| def update_step(self, epoch, global_step): | |
| pass | |
| def train(self, mode=True): | |
| return super().train(mode=mode) | |
| def eval(self): | |
| return super().eval() | |
| def regularizations(self, out): | |
| return {} | |
| def export(self, export_config): | |
| return {} | |