Spaces:
Runtime error
Runtime error
| import models | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| from engine import train, validate | |
| from dataset import ImageDataset | |
| from torch.utils.data import DataLoader | |
| matplotlib.style.use('ggplot') | |
| # initialize the computation device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(device) | |
| #intialize the model | |
| model = models.model(pretrained=True, requires_grad=False).to(device) | |
| # learning parameters | |
| lr = 0.0001 | |
| epochs = 10 | |
| batch_size = 32 | |
| optimizer = optim.Adam(model.parameters(), lr=lr) | |
| criterion = nn.BCELoss() | |
| # read the training csv file | |
| train_csv = pd.read_csv('../input/movie-classifier/Multi_Label_dataset/train.csv') | |
| # train dataset | |
| train_data = ImageDataset( | |
| train_csv, train=True, test=False | |
| ) | |
| # validation dataset | |
| valid_data = ImageDataset( | |
| train_csv, train=False, test=False | |
| ) | |
| # train data loader | |
| train_loader = DataLoader( | |
| train_data, | |
| batch_size=batch_size, | |
| shuffle=True | |
| ) | |
| # validation data loader | |
| valid_loader = DataLoader( | |
| valid_data, | |
| batch_size=batch_size, | |
| shuffle=False | |
| ) | |
| # start the training and validation | |
| train_loss = [] | |
| valid_loss = [] | |
| for epoch in range(epochs): | |
| print(f"Epoch {epoch+1} of {epochs}") | |
| train_epoch_loss = train( | |
| model, train_loader, optimizer, criterion, train_data, device | |
| ) | |
| valid_epoch_loss = validate( | |
| model, valid_loader, criterion, valid_data, device | |
| ) | |
| train_loss.append(train_epoch_loss) | |
| valid_loss.append(valid_epoch_loss) | |
| print(f"Train Loss: {train_epoch_loss:.4f}") | |
| print(f'Val Loss: {valid_epoch_loss:.4f}') | |
| # save the trained model to disk | |
| torch.save({ | |
| 'epoch': epochs, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'loss': criterion, | |
| }, '../outputs/model.pth') | |
| # plot and save the train and validation line graphs | |
| plt.figure(figsize=(10, 7)) | |
| plt.plot(train_loss, color='orange', label='train loss') | |
| plt.plot(valid_loss, color='red', label='validataion loss') | |
| plt.xlabel('Epochs') | |
| plt.ylabel('Loss') | |
| plt.legend() | |
| plt.savefig('../outputs/loss.png') | |
| plt.show() |