Spaces:
Runtime error
Runtime error
| import torch | |
| import cv2 | |
| import numpy as np | |
| import torchvision.transforms as transforms | |
| from torch.utils.data import Dataset | |
| class ImageDataset(Dataset): | |
| def __init__(self, csv, train, test): | |
| self.csv = csv | |
| self.train = train | |
| self.test = test | |
| self.all_image_names = self.csv[:]['Id'] | |
| self.all_labels = np.array(self.csv.drop(['Id', 'Genre'], axis=1)) | |
| self.train_ratio = int(0.85 * len(self.csv)) | |
| self.valid_ratio = len(self.csv) - self.train_ratio | |
| # set the training data images and labels | |
| if self.train == True: | |
| print(f"Number of training images: {self.train_ratio}") | |
| self.image_names = list(self.all_image_names[:self.train_ratio]) | |
| self.labels = list(self.all_labels[:self.train_ratio]) | |
| # define the training transforms | |
| self.transform = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.Resize((400, 400)), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomRotation(degrees=45), | |
| transforms.ToTensor(), | |
| ]) | |
| # set the validation data images and labels | |
| elif self.train == False and self.test == False: | |
| print(f"Number of validation images: {self.valid_ratio}") | |
| self.image_names = list(self.all_image_names[-self.valid_ratio:-10]) | |
| self.labels = list(self.all_labels[-self.valid_ratio:]) | |
| # define the validation transforms | |
| self.transform = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.Resize((400, 400)), | |
| transforms.ToTensor(), | |
| ]) | |
| # set the test data images and labels, only last 10 images | |
| # this, we will use in a separate inference script | |
| elif self.test == True and self.train == False: | |
| self.image_names = list(self.all_image_names[-10:]) | |
| self.labels = list(self.all_labels[-10:]) | |
| # define the test transforms | |
| self.transform = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.ToTensor(), | |
| ]) | |
| def __len__(self): | |
| return len(self.image_names) | |
| def __getitem__(self, index): | |
| image = cv2.imread(f"../input/movie-classifier/Multi_Label_dataset/Images/{self.image_names[index]}.jpg") | |
| # convert the image from BGR to RGB color format | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # apply image transforms | |
| image = self.transform(image) | |
| targets = self.labels[index] | |
| return { | |
| 'image': torch.tensor(image, dtype=torch.float32), | |
| 'label': torch.tensor(targets, dtype=torch.float32) | |
| } |