soumickmj's picture
Upload ProbUNet
e257499 verified
raw
history blame contribute delete
No virus
2.45 kB
import sys
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .ProbUNet_model import InjectionConvEncoder2D, InjectionUNet2D, InjectionConvEncoder3D, InjectionUNet3D, ProbabilisticSegmentationNet
from .PULASkiConfigs import ProbUNetConfig
class ProbUNet(PreTrainedModel):
config_class = ProbUNetConfig
def __init__(self, config):
super().__init__(config)
if config.dim == 2:
task_op = InjectionUNet2D
prior_op = InjectionConvEncoder2D
posterior_op = InjectionConvEncoder2D
elif config.dim == 3:
task_op = InjectionUNet3D
prior_op = InjectionConvEncoder3D
posterior_op = InjectionConvEncoder3D
else:
sys.exit("Invalid dim! Only configured for dim 2 and 3.")
if config.latent_distribution == "normal":
latent_distribution = torch.distributions.Normal
else:
sys.exit("Invalid latent_distribution. Only normal has been implemented.")
self.model = ProbabilisticSegmentationNet(in_channels=config.in_channels,
out_channels=config.out_channels,
num_feature_maps=config.num_feature_maps,
latent_size=config.latent_size,
depth=config.depth,
latent_distribution=latent_distribution,
task_op=task_op,
task_kwargs={"output_activation_op": nn.Identity if config.no_outact_op else nn.Sigmoid,
"activation_kwargs": {"inplace": True}, "injection_at": config.prob_injection_at},
prior_op=prior_op,
prior_kwargs={"activation_kwargs": {"inplace": True}, "norm_depth": 2},
posterior_op=posterior_op,
posterior_kwargs={"activation_kwargs": {"inplace": True}, "norm_depth": 2},
)
def forward(self, x):
return self.model(x)