""" Data transformation pipeline for model preprocessing. """ import numpy as np import cv2 def half_down(feature, label): """Downsample feature by 2x using pyrDown.""" h, w, c = feature.shape[-3:] h = h // 2 w = w // 2 b = feature.shape[0] down = np.zeros((b, h, w, c), dtype=feature.dtype) for i in range(b): down[i] = cv2.pyrDown(feature[i]).reshape((h, w, c)) return down, label def mono(feature, label): """Convert to grayscale and normalize to 0-1.""" monos = [] for temp in feature: gray = temp if len(temp.shape) == 3: if temp.shape[2] == 3: gray = cv2.cvtColor(temp, cv2.COLOR_RGB2GRAY) gray = np.expand_dims(gray, -1) gray = (gray / 255.0).astype(np.float32) elif gray.dtype == np.uint8: gray = (gray / 255.0).astype(np.float32) monos.append(gray) return np.stack(monos), label def normalize(feature, label): """Normalize uint8 to float32 0-1.""" result = [] for temp in feature: layer = (temp / 255.0).astype(np.float32) result.append(layer) return np.stack(result), label def invert(feature, label): """Invert colors.""" return 1 - feature, label def img_std_nor(feature, label): """ImageNet standard normalization.""" mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) std = np.array([0.229, 0.224, 0.225], dtype=np.float32) feature = feature.astype(np.float32) feature = (feature / 255.0 - mean) / std return feature, label def hwc2chw(feature, label): """Convert (N, H, W, C) to (N, C, H, W).""" feature = np.moveaxis(feature, -1, 1) return feature, label def to_float32(feature, label): """Convert to float32.""" return feature.astype(np.float32), label class TransWrapper: def __init__(self, fn): self.trans_fn = fn def __call__(self, feature, label): return self.trans_fn(feature, label) class Half_Down(TransWrapper): def __init__(self): super().__init__(half_down) class Img_std_N(TransWrapper): def __init__(self): super().__init__(img_std_nor) class Mono(TransWrapper): def __init__(self): super().__init__(mono) class Normalize(TransWrapper): def __init__(self): super().__init__(normalize) class Invert(TransWrapper): def __init__(self): super().__init__(invert) class HWC2CHW(TransWrapper): def __init__(self): super().__init__(hwc2chw) class To_Float32(TransWrapper): def __init__(self): super().__init__(to_float32) class TransformFactory: def __init__(self): trans_classes = [ Half_Down, Img_std_N, HWC2CHW, To_Float32, Mono, Normalize, Invert, ] self.trans_dict = {c.__name__: c() for c in trans_classes} def get_transfn(self, name): return self.trans_dict.get(name) transform_factory = TransformFactory() class Composer: """Compose multiple transforms into a pipeline.""" def __init__(self, trans): self.trans_name = trans def __call__(self, feature, target): newf, newl = feature, target for name in self.trans_name: trans = transform_factory.get_transfn(name) if trans: newf, newl = trans(newf, newl) return newf, newl