File size: 1,466 Bytes
83e314f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import data_utils as du

def run_inference_two_channels(coil_complex_image, autoencoder, device="cuda"):
    coil_complex_image = du.normalize_complex_coil_image(coil_complex_image)
    two_channel_image = du.complex_to_two_channel_image(coil_complex_image)
    two_channel_tensor = torch.from_numpy(two_channel_image)[None,...].type(torch.FloatTensor).to(device)
    autoencoder = autoencoder.to(device)
    with torch.no_grad():
        autoencoder_output = autoencoder.encode(two_channel_tensor)
        latents = autoencoder_output.latent_dist.mean
        decoded_image = autoencoder.decode(latents).sample
    recon = du.two_channel_to_complex_image(decoded_image.detach().cpu().numpy())
    input = coil_complex_image
    return input, recon

def run_inference_three_channels(coil_complex_image, autoencoder, device="cuda"):
    coil_complex_image = du.normalize_complex_coil_image(coil_complex_image)
    three_channel_image = du.create_three_channel_image(coil_complex_image)
    three_channel_tensor = torch.from_numpy(three_channel_image)[None,...].type(torch.FloatTensor).to(device)
    autoencoder = autoencoder.to(device)
    with torch.no_grad():
        autoencoder_output = autoencoder.encode(three_channel_tensor)
        latents = autoencoder_output.latent_dist.mean
        decoded_image = autoencoder.decode(latents).sample
    recon = decoded_image[0].detach().cpu().numpy()
    input = three_channel_image
    return input, recon