| import torch | |
| import data_utils as du | |
| def run_inference_two_channels(coil_complex_image, autoencoder, device="cuda"): | |
| coil_complex_image = du.normalize_complex_coil_image(coil_complex_image) | |
| two_channel_image = du.complex_to_two_channel_image(coil_complex_image) | |
| two_channel_tensor = torch.from_numpy(two_channel_image)[None,...].type(torch.FloatTensor).to(device) | |
| autoencoder = autoencoder.to(device) | |
| with torch.no_grad(): | |
| autoencoder_output = autoencoder.encode(two_channel_tensor) | |
| latents = autoencoder_output.latent_dist.mean | |
| decoded_image = autoencoder.decode(latents).sample | |
| recon = du.two_channel_to_complex_image(decoded_image.detach().cpu().numpy()) | |
| input = coil_complex_image | |
| return input, recon | |
| def run_inference_three_channels(coil_complex_image, autoencoder, device="cuda"): | |
| coil_complex_image = du.normalize_complex_coil_image(coil_complex_image) | |
| three_channel_image = du.create_three_channel_image(coil_complex_image) | |
| three_channel_tensor = torch.from_numpy(three_channel_image)[None,...].type(torch.FloatTensor).to(device) | |
| autoencoder = autoencoder.to(device) | |
| with torch.no_grad(): | |
| autoencoder_output = autoencoder.encode(three_channel_tensor) | |
| latents = autoencoder_output.latent_dist.mean | |
| decoded_image = autoencoder.decode(latents).sample | |
| recon = decoded_image[0].detach().cpu().numpy() | |
| input = three_channel_image | |
| return input, recon | |