import torch import torch.nn as nn from .ProbUNet_utils import make_onehot as make_onehot_segmentation, make_slices, match_to def is_conv(op): conv_types = (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d) if type(op) == type and issubclass(op, conv_types): return True elif type(op) in conv_types: return True else: return False class ConvModule(nn.Module): def __init__(self, *args, **kwargs): super(ConvModule, self).__init__() def init_weights(self, init_fn, *args, **kwargs): class init_(object): def __init__(self): self.fn = init_fn self.args = args self.kwargs = kwargs def __call__(self, module): if is_conv(type(module)): module.weight = self.fn(module.weight, *self.args, **self.kwargs) _init_ = init_() self.apply(_init_) def init_bias(self, init_fn, *args, **kwargs): class init_(object): def __init__(self): self.fn = init_fn self.args = args self.kwargs = kwargs def __call__(self, module): if is_conv(type(module)) and module.bias is not None: module.bias = self.fn(module.bias, *self.args, **self.kwargs) _init_ = init_() self.apply(_init_) class ConcatCoords(nn.Module): def forward(self, input_): dim = input_.dim() - 2 coord_channels = [] for i in range(dim): view = [1, ] * dim view[i] = -1 repeat = list(input_.shape[2:]) repeat[i] = 1 coord_channels.append( torch.linspace(-0.5, 0.5, input_.shape[i+2]) .view(*view) .repeat(*repeat) .to(device=input_.device, dtype=input_.dtype)) coord_channels = torch.stack(coord_channels).unsqueeze(0) repeat = [1, ] * input_.dim() repeat[0] = input_.shape[0] coord_channels = coord_channels.repeat(*repeat).contiguous() return torch.cat([input_, coord_channels], 1) class InjectionConvEncoder(ConvModule): _default_activation_kwargs = dict(inplace=True) _default_norm_kwargs = dict() _default_conv_kwargs = dict(kernel_size=3, padding=1) _default_pool_kwargs = dict(kernel_size=2) _default_dropout_kwargs = dict() _default_global_pool_kwargs = dict() def __init__(self, in_channels=1, out_channels=6, depth=4, injection_depth="last", injection_channels=0, block_depth=2, num_feature_maps=24, feature_map_multiplier=2, activation_op=nn.LeakyReLU, activation_kwargs=None, norm_op=nn.InstanceNorm2d, norm_kwargs=None, norm_depth=0, conv_op=nn.Conv2d, conv_kwargs=None, pool_op=nn.AvgPool2d, pool_kwargs=None, dropout_op=None, dropout_kwargs=None, global_pool_op=nn.AdaptiveAvgPool2d, global_pool_kwargs=None, **kwargs): super(InjectionConvEncoder, self).__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.depth = depth self.injection_depth = depth - 1 if injection_depth == "last" else injection_depth self.injection_channels = injection_channels self.block_depth = block_depth self.num_feature_maps = num_feature_maps self.feature_map_multiplier = feature_map_multiplier self.activation_op = activation_op self.activation_kwargs = self._default_activation_kwargs if activation_kwargs is not None: self.activation_kwargs.update(activation_kwargs) self.norm_op = norm_op self.norm_kwargs = self._default_norm_kwargs if norm_kwargs is not None: self.norm_kwargs.update(norm_kwargs) self.norm_depth = depth if norm_depth == "full" else norm_depth self.conv_op = conv_op self.conv_kwargs = self._default_conv_kwargs if conv_kwargs is not None: self.conv_kwargs.update(conv_kwargs) self.pool_op = pool_op self.pool_kwargs = self._default_pool_kwargs if pool_kwargs is not None: self.pool_kwargs.update(pool_kwargs) self.dropout_op = dropout_op self.dropout_kwargs = self._default_dropout_kwargs if dropout_kwargs is not None: self.dropout_kwargs.update(dropout_kwargs) self.global_pool_op = global_pool_op self.global_pool_kwargs = self._default_global_pool_kwargs if global_pool_kwargs is not None: self.global_pool_kwargs.update(global_pool_kwargs) for d in range(self.depth): in_ = self.in_channels if d == 0 else self.num_feature_maps * (self.feature_map_multiplier**(d-1)) out_ = self.num_feature_maps * (self.feature_map_multiplier**d) if d == self.injection_depth + 1: in_ += self.injection_channels layers = [] if d > 0: layers.append(self.pool_op(**self.pool_kwargs)) for b in range(self.block_depth): current_in = in_ if b == 0 else out_ layers.append(self.conv_op(current_in, out_, **self.conv_kwargs)) if self.norm_op is not None and d < self.norm_depth: layers.append(self.norm_op(out_, **self.norm_kwargs)) if self.activation_op is not None: layers.append(self.activation_op(**self.activation_kwargs)) if self.dropout_op is not None: layers.append(self.dropout_op(**self.dropout_kwargs)) if d == self.depth - 1: current_conv_kwargs = self.conv_kwargs.copy() current_conv_kwargs["kernel_size"] = 1 current_conv_kwargs["padding"] = 0 current_conv_kwargs["bias"] = False layers.append(self.conv_op(out_, out_channels, **current_conv_kwargs)) self.add_module("encode_{}".format(d), nn.Sequential(*layers)) if self.global_pool_op is not None: self.add_module("global_pool", self.global_pool_op(1, **self.global_pool_kwargs)) def forward(self, x, injection=None): for d in range(self.depth): x = self._modules["encode_{}".format(d)](x) if d == self.injection_depth and self.injection_channels > 0: injection = match_to(injection, x, self.injection_channels) x = torch.cat([x, injection], 1) if hasattr(self, "global_pool"): x = self.global_pool(x) return x class InjectionConvEncoder3D(InjectionConvEncoder): def __init__(self, *args, **kwargs): update_kwargs = dict( norm_op=nn.InstanceNorm3d, conv_op=nn.Conv3d, pool_op=nn.AvgPool3d, global_pool_op=nn.AdaptiveAvgPool3d ) for (arg, val) in update_kwargs.items(): if arg not in kwargs: kwargs[arg] = val super(InjectionConvEncoder3D, self).__init__(*args, **kwargs) class InjectionConvEncoder2D(InjectionConvEncoder): #Created by Soumick def __init__(self, *args, **kwargs): update_kwargs = dict( norm_op=nn.InstanceNorm2d, conv_op=nn.Conv2d, pool_op=nn.AvgPool2d, global_pool_op=nn.AdaptiveAvgPool2d ) for (arg, val) in update_kwargs.items(): if arg not in kwargs: kwargs[arg] = val super(InjectionConvEncoder2D, self).__init__(*args, **kwargs) class InjectionUNet(ConvModule): def __init__( self, depth=5, in_channels=4, out_channels=4, kernel_size=3, dilation=1, num_feature_maps=24, block_depth=2, num_1x1_at_end=3, injection_channels=3, injection_at="end", activation_op=nn.LeakyReLU, activation_kwargs=None, pool_op=nn.AvgPool2d, pool_kwargs=dict(kernel_size=2), dropout_op=None, dropout_kwargs=None, norm_op=nn.InstanceNorm2d, norm_kwargs=None, conv_op=nn.Conv2d, conv_kwargs=None, upconv_op=nn.ConvTranspose2d, upconv_kwargs=None, output_activation_op=None, output_activation_kwargs=None, return_bottom=False, coords=False, coords_dim=2, **kwargs ): super(InjectionUNet, self).__init__(**kwargs) self.depth = depth self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.dilation = dilation self.padding = (self.kernel_size + (self.kernel_size-1) * (self.dilation-1)) // 2 self.num_feature_maps = num_feature_maps self.block_depth = block_depth self.num_1x1_at_end = num_1x1_at_end self.injection_channels = injection_channels self.injection_at = injection_at self.activation_op = activation_op self.activation_kwargs = {} if activation_kwargs is None else activation_kwargs self.pool_op = pool_op self.pool_kwargs = {} if pool_kwargs is None else pool_kwargs self.dropout_op = dropout_op self.dropout_kwargs = {} if dropout_kwargs is None else dropout_kwargs self.norm_op = norm_op self.norm_kwargs = {} if norm_kwargs is None else norm_kwargs self.conv_op = conv_op self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs self.upconv_op = upconv_op self.upconv_kwargs = {} if upconv_kwargs is None else upconv_kwargs self.output_activation_op = output_activation_op self.output_activation_kwargs = {} if output_activation_kwargs is None else output_activation_kwargs self.return_bottom = return_bottom if not coords: self.coords = [[], []] elif coords is True: self.coords = [list(range(depth)), []] else: self.coords = coords self.coords_dim = coords_dim self.last_activations = None # BUILD ENCODER for d in range(self.depth): block = [] if d > 0: block.append(self.pool_op(**self.pool_kwargs)) for i in range(self.block_depth): # bottom block fixed to have depth 1 if d == self.depth - 1 and i > 0: continue out_size = self.num_feature_maps * 2**d if d == 0 and i == 0: in_size = self.in_channels elif i == 0: in_size = self.num_feature_maps * 2**(d - 1) else: in_size = out_size # check for coord appending at this depth if d in self.coords[0] and i == 0: block.append(ConcatCoords()) in_size += self.coords_dim block.append(self.conv_op(in_size, out_size, self.kernel_size, padding=self.padding, dilation=self.dilation, **self.conv_kwargs)) if self.dropout_op is not None: block.append(self.dropout_op(**self.dropout_kwargs)) if self.norm_op is not None: block.append(self.norm_op(out_size, **self.norm_kwargs)) block.append(self.activation_op(**self.activation_kwargs)) self.add_module("encode-{}".format(d), nn.Sequential(*block)) # BUILD DECODER for d in reversed(range(self.depth)): block = [] for i in range(self.block_depth): # bottom block fixed to have depth 1 if d == self.depth - 1 and i > 0: continue out_size = self.num_feature_maps * 2**(d) if i == 0 and d < self.depth - 1: in_size = self.num_feature_maps * 2**(d+1) elif i == 0 and self.injection_at == "bottom": in_size = out_size + self.injection_channels else: in_size = out_size # check for coord appending at this depth if d in self.coords[0] and i == 0 and d < self.depth - 1: block.append(ConcatCoords()) in_size += self.coords_dim block.append(self.conv_op(in_size, out_size, self.kernel_size, padding=self.padding, dilation=self.dilation, **self.conv_kwargs)) if self.dropout_op is not None: block.append(self.dropout_op(**self.dropout_kwargs)) if self.norm_op is not None: block.append(self.norm_op(out_size, **self.norm_kwargs)) block.append(self.activation_op(**self.activation_kwargs)) if d > 0: block.append(self.upconv_op(out_size, out_size // 2, self.kernel_size, 2, padding=self.padding, dilation=self.dilation, output_padding=1, **self.upconv_kwargs)) self.add_module("decode-{}".format(d), nn.Sequential(*block)) if self.injection_at == "end": out_size += self.injection_channels in_size = out_size for i in range(self.num_1x1_at_end): if i == self.num_1x1_at_end - 1: out_size = self.out_channels current_conv_kwargs = self.conv_kwargs.copy() current_conv_kwargs["bias"] = True self.add_module("reduce-{}".format(i), self.conv_op(in_size, out_size, 1, **current_conv_kwargs)) if i != self.num_1x1_at_end - 1: self.add_module("reduce-{}-nonlin".format(i), self.activation_op(**self.activation_kwargs)) if self.output_activation_op is not None: self.add_module("output-activation", self.output_activation_op(**self.output_activation_kwargs)) def reset(self): self.last_activations = None def forward(self, x, injection=None, reuse_last_activations=False, store_activations=False): if self.injection_at == "bottom": # not worth it for now reuse_last_activations = False store_activations = False if self.last_activations is None or reuse_last_activations is False: enc = [x] for i in range(self.depth - 1): enc.append(self._modules["encode-{}".format(i)](enc[-1])) bottom_rep = self._modules["encode-{}".format(self.depth - 1)](enc[-1]) if self.injection_at == "bottom" and self.injection_channels > 0: injection = match_to(injection, bottom_rep, (0, 1)) bottom_rep = torch.cat((bottom_rep, injection), 1) x = self._modules["decode-{}".format(self.depth - 1)](bottom_rep) for i in reversed(range(self.depth - 1)): x = self._modules["decode-{}".format(i)](torch.cat((enc[-(self.depth - 1 - i)], x), 1)) if store_activations: self.last_activations = x.detach() else: x = self.last_activations if self.injection_at == "end" and self.injection_channels > 0: injection = match_to(injection, x, (0, 1)) x = torch.cat((x, injection), 1) for i in range(self.num_1x1_at_end): x = self._modules["reduce-{}".format(i)](x) if self.output_activation_op is not None: x = self._modules["output-activation"](x) if self.return_bottom and not reuse_last_activations: return x, bottom_rep else: return x class InjectionUNet3D(InjectionUNet): def __init__(self, *args, **kwargs): update_kwargs = dict( pool_op=nn.AvgPool3d, norm_op=nn.InstanceNorm3d, conv_op=nn.Conv3d, upconv_op=nn.ConvTranspose3d, coords_dim=3 ) for (arg, val) in update_kwargs.items(): if arg not in kwargs: kwargs[arg] = val super(InjectionUNet3D, self).__init__(*args, **kwargs) class InjectionUNet2D(InjectionUNet): #Created by Soumick def __init__(self, *args, **kwargs): update_kwargs = dict( pool_op=nn.AvgPool2d, norm_op=nn.InstanceNorm2d, conv_op=nn.Conv2d, upconv_op=nn.ConvTranspose2d, coords_dim=2 ) for (arg, val) in update_kwargs.items(): if arg not in kwargs: kwargs[arg] = val super(InjectionUNet2D, self).__init__(*args, **kwargs) class ProbabilisticSegmentationNet(ConvModule): def __init__(self, in_channels=4, out_channels=4, num_feature_maps=24, latent_size=3, depth=5, latent_distribution=torch.distributions.Normal, task_op=InjectionUNet3D, task_kwargs=None, prior_op=InjectionConvEncoder3D, prior_kwargs=None, posterior_op=InjectionConvEncoder3D, posterior_kwargs=None, **kwargs): super(ProbabilisticSegmentationNet, self).__init__(**kwargs) self.task_op = task_op self.task_kwargs = {} if task_kwargs is None else task_kwargs self.prior_op = prior_op self.prior_kwargs = {} if prior_kwargs is None else prior_kwargs self.posterior_op = posterior_op self.posterior_kwargs = {} if posterior_kwargs is None else posterior_kwargs default_task_kwargs = dict( in_channels=in_channels, out_channels=out_channels, num_feature_maps=num_feature_maps, injection_size=latent_size, depth=depth ) default_prior_kwargs = dict( in_channels=in_channels, out_channels=latent_size*2, #Soumick num_feature_maps=num_feature_maps, z_dim=latent_size, depth=depth ) default_posterior_kwargs = dict( in_channels=in_channels+out_channels, out_channels=latent_size*2, #Soumick num_feature_maps=num_feature_maps, z_dim=latent_size, depth=depth ) default_task_kwargs.update(self.task_kwargs) self.task_kwargs = default_task_kwargs default_prior_kwargs.update(self.prior_kwargs) self.prior_kwargs = default_prior_kwargs default_posterior_kwargs.update(self.posterior_kwargs) self.posterior_kwargs = default_posterior_kwargs self.latent_distribution = latent_distribution self._prior = None self._posterior = None self.make_modules() def make_modules(self): if type(self.task_op) == type: self.add_module("task_net", self.task_op(**self.task_kwargs)) else: self.add_module("task_net", self.task_op) if type(self.prior_op) == type: self.add_module("prior_net", self.prior_op(**self.prior_kwargs)) else: self.add_module("prior_net", self.prior_op) if type(self.posterior_op) == type: self.add_module("posterior_net", self.posterior_op(**self.posterior_kwargs)) else: self.add_module("posterior_net", self.posterior_op) @property def prior(self): return self._prior @property def posterior(self): return self._posterior @property def last_activations(self): return self.task_net.last_activations def train(self, mode=True): super(ProbabilisticSegmentationNet, self).train(mode) self.reset() def reset(self): self.task_net.reset() self._prior = None self._posterior = None def forward(self, input_, seg=None, make_onehot=True, make_onehot_classes=None, newaxis=False, distlossN=0): """Forward pass includes reparametrization sampling during training, otherwise it'll just take the prior mean.""" self.encode_prior(input_) if distlossN == 0: if self.training: self.encode_posterior(input_, seg, make_onehot, make_onehot_classes, newaxis) sample = self.posterior.rsample() else: sample = self.prior.loc return self.task_net(input_, sample, store_activations=not self.training) else: if self.training: self.encode_posterior(input_, seg, make_onehot, make_onehot_classes, newaxis) segs = [] for i in range(distlossN): sample = self.posterior.rsample() segs.append(self.task_net(input_, sample, store_activations=not self.training)) return segs #torch.concat(segs, dim=0) else: #I'm not totally sure about this!! sample = self.prior.loc return self.task_net(input_, sample, store_activations=not self.training) def encode_prior(self, input_): rep = self.prior_net(input_) if isinstance(rep, tuple): mean, logvar = rep elif torch.is_tensor(rep): mean, logvar = torch.split(rep, rep.shape[1] // 2, dim=1) self._prior = self.latent_distribution(mean, logvar.mul(0.5).exp()) return self._prior def encode_posterior(self, input_, seg, make_onehot=True, make_onehot_classes=None, newaxis=False): if make_onehot: if make_onehot_classes is None: make_onehot_classes = tuple(range(self.posterior_net.in_channels - input_.shape[1])) seg = make_onehot_segmentation(seg, make_onehot_classes, newaxis=newaxis) rep = self.posterior_net(torch.cat((input_, seg.float()), 1)) if isinstance(rep, tuple): mean, logvar = rep elif torch.is_tensor(rep): mean, logvar = torch.split(rep, rep.shape[1] // 2, dim=1) self._posterior = self.latent_distribution(mean, logvar.mul(0.5).exp()) return self._posterior def sample_prior(self, N=1, out_device=None, input_=None, pred_with_mean=False): """Draw multiple samples from the current prior. * input_ is required if no activations are stored in task_net. * If input_ is given, prior will automatically be encoded again. * Returns either a single sample or a list of samples. """ if out_device is None: if self.last_activations is not None: out_device = self.last_activations.device elif input_ is not None: out_device = input_.device else: out_device = next(self.task_net.parameters()).device with torch.no_grad(): if self.prior is None or input_ is not None: self.encode_prior(input_) result = [] if input_ is not None: result.append(self.task_net(input_, self.prior.sample(), reuse_last_activations=False, store_activations=True).to(device=out_device)) while len(result) < N: result.append(self.task_net(input_, self.prior.sample(), reuse_last_activations=self.last_activations is not None, store_activations=False).to(device=out_device)) if pred_with_mean: result.append(self.task_net(input_, self.prior.mean, reuse_last_activations=False, store_activations=True).to(device=out_device)) if len(result) == 1: return result[0] else: return result def reconstruct(self, sample=None, use_posterior_mean=True, out_device=None, input_=None): """Reconstruct a sample or the current posterior mean. Will not compute gradients!""" if self.posterior is None and sample is None: raise ValueError("'posterior' is currently None. Please pass an input and a segmentation first.") if out_device is None: out_device = next(self.task_net.parameters()).device if sample is None: if use_posterior_mean: sample = self.posterior.loc else: sample = self.posterior.sample() else: sample = sample.to(next(self.task_net.parameters()).device) with torch.no_grad(): return self.task_net(input_, sample, reuse_last_activations=True).to(device=out_device) def kl_divergence(self): """Compute current KL, requires existing prior and posterior.""" if self.posterior is None or self.prior is None: raise ValueError("'prior' and 'posterior' must not be None, but prior={} and posterior={}".format(self.prior, self.posterior)) return torch.distributions.kl_divergence(self.posterior, self.prior).sum() def elbo(self, seg, input_=None, nll_reduction="sum", beta=1.0, make_onehot=True, make_onehot_classes=None, newaxis=False): """Compute the ELBO with seg as ground truth. * Prior is expected and will not be encoded. * If input_ is given, posterior will automatically be encoded. * Either input_ or stored activations must be available. """ if self.last_activations is None: raise ValueError("'last_activations' is currently None. Please pass an input first.") if input_ is not None: with torch.no_grad(): self.encode_posterior(input_, seg, make_onehot=make_onehot, make_onehot_classes=make_onehot_classes, newaxis=newaxis) if make_onehot and newaxis: pass # seg will already be (B x SPACE) elif make_onehot and not newaxis: seg = seg[:, 0] # in this case seg will hopefully be (B x 1 x SPACE) else: seg = torch.argmax(seg, 1, keepdim=False) # seg is already onehot kl = self.kl_divergence() nll = nn.NLLLoss(reduction=nll_reduction)(self.reconstruct(sample=None, use_posterior_mean=True, out_device=None), seg.long()) return - (beta * nll + kl)