import torch import torch.nn as nn import torchvision import torch.nn.functional as F device = "cuda" if torch.cuda.is_available() else "cpu" image_transforms_rgb = torchvision.transforms.Compose([ torchvision.transforms.Resize((256, 256)), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.0,0.0,0.0], std=[1.0,1.0,1.0]), torchvision.transforms.Grayscale() ]) image_transforms_gs = torchvision.transforms.Compose([ torchvision.transforms.Resize((256, 256)), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.0], std=[1.0]), ]) class ConvBlock(nn.Module): def __init__(self, in_channel, out_channel): super(ConvBlock, self).__init__() self.main = nn.Sequential( nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channel), nn.ReLU(True), nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channel), nn.ReLU(True) ) def forward(self, x): return self.main(x) class UNETFruitColor(nn.Module): def __init__(self): super(UNETFruitColor, self).__init__() self.convs = [64, 128, 256, 512] self.convEncoder = nn.ModuleList() in_feature = 1 for conv in self.convs: self.convEncoder.append(ConvBlock(in_feature, conv)) in_feature = conv self.bottleNeck = ConvBlock(self.convs[-1], self.convs[-1]*2) in_feature = self.convs[-1]*2 self.convDecoder = nn.ModuleList() self.decoderUpConvs = nn.ModuleList() for conv in self.convs[::-1]: self.convDecoder.append(ConvBlock(in_feature, conv)) self.decoderUpConvs.append(nn.ConvTranspose2d(in_feature, conv, kernel_size=2, stride=2, padding=0)) in_feature = conv # final conv and deconv self.finalUpConv = nn.Conv2d(in_feature, 3, (1, 1)) self.sigmoid = nn.Sigmoid() def forward(self,x): skip_conns = [] for conv in self.convEncoder: # conv x = conv(x) # append for skip conns skip_conns.append(x) # max pool x = F.max_pool2d(x, (2,2), stride=2) x = self.bottleNeck(x) skip_conns = skip_conns[::-1] for idx in range(len(self.convDecoder)): # do upsample here upconv = self.decoderUpConvs[idx] deconv = self.convDecoder[idx] skp = skip_conns[idx] # do up conv x = upconv(x) # crop and cat x_cat = torchvision.transforms.Resize((x.shape[2], x.shape[3]))(skp) x = torch.cat([x_cat, x], dim=1) # do deconv x = deconv(x) # final x = self.finalUpConv(x) # x = self.sigmoid(x) return x model = UNETFruitColor() model = nn.DataParallel(model).to(device) model.load_state_dict(torch.load("unet_colorizer_flickr_5_93_Ploss_10_14K.pth", map_location=device),strict=True) model.eval()