UNet_DCP_1024 / inference.py
qijie.wei
update
6735c2f
# Example code for running inference on a pre-trained model
import os
import json
import numpy as np
import cv2
import torch
from models import build_model
# os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
def sigmoid(arr):
return 1. / (1 + np.exp(-arr))
class Inference(object):
def __init__(self, model_path):
self.model_path = model_path
config_path = os.path.join(model_path, 'config.json')
with open(config_path) as fin:
params = json.load(fin)
self.model_params = params['model_params']
self.modality_mapping = params['modality_mapping']
self.model = self.load_model()
def inference(self, image, modality):
assert modality in self.modality_mapping, "Modality '{}' not supported".format(modality)
image, raw_h, raw_w = self.load_image(image)
modality_idx = self.modality_mapping[modality]
modality_idx = torch.tensor([modality_idx])
with torch.no_grad():
output = self.model.predict(x=image, device=device, dataset_idx=modality_idx)
output = output.data.cpu().numpy()[0][0]
output = sigmoid(output) * 255
output = output.astype(np.uint8)
output = cv2.resize(output, (raw_w, raw_h))
return output
def load_image(self, image):
# Load the image and preprocess it
if isinstance(image, str):
image = cv2.imread(image)[:, :, [2, 1, 0]]
raw_h, raw_w = image.shape[:2]
image = cv2.resize(image, (self.model_params['size_w'], self.model_params['size_h']))
image = image.astype(np.float32) / 255.0
image = np.transpose(image, (2, 0, 1))
image = np.expand_dims(image, axis=0)
image = torch.tensor(image)
return image, raw_h, raw_w
def load_model(self):
print('Loading model from {}'.format(self.model_path))
model = build_model(model_name=self.model_params['net'],
model_params=self.model_params,
training=False,
dataset_idx=list(self.modality_mapping.values()),
pretrained=False)
#print(model.model.pos_promot3['0'])
model.set_device(device)
# model.requires_grad_false()
model.load_model(os.path.join(self.model_path, 'model.pkl'))
model.set_mode('eval')
return model
if __name__ == '__main__':
model_path = 'checkpoints/UNet_DCP_1024'
image_paths = [
'images/FFA.bmp',
'images/CFP.jpg',
'images/SLO.jpg',
'images/UWF.jpg',
'images/OCTA.png'
]
modalities = ['FFA', 'CFP', 'SLO', 'UWF', 'OCTA']
output_root = 'output_images'
os.makedirs(output_root, exist_ok=True)
inference = Inference(model_path)
for image_path, modality in zip(image_paths, modalities):
output = inference.inference(image_path, modality)
cv2.imwrite(os.path.join(output_root, '{}.png'.format(modality)), output)