|
|
from config import config
|
|
|
import numpy as np
|
|
|
|
|
|
from dataset import myC2BDataset
|
|
|
from transform import myTransform
|
|
|
from torch.utils.data import DataLoader
|
|
|
from diffusers import LCMScheduler
|
|
|
|
|
|
from tqdm import tqdm
|
|
|
import cv2 as cv
|
|
|
import torch
|
|
|
import time
|
|
|
import os
|
|
|
|
|
|
from monai.utils import set_determinism
|
|
|
|
|
|
set_determinism(42)
|
|
|
|
|
|
|
|
|
def eval():
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
output_path = os.path.join("lcm_output_bs", "BS")
|
|
|
masked_output_path = os.path.join("lcm_output_bs", "Masked_BS")
|
|
|
fusion_output_path = os.path.join("lcm_output_bs", "Fusion_BS")
|
|
|
|
|
|
cxr_path = os.path.join("SZCH-X-Rays", "CXR")
|
|
|
masked_cxr_path = os.path.join("SZCH-X-Rays", "Masked_CXR")
|
|
|
mask_path = os.path.join("SZCH-X-Rays", "Mask")
|
|
|
|
|
|
model = torch.load("masked_lcm-600-2024-12-19-myModel.pth").to(device).eval()
|
|
|
VQGAN = torch.load("2024-12-12-Mask-SZCH-VQGAN.pth").to(device).eval()
|
|
|
testset_list = "SZCH.txt"
|
|
|
myTestSet = myC2BDataset(testset_list, cxr_path, masked_cxr_path, myTransform['testTransform'])
|
|
|
myTestLoader = DataLoader(myTestSet, batch_size=1, shuffle=False)
|
|
|
|
|
|
noise_scheduler = LCMScheduler(num_train_timesteps=config.num_train_timesteps,
|
|
|
clip_sample=config.clip_sample,
|
|
|
clip_sample_range=config.initial_clip_sample_range_g)
|
|
|
noise_scheduler.set_timesteps(config.num_infer_timesteps)
|
|
|
with torch.no_grad():
|
|
|
progress_bar = tqdm(enumerate(myTestLoader), total=len(myTestLoader), ncols=100)
|
|
|
total_start = time.time()
|
|
|
for step, batch in progress_bar:
|
|
|
cxr = batch[0].to(device=device, non_blocking=True).float()
|
|
|
masked_cxr = batch[1].to(device=device, non_blocking=True).float()
|
|
|
filename = batch[2][0]
|
|
|
cxr_copy = np.array(cxr.detach().to("cpu"))
|
|
|
cxr_copy = np.squeeze(cxr_copy)
|
|
|
cxr_copy = cxr_copy * 0.5 + 0.5
|
|
|
cxr_copy *= 255
|
|
|
cxr_copy = cxr_copy.astype(np.int8)
|
|
|
|
|
|
cxr = VQGAN.encode_stage_2_inputs(cxr)
|
|
|
masked_cxr = VQGAN.encode_stage_2_inputs(masked_cxr)
|
|
|
|
|
|
noise = torch.randn_like(cxr).to(device)
|
|
|
sample = torch.cat((noise, cxr), dim=1).to(device)
|
|
|
masked_sample = torch.cat((noise, masked_cxr), dim=1).to(device)
|
|
|
|
|
|
for j, t in tqdm(enumerate(noise_scheduler.timesteps)):
|
|
|
residual = model(sample, torch.Tensor((t,)).to(device).long()).to(device)
|
|
|
masked_residual = model(masked_sample, torch.Tensor((t,)).to(device).long()).to(device)
|
|
|
|
|
|
masked_residual = config.alpha * masked_residual + (1 - config.alpha) * torch.randn_like(
|
|
|
masked_residual).to(device) / torch.std(masked_residual)
|
|
|
|
|
|
noise_scheduler = LCMScheduler(num_train_timesteps=config.num_train_timesteps,
|
|
|
clip_sample=config.clip_sample,
|
|
|
clip_sample_range=
|
|
|
config.initial_clip_sample_range_g
|
|
|
+ config.clip_rate * j
|
|
|
)
|
|
|
noise_scheduler.set_timesteps(config.num_infer_timesteps)
|
|
|
sample = noise_scheduler.step(residual, t, sample).prev_sample
|
|
|
|
|
|
noise_scheduler = LCMScheduler(num_train_timesteps=config.num_train_timesteps,
|
|
|
clip_sample=config.clip_sample,
|
|
|
clip_sample_range=
|
|
|
config.initial_clip_sample_range_l
|
|
|
+ config.clip_rate * j
|
|
|
)
|
|
|
noise_scheduler.set_timesteps(config.num_infer_timesteps)
|
|
|
masked_sample = noise_scheduler.step(masked_residual, t, masked_sample).prev_sample
|
|
|
|
|
|
sample = torch.cat((sample[:, :4], cxr), dim=1)
|
|
|
masked_sample = torch.cat((masked_sample[:, :4], masked_cxr), dim=1).to(device)
|
|
|
if config.output_feature_map:
|
|
|
bs_show = np.array(sample[:, 0].detach().to("cpu"))
|
|
|
bs_show = np.squeeze(bs_show)
|
|
|
bs_show = bs_show * 0.5 + 0.5
|
|
|
bs_show = np.clip(bs_show, 0, 1)
|
|
|
|
|
|
masked_bs_show = np.array(masked_sample[:, 0].detach().to("cpu"))
|
|
|
masked_bs_show = np.squeeze(masked_bs_show)
|
|
|
masked_bs_show = masked_bs_show * 0.5 + 0.5
|
|
|
masked_bs_show = np.clip(masked_bs_show, 0, 1)
|
|
|
|
|
|
if not config.use_server:
|
|
|
cv.imshow("win1", bs_show)
|
|
|
cv.imshow("win2", masked_bs_show)
|
|
|
cv.waitKey(1)
|
|
|
|
|
|
mask = cv.imread(os.path.join(mask_path, filename), 0)
|
|
|
mask[mask < 255] = 0
|
|
|
|
|
|
bs = VQGAN.decode((sample[:, :4]))
|
|
|
bs = np.array(bs.detach().to("cpu"))
|
|
|
bs = np.squeeze(bs)
|
|
|
bs = bs * 0.5 + 0.5
|
|
|
bs[cxr_copy == 0] = 0
|
|
|
|
|
|
masked_bs = VQGAN.decode((masked_sample[:, :4]))
|
|
|
masked_bs = np.array(masked_bs.detach().to("cpu"))
|
|
|
masked_bs = np.squeeze(masked_bs)
|
|
|
masked_bs = masked_bs * 0.5 + 0.5
|
|
|
masked_bs[mask > 0] = masked_bs[mask > 0] + np.mean(bs[mask > 0]) - np.mean(masked_bs[mask > 0])
|
|
|
masked_bs[cxr_copy == 0] = 0
|
|
|
if not config.use_server:
|
|
|
cv.imshow("win3", bs)
|
|
|
cv.imshow("win4", masked_bs)
|
|
|
cv.waitKey(1)
|
|
|
|
|
|
bs *= 255
|
|
|
cv.imwrite(os.path.join(output_path, filename), bs)
|
|
|
masked_bs *= 255
|
|
|
cv.imwrite(os.path.join(masked_output_path, filename), masked_bs)
|
|
|
|
|
|
|
|
|
num_labels, labels, stats, _ = cv.connectedComponentsWithStats(mask)
|
|
|
min_area = 100
|
|
|
for i in range(1, num_labels):
|
|
|
if stats[i, cv.CC_STAT_AREA] < min_area:
|
|
|
labels[labels == i] = 0
|
|
|
mask[labels == 0] = 0
|
|
|
|
|
|
br = cv.boundingRect(mask)
|
|
|
p = (br[0] + br[2] // 2, br[1] + br[3] // 2)
|
|
|
|
|
|
masked_bs = np.clip(masked_bs, 0, 255)
|
|
|
masked_bs = cv.cvtColor(masked_bs, cv.COLOR_GRAY2BGR).astype(np.uint8)
|
|
|
bs = np.clip(bs, 0, 255)
|
|
|
bs = cv.cvtColor(bs, cv.COLOR_GRAY2BGR).astype(np.uint8)
|
|
|
|
|
|
fusion_bs = cv.seamlessClone(masked_bs, bs, mask, p, cv.MONOCHROME_TRANSFER)
|
|
|
|
|
|
|
|
|
|
|
|
cv.imwrite(os.path.join(fusion_output_path, filename), fusion_bs)
|
|
|
|
|
|
total_time = time.time() - total_start
|
|
|
print(f"Total time: {total_time}.")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
eval()
|
|
|
|