Rohil Bansal
New commit
adae88f
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
from skimage.color import rgb2lab, lab2rgb
import os
from torch import nn
# Define the model architecture (same as in the training script)
class UNetBlock(nn.Module):
def __init__(self, in_channels, out_channels, down=True, bn=True, dropout=False):
super(UNetBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False) if down \
else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
self.bn = nn.BatchNorm2d(out_channels) if bn else None
self.dropout = nn.Dropout(0.5) if dropout else None
self.down = down
def forward(self, x):
x = self.conv(x)
if self.bn:
x = self.bn(x)
if self.dropout:
x = self.dropout(x)
return nn.ReLU()(x) if self.down else nn.ReLU(inplace=True)(x)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.down1 = UNetBlock(1, 64, bn=False)
self.down2 = UNetBlock(64, 128)
self.down3 = UNetBlock(128, 256)
self.down4 = UNetBlock(256, 512)
self.down5 = UNetBlock(512, 512)
self.down6 = UNetBlock(512, 512)
self.down7 = UNetBlock(512, 512)
self.down8 = UNetBlock(512, 512, bn=False)
self.up1 = UNetBlock(512, 512, down=False, dropout=True)
self.up2 = UNetBlock(1024, 512, down=False, dropout=True)
self.up3 = UNetBlock(1024, 512, down=False, dropout=True)
self.up4 = UNetBlock(1024, 512, down=False)
self.up5 = UNetBlock(1024, 256, down=False)
self.up6 = UNetBlock(512, 128, down=False)
self.up7 = UNetBlock(256, 64, down=False)
self.up8 = nn.ConvTranspose2d(128, 2, 4, 2, 1)
def forward(self, x):
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
d6 = self.down6(d5)
d7 = self.down7(d6)
d8 = self.down8(d7)
u1 = self.up1(d8)
u2 = self.up2(torch.cat([u1, d7], 1))
u3 = self.up3(torch.cat([u2, d6], 1))
u4 = self.up4(torch.cat([u3, d5], 1))
u5 = self.up5(torch.cat([u4, d4], 1))
u6 = self.up6(torch.cat([u5, d3], 1))
u7 = self.up7(torch.cat([u6, d2], 1))
return torch.tanh(self.up8(torch.cat([u7, d1], 1)))
# Load the checkpoint
def load_checkpoint(filename, generator, map_location):
if os.path.isfile(filename):
print(f"Loading checkpoint '{filename}'")
checkpoint = torch.load(filename, map_location=map_location)
generator.load_state_dict(checkpoint['generator_state_dict'])
print(f"Loaded checkpoint '{filename}' (epoch {checkpoint['epoch']})")
else:
print(f"No checkpoint found at '{filename}'")
# Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator().to(device)
checkpoint_path = "checkpoints/latest_checkpoint.pth.tar"
load_checkpoint(checkpoint_path, generator, map_location=device)
generator.eval()
# Define the transformation
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.Grayscale(num_output_channels=1), # Convert to grayscale
transforms.ToTensor()
])
# Define the inference function
def colorize_image(input_image):
try:
original_size = input_image.size
input_image = transform(input_image).unsqueeze(0).to(device)
with torch.no_grad():
output = generator(input_image)
output = output.squeeze(0).cpu().numpy()
L = input_image.squeeze(0).cpu().numpy()
L = (L + 1.) * 50.
ab = output * 128.
Lab = np.concatenate([L, ab], axis=0).transpose(1, 2, 0)
rgb_image = lab2rgb(Lab)
rgb_image = Image.fromarray((rgb_image * 255).astype(np.uint8))
rgb_image = rgb_image.resize(original_size, Image.LANCZOS)
return rgb_image
except Exception as e:
print(f"Error in colorize_image: {str(e)}")
return None
# Create the Gradio interface
iface = gr.Interface(
fn=colorize_image,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Image Colorizer",
description="Upload a grayscale image to colorize it."
)
# Launch the app
if __name__ == "__main__":
iface.launch(share=True)