Spaces:
Runtime error
Runtime error
| import torch | |
| from tqdm import tqdm | |
| # training function | |
| def train(model, dataloader, optimizer, criterion, train_data, device): | |
| print('Training') | |
| model.train() | |
| counter = 0 | |
| train_running_loss = 0.0 | |
| for i, data in tqdm(enumerate(dataloader), total=int(len(train_data)/dataloader.batch_size)): | |
| counter += 1 | |
| data, target = data['image'].to(device), data['label'].to(device) | |
| optimizer.zero_grad() | |
| outputs = model(data) | |
| # apply sigmoid activation to get all the outputs between 0 and 1 | |
| outputs = torch.sigmoid(outputs) | |
| loss = criterion(outputs, target) | |
| train_running_loss += loss.item() | |
| # backpropagation | |
| loss.backward() | |
| # update optimizer parameters | |
| optimizer.step() | |
| train_loss = train_running_loss / counter | |
| return train_loss | |
| # validation function | |
| def validate(model, dataloader, criterion, val_data, device): | |
| print('Validating') | |
| model.eval() | |
| counter = 0 | |
| val_running_loss = 0.0 | |
| with torch.no_grad(): | |
| for i, data in tqdm(enumerate(dataloader), total=int(len(val_data)/dataloader.batch_size)): | |
| counter += 1 | |
| data, target = data['image'].to(device), data['label'].to(device) | |
| outputs = model(data) | |
| # apply sigmoid activation to get all the outputs between 0 and 1 | |
| outputs = torch.sigmoid(outputs) | |
| loss = criterion(outputs, target) | |
| val_running_loss += loss.item() | |
| val_loss = val_running_loss / counter | |
| return val_loss |