My-AI-Projects commited on
Commit
c42a584
1 Parent(s): e949949
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ from tqdm import tqdm
6
+ import torch
7
+ from basicsr.archs.ddcolor_arch import DDColor
8
+ import torch.nn.functional as F
9
+ import gradio as gr
10
+ from gradio_imageslider import ImageSlider
11
+ import uuid
12
+
13
+ model_path = r"C:\Users\abohamam\Desktop\pytorch_model.pt"
14
+ input_size = 512
15
+ model_size = 'large'
16
+
17
+
18
+ # Create Image Colorization Pipeline
19
+ class ImageColorizationPipeline(object):
20
+
21
+ def __init__(self, model_path, input_size=256, model_size='large'):
22
+
23
+ self.input_size = input_size
24
+ if torch.cuda.is_available():
25
+ self.device = torch.device('cuda')
26
+ else:
27
+ self.device = torch.device('cpu')
28
+
29
+ if model_size == 'tiny':
30
+ self.encoder_name = 'convnext-t'
31
+ else:
32
+ self.encoder_name = 'convnext-l'
33
+
34
+ self.decoder_type = "MultiScaleColorDecoder"
35
+
36
+ if self.decoder_type == 'MultiScaleColorDecoder':
37
+ self.model = DDColor(
38
+ encoder_name=self.encoder_name,
39
+ decoder_name='MultiScaleColorDecoder',
40
+ input_size=[self.input_size, self.input_size],
41
+ num_output_channels=2,
42
+ last_norm='Spectral',
43
+ do_normalize=False,
44
+ num_queries=100,
45
+ num_scales=3,
46
+ dec_layers=9,
47
+ ).to(self.device)
48
+ else:
49
+ self.model = DDColor(
50
+ encoder_name=self.encoder_name,
51
+ decoder_name='SingleColorDecoder',
52
+ input_size=[self.input_size, self.input_size],
53
+ num_output_channels=2,
54
+ last_norm='Spectral',
55
+ do_normalize=False,
56
+ num_queries=256,
57
+ ).to(self.device)
58
+
59
+ self.model.load_state_dict(
60
+ torch.load(model_path, map_location=torch.device('cpu'))['params'],
61
+ strict=False)
62
+ self.model.eval()
63
+
64
+ @torch.no_grad()
65
+ def process(self, img):
66
+ self.height, self.width = img.shape[:2]
67
+ img = (img / 255.0).astype(np.float32)
68
+ orig_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1] # (h, w, 1)
69
+
70
+ # resize rgb image -> lab -> get grey -> rgb
71
+ img = cv2.resize(img, (self.input_size, self.input_size))
72
+ img_l = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)[:, :, :1]
73
+ img_gray_lab = np.concatenate((img_l, np.zeros_like(img_l), np.zeros_like(img_l)), axis=-1)
74
+ img_gray_rgb = cv2.cvtColor(img_gray_lab, cv2.COLOR_LAB2RGB)
75
+
76
+ tensor_gray_rgb = torch.from_numpy(img_gray_rgb.transpose((2, 0, 1))).float().unsqueeze(0).to(self.device)
77
+ output_ab = self.model(tensor_gray_rgb).cpu() # (1, 2, self.height, self.width)
78
+
79
+ # resize ab -> concat original l -> rgb
80
+ output_ab_resize = F.interpolate(output_ab, size=(self.height, self.width))[0].float().numpy().transpose(1, 2, 0)
81
+ output_lab = np.concatenate((orig_l, output_ab_resize), axis=-1)
82
+ output_bgr = cv2.cvtColor(output_lab, cv2.COLOR_LAB2BGR)
83
+
84
+ output_img = (output_bgr * 255.0).round().astype(np.uint8)
85
+
86
+ return output_img
87
+
88
+
89
+
90
+ def colorize_image(image):
91
+ """Colorizes a grayscale image using the DDColor model."""
92
+
93
+ # Convert image to grayscale if needed
94
+ img_array = np.array(image)
95
+ if len(img_array.shape) == 3 and img_array.shape[2] == 3:
96
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
97
+
98
+ # Colorize the image
99
+ colorized_img = colorizer.process(image)
100
+
101
+ # Convert colorized image to PIL format
102
+ colorized_img = Image.fromarray(colorized_img)
103
+
104
+ return colorized_img
105
+
106
+ # Create inference function for gradio app
107
+ def colorize(img):
108
+ image_out = colorizer.process(img)
109
+ # Generate a unique filename using UUID
110
+ unique_imgfilename = str(uuid.uuid4()) + '.png'
111
+ cv2.imwrite(unique_imgfilename, image_out)
112
+ return (img, unique_imgfilename)
113
+
114
+
115
+ # Gradio demo using the Image-Slider custom component
116
+ with gr.Blocks() as demo:
117
+ with gr.Row():
118
+ with gr.Column():
119
+ bw_image = gr.Image(label='Black and White Input Image')
120
+ btn = gr.Button('Convert using DDColor')
121
+ with gr.Column():
122
+ col_image_slider =ImageSlider(position=0.5,
123
+ label='Colored Image with Slider-view')
124
+
125
+ btn.click(colorize, bw_image, col_image_slider)
126
+ demo.launch()