zaidmehdi commited on
Commit
01d43bf
1 Parent(s): d62571a

Upload folder using huggingface_hub

Browse files
src/__pycache__/model.cpython-311.pyc ADDED
Binary file (2.14 kB). View file
 
src/__pycache__/model.cpython-312.pyc ADDED
Binary file (1.93 kB). View file
 
src/__pycache__/utils.cpython-311.pyc ADDED
Binary file (3.84 kB). View file
 
src/__pycache__/utils.cpython-312.pyc ADDED
Binary file (3.45 kB). View file
 
src/get_data.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import shutil
4
+
5
+
6
+ def extract_images(source_folder, destination_folder):
7
+ count = 0
8
+ for root, _, files in os.walk(source_folder):
9
+ for file in files:
10
+ if file.endswith(('jpg', '.png')):
11
+ src_path = os.path.join(root, file)
12
+ dst_path = os.path.join(destination_folder, f"{count:05d}" + os.path.splitext(file)[1])
13
+ shutil.copy(src_path, dst_path)
14
+ count += 1
15
+
16
+
17
+ def split_data(data_folder):
18
+ train_folder = f"{data_folder}/train"
19
+ validation_folder = f"{data_folder}/validation"
20
+ test_folder = f"{data_folder}/test"
21
+
22
+ for folder in [train_folder, validation_folder, test_folder]:
23
+ if not os.path.exists(folder):
24
+ os.makedirs(folder)
25
+
26
+ image_files = [f for f in os.listdir(data_folder) if os.path.isfile(os.path.join(data_folder, f))]
27
+ random.shuffle(image_files)
28
+
29
+ total_images = len(image_files)
30
+ train_count = int(0.7 * total_images)
31
+ validation_count = int(0.2 * total_images)
32
+
33
+ for i in range(train_count):
34
+ shutil.move(os.path.join(data_folder, image_files[i]), train_folder)
35
+
36
+ for i in range(train_count, train_count + validation_count):
37
+ shutil.move(os.path.join(data_folder, image_files[i]), validation_folder)
38
+
39
+ for i in range(train_count + validation_count, total_images):
40
+ shutil.move(os.path.join(data_folder, image_files[i]), test_folder)
41
+
42
+
43
+ if __name__ == "__main__":
44
+ source_folder = "manga/"
45
+ destination_folder = "data/"
46
+ extract_images(source_folder, destination_folder)
47
+
48
+ data_folder = "data/"
49
+ split_data(data_folder)
src/main.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from PIL import Image
6
+
7
+ from model import MangaColorizer
8
+ from utils import pil_to_torch, torch_to_pil
9
+
10
+ def load_html_template():
11
+ html_dir = os.path.join(os.path.dirname(__file__), "templates")
12
+ index_html_path = os.path.join(html_dir, "index.html")
13
+
14
+ if os.path.exists(index_html_path):
15
+ with open(index_html_path, "r") as html_file:
16
+ index_html = html_file.read()
17
+ return index_html
18
+ else:
19
+ print(f"Error: {index_html_path} not found.")
20
+
21
+
22
+ def load_model():
23
+ model = MangaColorizer()
24
+ models_dir = os.path.join(os.path.dirname(__file__), '..', 'model')
25
+ model_file = os.path.join(models_dir, 'best_model_checkpoint.pth')
26
+ if os.path.exists(model_file):
27
+ with open(model_file, "rb") as f:
28
+ checkpoint = torch.load(f, map_location="cpu")
29
+ model.load_state_dict(checkpoint)
30
+ else:
31
+ print(f"Error: {model_file} not found.")
32
+
33
+ return model
34
+ model = load_model()
35
+
36
+ def colorize_image(image):
37
+ global model
38
+ img = Image.fromarray(image).convert("L")
39
+ output = model(pil_to_torch(img)).detach().cpu()
40
+ output_image = torch_to_pil(output)
41
+
42
+ return output_image
43
+
44
+
45
+ def main():
46
+ index_html = load_html_template()
47
+ with gr.Blocks() as demo:
48
+ gr.HTML(index_html)
49
+ gr.Interface(colorize_image, inputs=["image"], outputs=["image"], allow_flagging="never")
50
+ gr.HTML("""
51
+ <p style="text-align: center;font-size: large;">
52
+ Checkout the <a href="https://github.com/zaidmehdi/manga-colorizer">Github Repo</a>
53
+ </p>
54
+ """)
55
+ demo.launch(server_name="0.0.0.0", server_port=8080)
56
+
57
+
58
+ if __name__ == "__main__":
59
+ main()
src/model.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class MangaColorizer(nn.Module):
5
+ def __init__(self):
6
+ super(MangaColorizer, self).__init__()
7
+ self.encoder = nn.Sequential(
8
+ nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
9
+ nn.ReLU(inplace=True),
10
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
11
+ nn.ReLU(inplace=True),
12
+ nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
13
+ nn.ReLU(inplace=True)
14
+ )
15
+ self.decoder = nn.Sequential(
16
+ nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
17
+ nn.ReLU(inplace=True),
18
+ nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
19
+ nn.ReLU(inplace=True),
20
+ nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1),
21
+ nn.Tanh()
22
+ )
23
+
24
+ def forward(self, x):
25
+ x = self.encoder(x)
26
+ x = self.decoder(x)
27
+ return x
src/model_training.ipynb ADDED
@@ -0,0 +1,838 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "06eef311",
7
+ "metadata": {
8
+ "execution": {
9
+ "iopub.execute_input": "2024-04-01T01:04:34.377065Z",
10
+ "iopub.status.busy": "2024-04-01T01:04:34.376156Z",
11
+ "iopub.status.idle": "2024-04-01T01:04:41.159287Z",
12
+ "shell.execute_reply": "2024-04-01T01:04:41.158448Z"
13
+ },
14
+ "papermill": {
15
+ "duration": 6.79219,
16
+ "end_time": "2024-04-01T01:04:41.161646",
17
+ "exception": false,
18
+ "start_time": "2024-04-01T01:04:34.369456",
19
+ "status": "completed"
20
+ },
21
+ "tags": []
22
+ },
23
+ "outputs": [],
24
+ "source": [
25
+ "import matplotlib.pyplot as plt\n",
26
+ "import numpy as np\n",
27
+ "import torch\n",
28
+ "import torch.nn as nn\n",
29
+ "import torch.optim as optim\n",
30
+ "from torchvision import transforms\n",
31
+ "from torch.utils.data import DataLoader\n",
32
+ "from tqdm.auto import tqdm\n",
33
+ "\n",
34
+ "from model import MangaColorizer\n",
35
+ "from utils import ImageDataset, adjust_output_shape"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "markdown",
40
+ "id": "5e7ff784",
41
+ "metadata": {
42
+ "papermill": {
43
+ "duration": 0.004403,
44
+ "end_time": "2024-04-01T01:04:41.171084",
45
+ "exception": false,
46
+ "start_time": "2024-04-01T01:04:41.166681",
47
+ "status": "completed"
48
+ },
49
+ "tags": []
50
+ },
51
+ "source": [
52
+ "## Model architecture"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": 3,
58
+ "id": "87d03ce6",
59
+ "metadata": {
60
+ "execution": {
61
+ "iopub.execute_input": "2024-04-01T01:04:41.182184Z",
62
+ "iopub.status.busy": "2024-04-01T01:04:41.181258Z",
63
+ "iopub.status.idle": "2024-04-01T01:04:41.190651Z",
64
+ "shell.execute_reply": "2024-04-01T01:04:41.189724Z"
65
+ },
66
+ "papermill": {
67
+ "duration": 0.017191,
68
+ "end_time": "2024-04-01T01:04:41.192743",
69
+ "exception": false,
70
+ "start_time": "2024-04-01T01:04:41.175552",
71
+ "status": "completed"
72
+ },
73
+ "tags": []
74
+ },
75
+ "outputs": [
76
+ {
77
+ "data": {
78
+ "text/plain": [
79
+ "MangaColorizer(\n",
80
+ " (encoder): Sequential(\n",
81
+ " (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
82
+ " (1): ReLU(inplace=True)\n",
83
+ " (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
84
+ " (3): ReLU(inplace=True)\n",
85
+ " (4): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
86
+ " (5): ReLU(inplace=True)\n",
87
+ " )\n",
88
+ " (decoder): Sequential(\n",
89
+ " (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
90
+ " (1): ReLU(inplace=True)\n",
91
+ " (2): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
92
+ " (3): ReLU(inplace=True)\n",
93
+ " (4): ConvTranspose2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
94
+ " (5): Tanh()\n",
95
+ " )\n",
96
+ ")"
97
+ ]
98
+ },
99
+ "execution_count": 3,
100
+ "metadata": {},
101
+ "output_type": "execute_result"
102
+ }
103
+ ],
104
+ "source": [
105
+ "model = MangaColorizer()\n",
106
+ "model"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "markdown",
111
+ "id": "c4b5ff4a",
112
+ "metadata": {
113
+ "papermill": {
114
+ "duration": 0.004206,
115
+ "end_time": "2024-04-01T01:04:41.201565",
116
+ "exception": false,
117
+ "start_time": "2024-04-01T01:04:41.197359",
118
+ "status": "completed"
119
+ },
120
+ "tags": []
121
+ },
122
+ "source": [
123
+ "## Loading the Data"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": 5,
129
+ "id": "42198e39",
130
+ "metadata": {
131
+ "execution": {
132
+ "iopub.execute_input": "2024-04-01T01:04:41.247525Z",
133
+ "iopub.status.busy": "2024-04-01T01:04:41.247244Z",
134
+ "iopub.status.idle": "2024-04-01T01:04:41.627306Z",
135
+ "shell.execute_reply": "2024-04-01T01:04:41.626292Z"
136
+ },
137
+ "papermill": {
138
+ "duration": 0.387773,
139
+ "end_time": "2024-04-01T01:04:41.629778",
140
+ "exception": false,
141
+ "start_time": "2024-04-01T01:04:41.242005",
142
+ "status": "completed"
143
+ },
144
+ "tags": []
145
+ },
146
+ "outputs": [],
147
+ "source": [
148
+ "transform = transforms.Compose([\n",
149
+ " transforms.ToTensor()\n",
150
+ "])\n",
151
+ "\n",
152
+ "train_dataset = ImageDataset(dir=\"/kaggle/input/manga-panels-colored/data/train\", \n",
153
+ " transform=transform)\n",
154
+ "validation_dataset = ImageDataset(dir=\"/kaggle/input/manga-panels-colored/data/validation\", \n",
155
+ " transform=transform)\n",
156
+ "test_dataset = ImageDataset(dir=\"/kaggle/input/manga-panels-colored/data/test\", \n",
157
+ " transform=transform)"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": 6,
163
+ "id": "8e5ea6dd",
164
+ "metadata": {
165
+ "execution": {
166
+ "iopub.execute_input": "2024-04-01T01:04:41.640890Z",
167
+ "iopub.status.busy": "2024-04-01T01:04:41.640098Z",
168
+ "iopub.status.idle": "2024-04-01T01:04:41.645385Z",
169
+ "shell.execute_reply": "2024-04-01T01:04:41.644485Z"
170
+ },
171
+ "papermill": {
172
+ "duration": 0.012881,
173
+ "end_time": "2024-04-01T01:04:41.647392",
174
+ "exception": false,
175
+ "start_time": "2024-04-01T01:04:41.634511",
176
+ "status": "completed"
177
+ },
178
+ "tags": []
179
+ },
180
+ "outputs": [],
181
+ "source": [
182
+ "train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)\n",
183
+ "validation_loader = DataLoader(validation_dataset, batch_size=1, shuffle=True)\n",
184
+ "test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "markdown",
189
+ "id": "fd0bbc4c",
190
+ "metadata": {
191
+ "papermill": {
192
+ "duration": 0.004236,
193
+ "end_time": "2024-04-01T01:04:41.656333",
194
+ "exception": false,
195
+ "start_time": "2024-04-01T01:04:41.652097",
196
+ "status": "completed"
197
+ },
198
+ "tags": []
199
+ },
200
+ "source": [
201
+ "## Training the model"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": 8,
207
+ "id": "6bb853cd",
208
+ "metadata": {
209
+ "execution": {
210
+ "iopub.execute_input": "2024-04-01T01:04:41.683769Z",
211
+ "iopub.status.busy": "2024-04-01T01:04:41.683460Z",
212
+ "iopub.status.idle": "2024-04-01T01:04:41.721713Z",
213
+ "shell.execute_reply": "2024-04-01T01:04:41.720922Z"
214
+ },
215
+ "papermill": {
216
+ "duration": 0.04614,
217
+ "end_time": "2024-04-01T01:04:41.724036",
218
+ "exception": false,
219
+ "start_time": "2024-04-01T01:04:41.677896",
220
+ "status": "completed"
221
+ },
222
+ "tags": []
223
+ },
224
+ "outputs": [],
225
+ "source": [
226
+ "criterion = nn.MSELoss()\n",
227
+ "optimizer = optim.Adam(model.parameters(), lr=0.0001)"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": null,
233
+ "id": "7b70952d",
234
+ "metadata": {
235
+ "execution": {
236
+ "iopub.execute_input": "2024-04-01T01:04:41.735281Z",
237
+ "iopub.status.busy": "2024-04-01T01:04:41.734495Z",
238
+ "iopub.status.idle": "2024-04-01T01:04:41.955794Z",
239
+ "shell.execute_reply": "2024-04-01T01:04:41.954865Z"
240
+ },
241
+ "papermill": {
242
+ "duration": 0.229072,
243
+ "end_time": "2024-04-01T01:04:41.957864",
244
+ "exception": false,
245
+ "start_time": "2024-04-01T01:04:41.728792",
246
+ "status": "completed"
247
+ },
248
+ "tags": []
249
+ },
250
+ "outputs": [],
251
+ "source": [
252
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
253
+ "model.to(device)"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": null,
259
+ "id": "4d252b1e",
260
+ "metadata": {
261
+ "execution": {
262
+ "iopub.execute_input": "2024-04-01T01:04:41.969269Z",
263
+ "iopub.status.busy": "2024-04-01T01:04:41.968485Z",
264
+ "iopub.status.idle": "2024-04-01T10:06:12.760575Z",
265
+ "shell.execute_reply": "2024-04-01T10:06:12.759664Z"
266
+ },
267
+ "papermill": {
268
+ "duration": 32490.811819,
269
+ "end_time": "2024-04-01T10:06:12.774567",
270
+ "exception": false,
271
+ "start_time": "2024-04-01T01:04:41.962748",
272
+ "status": "completed"
273
+ },
274
+ "tags": []
275
+ },
276
+ "outputs": [],
277
+ "source": [
278
+ "num_epochs = 100\n",
279
+ "num_training_steps = num_epochs * len(train_loader)\n",
280
+ "progress_bar = tqdm(range(num_training_steps))\n",
281
+ "\n",
282
+ "train_losses = []\n",
283
+ "valid_losses = []\n",
284
+ "\n",
285
+ "best_valid_loss = float(\"inf\")\n",
286
+ "epochs_no_improve = 0\n",
287
+ "patience = 10\n",
288
+ "best_model = None\n",
289
+ "\n",
290
+ "for epoch in range(num_epochs):\n",
291
+ " model.train()\n",
292
+ " train_loss = 0.0\n",
293
+ " for images, targets in train_loader:\n",
294
+ " images = images.to(device)\n",
295
+ " targets = targets.to(device)\n",
296
+ " outputs = model(images)\n",
297
+ " try:\n",
298
+ " loss = criterion(outputs, targets)\n",
299
+ " except RuntimeError:\n",
300
+ " adjusted_output = adjust_output_shape(outputs, targets)\n",
301
+ " loss = criterion(adjusted_output, targets)\n",
302
+ " loss.backward()\n",
303
+ "\n",
304
+ " optimizer.step()\n",
305
+ " optimizer.zero_grad()\n",
306
+ " progress_bar.update(1)\n",
307
+ "\n",
308
+ " train_loss += loss.item()\n",
309
+ " \n",
310
+ " train_losses.append(train_loss / len(train_loader))\n",
311
+ "\n",
312
+ " model.eval()\n",
313
+ " valid_loss = 0.0\n",
314
+ " with torch.no_grad():\n",
315
+ " for images, targets in validation_loader:\n",
316
+ " images = images.to(device)\n",
317
+ " targets = targets.to(device)\n",
318
+ " outputs = model(images)\n",
319
+ " try:\n",
320
+ " loss = criterion(outputs, targets)\n",
321
+ " except RuntimeError:\n",
322
+ " adjusted_output = adjust_output_shape(outputs, targets)\n",
323
+ " loss = criterion(adjusted_output, targets)\n",
324
+ " valid_loss += loss.item()\n",
325
+ " \n",
326
+ " valid_loss /= len(validation_loader)\n",
327
+ " valid_losses.append(valid_loss)\n",
328
+ "\n",
329
+ " print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_losses[-1]:.4f}, Valid Loss: {valid_loss:.4f}') \n",
330
+ " torch.save(model.state_dict(), \"last_checkpoint.pth\")\n",
331
+ "\n",
332
+ " if valid_loss < best_valid_loss:\n",
333
+ " best_valid_loss = valid_loss\n",
334
+ " epochs_no_improve = 0\n",
335
+ " best_model = model.state_dict()\n",
336
+ " torch.save(best_model, \"best_model_checkpoint.pth\")\n",
337
+ " else:\n",
338
+ " epochs_no_improve += 1\n",
339
+ " if epochs_no_improve == patience:\n",
340
+ " print(f\"Early stopping after {epoch+1} epochs with no improvement.\")\n",
341
+ " break\n",
342
+ "\n",
343
+ "model.load_state_dict(best_model)"
344
+ ]
345
+ },
346
+ {
347
+ "cell_type": "code",
348
+ "execution_count": 11,
349
+ "id": "e3d447f1",
350
+ "metadata": {
351
+ "execution": {
352
+ "iopub.execute_input": "2024-04-01T10:06:12.800929Z",
353
+ "iopub.status.busy": "2024-04-01T10:06:12.800630Z",
354
+ "iopub.status.idle": "2024-04-01T10:06:13.054441Z",
355
+ "shell.execute_reply": "2024-04-01T10:06:13.053544Z"
356
+ },
357
+ "papermill": {
358
+ "duration": 0.269281,
359
+ "end_time": "2024-04-01T10:06:13.056431",
360
+ "exception": false,
361
+ "start_time": "2024-04-01T10:06:12.787150",
362
+ "status": "completed"
363
+ },
364
+ "tags": []
365
+ },
366
+ "outputs": [
367
+ {
368
+ "data": {
369
+ "image/png": "",
370
+ "text/plain": [
371
+ "<Figure size 1000x500 with 1 Axes>"
372
+ ]
373
+ },
374
+ "metadata": {},
375
+ "output_type": "display_data"
376
+ }
377
+ ],
378
+ "source": [
379
+ "epochs = np.arange(1, len(train_losses) + 1)\n",
380
+ "plt.figure(figsize=(10, 5))\n",
381
+ "\n",
382
+ "plt.plot(epochs, train_losses, label='Train Loss')\n",
383
+ "plt.plot(epochs, valid_losses, label='Valid Loss')\n",
384
+ "plt.xlabel('Epoch')\n",
385
+ "plt.ylabel('Loss')\n",
386
+ "plt.title('Training and Validation Loss')\n",
387
+ "plt.legend()\n",
388
+ "\n",
389
+ "plt.show()"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": 12,
395
+ "id": "c3b04bf7",
396
+ "metadata": {
397
+ "execution": {
398
+ "iopub.execute_input": "2024-04-01T10:06:13.084256Z",
399
+ "iopub.status.busy": "2024-04-01T10:06:13.083932Z",
400
+ "iopub.status.idle": "2024-04-01T10:06:38.988309Z",
401
+ "shell.execute_reply": "2024-04-01T10:06:38.987316Z"
402
+ },
403
+ "papermill": {
404
+ "duration": 25.920465,
405
+ "end_time": "2024-04-01T10:06:38.990405",
406
+ "exception": false,
407
+ "start_time": "2024-04-01T10:06:13.069940",
408
+ "status": "completed"
409
+ },
410
+ "tags": []
411
+ },
412
+ "outputs": [
413
+ {
414
+ "name": "stdout",
415
+ "output_type": "stream",
416
+ "text": [
417
+ "The loss on the test set is 0.008598181701702099\n"
418
+ ]
419
+ }
420
+ ],
421
+ "source": [
422
+ "model.eval()\n",
423
+ "test_loss = 0.0\n",
424
+ "with torch.no_grad():\n",
425
+ " for images, targets in test_loader:\n",
426
+ " images = images.to(device)\n",
427
+ " targets = targets.to(device)\n",
428
+ " outputs = model(images)\n",
429
+ " try:\n",
430
+ " loss = criterion(outputs, targets)\n",
431
+ " except RuntimeError:\n",
432
+ " adjusted_output = adjust_output_shape(outputs, targets)\n",
433
+ " loss = criterion(adjusted_output, targets)\n",
434
+ " test_loss += loss.item()\n",
435
+ " test_loss /= len(test_loader)\n",
436
+ "print(f\"The loss on the test set is {test_loss}\")"
437
+ ]
438
+ }
439
+ ],
440
+ "metadata": {
441
+ "kaggle": {
442
+ "accelerator": "gpu",
443
+ "dataSources": [
444
+ {
445
+ "datasetId": 4705836,
446
+ "sourceId": 7993213,
447
+ "sourceType": "datasetVersion"
448
+ }
449
+ ],
450
+ "dockerImageVersionId": 30674,
451
+ "isGpuEnabled": true,
452
+ "isInternetEnabled": true,
453
+ "language": "python",
454
+ "sourceType": "notebook"
455
+ },
456
+ "kernelspec": {
457
+ "display_name": "Python 3",
458
+ "language": "python",
459
+ "name": "python3"
460
+ },
461
+ "language_info": {
462
+ "codemirror_mode": {
463
+ "name": "ipython",
464
+ "version": 3
465
+ },
466
+ "file_extension": ".py",
467
+ "mimetype": "text/x-python",
468
+ "name": "python",
469
+ "nbconvert_exporter": "python",
470
+ "pygments_lexer": "ipython3",
471
+ "version": "3.12.2"
472
+ },
473
+ "papermill": {
474
+ "default_parameters": {},
475
+ "duration": 32529.008488,
476
+ "end_time": "2024-04-01T10:06:40.496191",
477
+ "environment_variables": {},
478
+ "exception": null,
479
+ "input_path": "__notebook__.ipynb",
480
+ "output_path": "__notebook__.ipynb",
481
+ "parameters": {},
482
+ "start_time": "2024-04-01T01:04:31.487703",
483
+ "version": "2.5.0"
484
+ },
485
+ "widgets": {
486
+ "application/vnd.jupyter.widget-state+json": {
487
+ "state": {
488
+ "0d705fdb53a0460cb06e86e2212618f5": {
489
+ "model_module": "@jupyter-widgets/controls",
490
+ "model_module_version": "1.5.0",
491
+ "model_name": "DescriptionStyleModel",
492
+ "state": {
493
+ "_model_module": "@jupyter-widgets/controls",
494
+ "_model_module_version": "1.5.0",
495
+ "_model_name": "DescriptionStyleModel",
496
+ "_view_count": null,
497
+ "_view_module": "@jupyter-widgets/base",
498
+ "_view_module_version": "1.2.0",
499
+ "_view_name": "StyleView",
500
+ "description_width": ""
501
+ }
502
+ },
503
+ "1a8db5a1fed14afca913a6edb7794b17": {
504
+ "model_module": "@jupyter-widgets/controls",
505
+ "model_module_version": "1.5.0",
506
+ "model_name": "ProgressStyleModel",
507
+ "state": {
508
+ "_model_module": "@jupyter-widgets/controls",
509
+ "_model_module_version": "1.5.0",
510
+ "_model_name": "ProgressStyleModel",
511
+ "_view_count": null,
512
+ "_view_module": "@jupyter-widgets/base",
513
+ "_view_module_version": "1.2.0",
514
+ "_view_name": "StyleView",
515
+ "bar_color": null,
516
+ "description_width": ""
517
+ }
518
+ },
519
+ "1e977edae4d54b5fbd5b7018ffb9858f": {
520
+ "model_module": "@jupyter-widgets/base",
521
+ "model_module_version": "1.2.0",
522
+ "model_name": "LayoutModel",
523
+ "state": {
524
+ "_model_module": "@jupyter-widgets/base",
525
+ "_model_module_version": "1.2.0",
526
+ "_model_name": "LayoutModel",
527
+ "_view_count": null,
528
+ "_view_module": "@jupyter-widgets/base",
529
+ "_view_module_version": "1.2.0",
530
+ "_view_name": "LayoutView",
531
+ "align_content": null,
532
+ "align_items": null,
533
+ "align_self": null,
534
+ "border": null,
535
+ "bottom": null,
536
+ "display": null,
537
+ "flex": null,
538
+ "flex_flow": null,
539
+ "grid_area": null,
540
+ "grid_auto_columns": null,
541
+ "grid_auto_flow": null,
542
+ "grid_auto_rows": null,
543
+ "grid_column": null,
544
+ "grid_gap": null,
545
+ "grid_row": null,
546
+ "grid_template_areas": null,
547
+ "grid_template_columns": null,
548
+ "grid_template_rows": null,
549
+ "height": null,
550
+ "justify_content": null,
551
+ "justify_items": null,
552
+ "left": null,
553
+ "margin": null,
554
+ "max_height": null,
555
+ "max_width": null,
556
+ "min_height": null,
557
+ "min_width": null,
558
+ "object_fit": null,
559
+ "object_position": null,
560
+ "order": null,
561
+ "overflow": null,
562
+ "overflow_x": null,
563
+ "overflow_y": null,
564
+ "padding": null,
565
+ "right": null,
566
+ "top": null,
567
+ "visibility": null,
568
+ "width": null
569
+ }
570
+ },
571
+ "38d95f1d4d65453895e3b9f5ea41723c": {
572
+ "model_module": "@jupyter-widgets/base",
573
+ "model_module_version": "1.2.0",
574
+ "model_name": "LayoutModel",
575
+ "state": {
576
+ "_model_module": "@jupyter-widgets/base",
577
+ "_model_module_version": "1.2.0",
578
+ "_model_name": "LayoutModel",
579
+ "_view_count": null,
580
+ "_view_module": "@jupyter-widgets/base",
581
+ "_view_module_version": "1.2.0",
582
+ "_view_name": "LayoutView",
583
+ "align_content": null,
584
+ "align_items": null,
585
+ "align_self": null,
586
+ "border": null,
587
+ "bottom": null,
588
+ "display": null,
589
+ "flex": null,
590
+ "flex_flow": null,
591
+ "grid_area": null,
592
+ "grid_auto_columns": null,
593
+ "grid_auto_flow": null,
594
+ "grid_auto_rows": null,
595
+ "grid_column": null,
596
+ "grid_gap": null,
597
+ "grid_row": null,
598
+ "grid_template_areas": null,
599
+ "grid_template_columns": null,
600
+ "grid_template_rows": null,
601
+ "height": null,
602
+ "justify_content": null,
603
+ "justify_items": null,
604
+ "left": null,
605
+ "margin": null,
606
+ "max_height": null,
607
+ "max_width": null,
608
+ "min_height": null,
609
+ "min_width": null,
610
+ "object_fit": null,
611
+ "object_position": null,
612
+ "order": null,
613
+ "overflow": null,
614
+ "overflow_x": null,
615
+ "overflow_y": null,
616
+ "padding": null,
617
+ "right": null,
618
+ "top": null,
619
+ "visibility": null,
620
+ "width": null
621
+ }
622
+ },
623
+ "50e7ac5f3f4b4fe58e2e110fea403bbc": {
624
+ "model_module": "@jupyter-widgets/controls",
625
+ "model_module_version": "1.5.0",
626
+ "model_name": "HBoxModel",
627
+ "state": {
628
+ "_dom_classes": [],
629
+ "_model_module": "@jupyter-widgets/controls",
630
+ "_model_module_version": "1.5.0",
631
+ "_model_name": "HBoxModel",
632
+ "_view_count": null,
633
+ "_view_module": "@jupyter-widgets/controls",
634
+ "_view_module_version": "1.5.0",
635
+ "_view_name": "HBoxView",
636
+ "box_style": "",
637
+ "children": [
638
+ "IPY_MODEL_66e2c724e5e54c878bb614d51a452b26",
639
+ "IPY_MODEL_6901f2bab73b4155b0f53c30467d46c3",
640
+ "IPY_MODEL_771460c15f794c71af4f6eeb0ea7b1ad"
641
+ ],
642
+ "layout": "IPY_MODEL_1e977edae4d54b5fbd5b7018ffb9858f"
643
+ }
644
+ },
645
+ "66e2c724e5e54c878bb614d51a452b26": {
646
+ "model_module": "@jupyter-widgets/controls",
647
+ "model_module_version": "1.5.0",
648
+ "model_name": "HTMLModel",
649
+ "state": {
650
+ "_dom_classes": [],
651
+ "_model_module": "@jupyter-widgets/controls",
652
+ "_model_module_version": "1.5.0",
653
+ "_model_name": "HTMLModel",
654
+ "_view_count": null,
655
+ "_view_module": "@jupyter-widgets/controls",
656
+ "_view_module_version": "1.5.0",
657
+ "_view_name": "HTMLView",
658
+ "description": "",
659
+ "description_tooltip": null,
660
+ "layout": "IPY_MODEL_38d95f1d4d65453895e3b9f5ea41723c",
661
+ "placeholder": "​",
662
+ "style": "IPY_MODEL_0d705fdb53a0460cb06e86e2212618f5",
663
+ "value": "100%"
664
+ }
665
+ },
666
+ "6901f2bab73b4155b0f53c30467d46c3": {
667
+ "model_module": "@jupyter-widgets/controls",
668
+ "model_module_version": "1.5.0",
669
+ "model_name": "FloatProgressModel",
670
+ "state": {
671
+ "_dom_classes": [],
672
+ "_model_module": "@jupyter-widgets/controls",
673
+ "_model_module_version": "1.5.0",
674
+ "_model_name": "FloatProgressModel",
675
+ "_view_count": null,
676
+ "_view_module": "@jupyter-widgets/controls",
677
+ "_view_module_version": "1.5.0",
678
+ "_view_name": "ProgressView",
679
+ "bar_style": "",
680
+ "description": "",
681
+ "description_tooltip": null,
682
+ "layout": "IPY_MODEL_dd05d612550a4db28ebf2c7cfc2312fe",
683
+ "max": 75500,
684
+ "min": 0,
685
+ "orientation": "horizontal",
686
+ "style": "IPY_MODEL_1a8db5a1fed14afca913a6edb7794b17",
687
+ "value": 75500
688
+ }
689
+ },
690
+ "771460c15f794c71af4f6eeb0ea7b1ad": {
691
+ "model_module": "@jupyter-widgets/controls",
692
+ "model_module_version": "1.5.0",
693
+ "model_name": "HTMLModel",
694
+ "state": {
695
+ "_dom_classes": [],
696
+ "_model_module": "@jupyter-widgets/controls",
697
+ "_model_module_version": "1.5.0",
698
+ "_model_name": "HTMLModel",
699
+ "_view_count": null,
700
+ "_view_module": "@jupyter-widgets/controls",
701
+ "_view_module_version": "1.5.0",
702
+ "_view_name": "HTMLView",
703
+ "description": "",
704
+ "description_tooltip": null,
705
+ "layout": "IPY_MODEL_d08ee627a3cc414489e61161d8a917ae",
706
+ "placeholder": "​",
707
+ "style": "IPY_MODEL_b2818288e9b9459fb75a1ea2e6f35117",
708
+ "value": " 75500/75500 [9:00:45&lt;00:00, 2.32it/s]"
709
+ }
710
+ },
711
+ "b2818288e9b9459fb75a1ea2e6f35117": {
712
+ "model_module": "@jupyter-widgets/controls",
713
+ "model_module_version": "1.5.0",
714
+ "model_name": "DescriptionStyleModel",
715
+ "state": {
716
+ "_model_module": "@jupyter-widgets/controls",
717
+ "_model_module_version": "1.5.0",
718
+ "_model_name": "DescriptionStyleModel",
719
+ "_view_count": null,
720
+ "_view_module": "@jupyter-widgets/base",
721
+ "_view_module_version": "1.2.0",
722
+ "_view_name": "StyleView",
723
+ "description_width": ""
724
+ }
725
+ },
726
+ "d08ee627a3cc414489e61161d8a917ae": {
727
+ "model_module": "@jupyter-widgets/base",
728
+ "model_module_version": "1.2.0",
729
+ "model_name": "LayoutModel",
730
+ "state": {
731
+ "_model_module": "@jupyter-widgets/base",
732
+ "_model_module_version": "1.2.0",
733
+ "_model_name": "LayoutModel",
734
+ "_view_count": null,
735
+ "_view_module": "@jupyter-widgets/base",
736
+ "_view_module_version": "1.2.0",
737
+ "_view_name": "LayoutView",
738
+ "align_content": null,
739
+ "align_items": null,
740
+ "align_self": null,
741
+ "border": null,
742
+ "bottom": null,
743
+ "display": null,
744
+ "flex": null,
745
+ "flex_flow": null,
746
+ "grid_area": null,
747
+ "grid_auto_columns": null,
748
+ "grid_auto_flow": null,
749
+ "grid_auto_rows": null,
750
+ "grid_column": null,
751
+ "grid_gap": null,
752
+ "grid_row": null,
753
+ "grid_template_areas": null,
754
+ "grid_template_columns": null,
755
+ "grid_template_rows": null,
756
+ "height": null,
757
+ "justify_content": null,
758
+ "justify_items": null,
759
+ "left": null,
760
+ "margin": null,
761
+ "max_height": null,
762
+ "max_width": null,
763
+ "min_height": null,
764
+ "min_width": null,
765
+ "object_fit": null,
766
+ "object_position": null,
767
+ "order": null,
768
+ "overflow": null,
769
+ "overflow_x": null,
770
+ "overflow_y": null,
771
+ "padding": null,
772
+ "right": null,
773
+ "top": null,
774
+ "visibility": null,
775
+ "width": null
776
+ }
777
+ },
778
+ "dd05d612550a4db28ebf2c7cfc2312fe": {
779
+ "model_module": "@jupyter-widgets/base",
780
+ "model_module_version": "1.2.0",
781
+ "model_name": "LayoutModel",
782
+ "state": {
783
+ "_model_module": "@jupyter-widgets/base",
784
+ "_model_module_version": "1.2.0",
785
+ "_model_name": "LayoutModel",
786
+ "_view_count": null,
787
+ "_view_module": "@jupyter-widgets/base",
788
+ "_view_module_version": "1.2.0",
789
+ "_view_name": "LayoutView",
790
+ "align_content": null,
791
+ "align_items": null,
792
+ "align_self": null,
793
+ "border": null,
794
+ "bottom": null,
795
+ "display": null,
796
+ "flex": null,
797
+ "flex_flow": null,
798
+ "grid_area": null,
799
+ "grid_auto_columns": null,
800
+ "grid_auto_flow": null,
801
+ "grid_auto_rows": null,
802
+ "grid_column": null,
803
+ "grid_gap": null,
804
+ "grid_row": null,
805
+ "grid_template_areas": null,
806
+ "grid_template_columns": null,
807
+ "grid_template_rows": null,
808
+ "height": null,
809
+ "justify_content": null,
810
+ "justify_items": null,
811
+ "left": null,
812
+ "margin": null,
813
+ "max_height": null,
814
+ "max_width": null,
815
+ "min_height": null,
816
+ "min_width": null,
817
+ "object_fit": null,
818
+ "object_position": null,
819
+ "order": null,
820
+ "overflow": null,
821
+ "overflow_x": null,
822
+ "overflow_y": null,
823
+ "padding": null,
824
+ "right": null,
825
+ "top": null,
826
+ "visibility": null,
827
+ "width": null
828
+ }
829
+ }
830
+ },
831
+ "version_major": 2,
832
+ "version_minor": 0
833
+ }
834
+ }
835
+ },
836
+ "nbformat": 4,
837
+ "nbformat_minor": 5
838
+ }
src/templates/index.html ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <title>MangaColorizer</title>
5
+ </head>
6
+ <body>
7
+ <h1 style="text-align: center;font-size: xx-large;">MangaColorizer</h1>
8
+ <p style="text-align: center;font-size: large;">Upload a black and white drawing and get its colorized version</p>
9
+ </body>
10
+ </html>
src/test.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
+ from PIL import Image
6
+ from torch.utils.data import Dataset
7
+ from torchvision import transforms
8
+
9
+
10
+ class ImageDataset(Dataset):
11
+ def __init__(self, dir, transform=None) -> None:
12
+ self.dir = dir
13
+ self.transform = transform
14
+ self.file_list = sorted(os.listdir(self.dir))
15
+
16
+ def __len__(self):
17
+ return len(self.file_list)
18
+
19
+ def __getitem__(self, idx):
20
+ image_name = self.file_list[idx]
21
+ image_path = os.path.join(self.dir, image_name)
22
+
23
+ grayscale_image = Image.open(image_path).convert('L')
24
+ colorized_image = Image.open(image_path).convert('RGB')
25
+
26
+ if self.transform:
27
+ grayscale_image = self.transform(grayscale_image)
28
+ colorized_image = self.transform(colorized_image)
29
+
30
+ return grayscale_image, colorized_image
31
+
32
+
33
+ def show_image(image_tensor):
34
+ try:
35
+ if len(image_tensor) == 1:
36
+ plt.imshow(image_tensor[0], cmap="gray")
37
+ else:
38
+ plt.imshow(image_tensor.numpy().transpose(1, 2, 0))
39
+ except Exception as e:
40
+ print(f"Exception when showing image: {e}")
41
+
42
+
43
+ # To be able to calculate MSE loss in case output tensor has different shape from target tensor
44
+ def adjust_output_shape(output_tensor, target_tensor):
45
+ adjusted_tensor = torch.nn.functional.interpolate(output_tensor, size=target_tensor.shape[2:], mode="bilinear", align_corners=False)
46
+ return adjusted_tensor
47
+
48
+
49
+ def pil_to_torch(pil_image):
50
+ transform = transforms.ToTensor()
51
+ return transform(pil_image).unsqueeze(0)
52
+
53
+ def torch_to_pil(torch_image):
54
+ transform = transforms.ToPILImage()
55
+ return transform(torch_image.squeeze(0))