|
|
from config import config
|
|
|
from torchvision import transforms
|
|
|
import cv2 as cv
|
|
|
|
|
|
|
|
|
|
|
|
class myTransformMethod():
|
|
|
def __call__(self, img):
|
|
|
|
|
|
img = cv.resize(img, (config.image_size, config.image_size))
|
|
|
if img.shape[-1] == 3:
|
|
|
img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
|
|
|
return img
|
|
|
|
|
|
|
|
|
myTransform = {
|
|
|
'trainTransform': transforms.Compose([
|
|
|
myTransformMethod(),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize([0.5], [0.5])
|
|
|
]),
|
|
|
'testTransform': transforms.Compose([
|
|
|
myTransformMethod(),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize([0.5], [0.5])
|
|
|
]),
|
|
|
|
|
|
}
|
|
|
|