nguyenp99's picture
Upload 17 files
45099b6 verified
raw
history blame
6.79 kB
from typing import List
import numpy as np
import torch
from fastai.vision.all import *
# from backend.StampRemoval.util import *
__all__ = ["CustomUnetBlock", "CustomDynamicUnet", "UnetInference"]
class CustomUnetBlock(Module):
"""A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."""
@delegates(ConvLayer.__init__)
def __init__(
self,
up_in_c,
x_in_c,
hook,
final_div=True,
blur=False,
act_cls=defaults.activation,
self_attention=False,
init=nn.init.kaiming_normal_,
norm_type=None,
**kwargs,
):
self.hook = hook
self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c // 2, blur=blur, act_cls=act_cls, norm_type=norm_type)
self.bn = BatchNorm(x_in_c)
ni = up_in_c // 2 + x_in_c
# nf = ni if final_div else ni//2
nf = ni // 2 if final_div else ni // 4
self.conv1 = ConvLayer(ni, nf, act_cls=act_cls, norm_type=norm_type, **kwargs)
self.conv2 = ConvLayer(
nf,
nf,
act_cls=act_cls,
norm_type=norm_type,
xtra=SelfAttention(nf) if self_attention else None,
**kwargs,
)
self.relu = act_cls()
apply_init(nn.Sequential(self.conv1, self.conv2), init)
def forward(self, up_in):
s = self.hook.stored
up_out = self.shuf(up_in)
ssh = s.shape[-2:]
if ssh != up_out.shape[-2:]:
up_out = F.interpolate(up_out, s.shape[-2:], mode="nearest")
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
return self.conv2(self.conv1(cat_x))
class CustomDynamicUnet(SequentialEx):
"""Create a U-Net from a given architecture."""
def __init__(
self,
encoder,
n_out,
img_size,
blur=False,
blur_final=True,
self_attention=False,
y_range=None,
last_cross=True,
bottle=False,
act_cls=defaults.activation,
init=nn.init.kaiming_normal_,
norm_type=None,
**kwargs,
):
imsize = img_size
sizes = model_sizes(encoder, size=imsize)
sz_chg_idxs = list(reversed(self._get_sz_change_idxs(sizes)))
self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)
x = dummy_eval(encoder, imsize).detach()
ni = sizes[-1][1]
middle_conv = nn.Sequential(
ConvLayer(ni, ni, act_cls=act_cls, norm_type=norm_type, **kwargs),
ConvLayer(ni, ni, act_cls=act_cls, norm_type=norm_type, **kwargs),
).eval()
x = middle_conv(x)
layers = [encoder, BatchNorm(ni), nn.ReLU(), middle_conv]
for i, idx in enumerate(sz_chg_idxs):
not_final = i != len(sz_chg_idxs) - 1
up_in_c, x_in_c = int(x.shape[1]), int(sizes[idx][1])
do_blur = blur and (not_final or blur_final)
sa = self_attention and (i == len(sz_chg_idxs) - 3)
unet_block = CustomUnetBlock(
up_in_c,
x_in_c,
self.sfs[i],
final_div=not_final,
blur=do_blur,
self_attention=sa,
act_cls=act_cls,
init=init,
norm_type=norm_type,
**kwargs,
).eval()
layers.append(unet_block)
x = unet_block(x)
ni = x.shape[1]
if imsize != sizes[0][-2:]:
layers.append(PixelShuffle_ICNR(ni, act_cls=act_cls, norm_type=norm_type))
layers.append(ResizeToOrig())
if last_cross:
layers.append(MergeLayer(dense=True))
ni += in_channels(encoder)
layers.append(
ResBlock(
1,
ni,
ni // 2 if bottle else ni,
act_cls=act_cls,
norm_type=norm_type,
**kwargs,
)
)
layers += [ConvLayer(ni, n_out, ks=1, act_cls=None, norm_type=norm_type, **kwargs)]
apply_init(nn.Sequential(layers[3], layers[-2]), init)
# apply_init(nn.Sequential(layers[2]), init)
if y_range is not None:
layers.append(SigmoidRange(*y_range))
super().__init__(*layers)
def _get_sz_change_idxs(self, sizes):
"Get the indexes of the layers where the size of the activation changes."
feature_szs = [size[-1] for size in sizes]
sz_chg_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
return sz_chg_idxs
def __del__(self):
if hasattr(self, "sfs"):
self.sfs.remove()
class PerceptualLoss:
pass
class UnetInference:
def __init__(self, model_path):
"""Inference interface for unet model"""
self.learn = load_learner(model_path)
self.learn.model.eval()
def __call__(self, image_array: str, bs: int = 1) -> List[np.ndarray]:
"""Perform forward pass and decode the prediction of Unet model
Args:
image_array (list): list of numpy array
bs (int, optional): [batch size]. Defaults to 1.
Returns:
[list]: list of numpy array
"""
if len(image_array) < 1:
return []
batches = self.__build_batches(image_array, bs=bs)
outs = []
with torch.no_grad():
for b in batches:
outs.append(self.learn.model(b))
del b
pil_images = self.__decode_prediction(outs)
return pil_images
def __decode_prediction(self, preds):
out = []
i2f = IntToFloatTensor()
for pred in preds:
img_np = i2f.decodes(pred.squeeze()).numpy()
img_np = img_np.transpose(1, 2, 0)
img_np = img_np.astype(np.uint8)
out.append(img_np)
# out.append(Image.fromarray(img_np))
del img_np
return out
def __build_batches(self, image_array: list, bs=1):
"Builds batches to skip `DataLoader` overhead"
type_tfms = [PILImage.create]
item_tfms = [ToTensor()]
type_pipe = Pipeline(type_tfms)
item_pipe = Pipeline(item_tfms)
i2f = IntToFloatTensor()
batches = []
batch = []
k = 0
for i, im in enumerate(image_array):
batch.append(item_pipe(type_pipe(im)))
k += 1
if i == len(image_array) - 1 or k == bs:
# batches.append(torch.cat([norm(i2f(b.cuda())) for b in batch]))
batches.append(torch.stack([i2f(b.cpu()) for b in batch], axis=0))
batch = []
k = 0
return batches