File size: 674 Bytes
6434535 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from config import config
from torchvision import transforms
import cv2 as cv
class myTransformMethod():
def __call__(self, img):
img = cv.resize(img, (config.image_size, config.image_size))
if img.shape[-1] == 3: # HWC
img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
return img
myTransform = {
'trainTransform': transforms.Compose([
myTransformMethod(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
]),
'testTransform': transforms.Compose([
myTransformMethod(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
]),
}
|