# Example code for running inference on a pre-trained model import os import json import numpy as np import cv2 import torch from models import build_model # os.environ['CUDA_VISIBLE_DEVICES'] = "0,1" device = torch.device('cuda' if torch.cuda.is_available() else "cpu") def sigmoid(arr): return 1. / (1 + np.exp(-arr)) class Inference(object): def __init__(self, model_path): self.model_path = model_path config_path = os.path.join(model_path, 'config.json') with open(config_path) as fin: params = json.load(fin) self.model_params = params['model_params'] self.modality_mapping = params['modality_mapping'] self.model = self.load_model() def inference(self, image, modality): assert modality in self.modality_mapping, "Modality '{}' not supported".format(modality) image, raw_h, raw_w = self.load_image(image) modality_idx = self.modality_mapping[modality] modality_idx = torch.tensor([modality_idx]) with torch.no_grad(): output = self.model.predict(x=image, device=device, dataset_idx=modality_idx) output = output.data.cpu().numpy()[0][0] output = sigmoid(output) * 255 output = output.astype(np.uint8) output = cv2.resize(output, (raw_w, raw_h)) return output def load_image(self, image): # Load the image and preprocess it if isinstance(image, str): image = cv2.imread(image)[:, :, [2, 1, 0]] raw_h, raw_w = image.shape[:2] image = cv2.resize(image, (self.model_params['size_w'], self.model_params['size_h'])) image = image.astype(np.float32) / 255.0 image = np.transpose(image, (2, 0, 1)) image = np.expand_dims(image, axis=0) image = torch.tensor(image) return image, raw_h, raw_w def load_model(self): print('Loading model from {}'.format(self.model_path)) model = build_model(model_name=self.model_params['net'], model_params=self.model_params, training=False, dataset_idx=list(self.modality_mapping.values()), pretrained=False) #print(model.model.pos_promot3['0']) model.set_device(device) # model.requires_grad_false() model.load_model(os.path.join(self.model_path, 'model.pkl')) model.set_mode('eval') return model if __name__ == '__main__': model_path = 'checkpoints/UNet_DCP_1024' image_paths = [ 'images/FFA.bmp', 'images/CFP.jpg', 'images/SLO.jpg', 'images/UWF.jpg', 'images/OCTA.png' ] modalities = ['FFA', 'CFP', 'SLO', 'UWF', 'OCTA'] output_root = 'output_images' os.makedirs(output_root, exist_ok=True) inference = Inference(model_path) for image_path, modality in zip(image_paths, modalities): output = inference.inference(image_path, modality) cv2.imwrite(os.path.join(output_root, '{}.png'.format(modality)), output)