import sys import torch import torch.nn as nn from transformers import PreTrainedModel from .ProbUNet_model import InjectionConvEncoder2D, InjectionUNet2D, InjectionConvEncoder3D, InjectionUNet3D, ProbabilisticSegmentationNet from .PULASkiConfigs import ProbUNetConfig class ProbUNet(PreTrainedModel): config_class = ProbUNetConfig def __init__(self, config): super().__init__(config) if config.dim == 2: task_op = InjectionUNet2D prior_op = InjectionConvEncoder2D posterior_op = InjectionConvEncoder2D elif config.dim == 3: task_op = InjectionUNet3D prior_op = InjectionConvEncoder3D posterior_op = InjectionConvEncoder3D else: sys.exit("Invalid dim! Only configured for dim 2 and 3.") if config.latent_distribution == "normal": latent_distribution = torch.distributions.Normal else: sys.exit("Invalid latent_distribution. Only normal has been implemented.") self.model = ProbabilisticSegmentationNet(in_channels=config.in_channels, out_channels=config.out_channels, num_feature_maps=config.num_feature_maps, latent_size=config.latent_size, depth=config.depth, latent_distribution=latent_distribution, task_op=task_op, task_kwargs={"output_activation_op": nn.Identity if config.no_outact_op else nn.Sigmoid, "activation_kwargs": {"inplace": True}, "injection_at": config.prob_injection_at}, prior_op=prior_op, prior_kwargs={"activation_kwargs": {"inplace": True}, "norm_depth": 2}, posterior_op=posterior_op, posterior_kwargs={"activation_kwargs": {"inplace": True}, "norm_depth": 2}, ) def forward(self, x): return self.model(x)