Spaces:
Build error
Build error
NikeZoldyck
commited on
Commit
•
4158574
1
Parent(s):
4ac8bc1
adding the gradio app code
Browse files- app.py +152 -0
- models/__init__.py +0 -0
- models/components/__init__.py +0 -0
- models/components/photo_wct.pth +3 -0
- models/models.py +297 -0
- requirements.txt +18 -0
- utils/__init__.py +12 -0
- utils/photo_smooth.py +101 -0
- utils/photo_wct.py +171 -0
- utils/shared_utils.py +136 -0
- utils/smooth_filter.py +405 -0
app.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import numpy as np
|
3 |
+
import gradio as gr
|
4 |
+
import utils.shared_utils as st
|
5 |
+
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import autocast
|
9 |
+
import torchvision.transforms as T
|
10 |
+
from contextlib import nullcontext
|
11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
+
context = autocast if device == "cuda" else nullcontext
|
13 |
+
# Apply the transformations needed
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
def select_input(input_img,webcm_img):
|
18 |
+
if input_img is None:
|
19 |
+
img= webcm_img
|
20 |
+
else:
|
21 |
+
img=input_img
|
22 |
+
return img
|
23 |
+
|
24 |
+
|
25 |
+
def infer(prompt,samples):
|
26 |
+
images= []
|
27 |
+
selections = ["Img_{}".format(str(i+1).zfill(2)) for i in range(samples)]
|
28 |
+
with context(device):
|
29 |
+
for _ in range(samples):
|
30 |
+
back_img = st.stableDiffusionAPICall(prompt)
|
31 |
+
images.append(back_img)
|
32 |
+
return images
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
def change_bg_option(choice):
|
39 |
+
if choice == "I have an Image":
|
40 |
+
return gr.Image(shape=(800, 800))
|
41 |
+
|
42 |
+
elif choice == "Generate one for me":
|
43 |
+
return gr.update(lines=8, visible=True, value="Please enter a text prompt")
|
44 |
+
else:
|
45 |
+
return gr.update(visible=False)
|
46 |
+
|
47 |
+
|
48 |
+
# TEXT
|
49 |
+
title = "FSDL- One-Shot, Green-Screen, Composition-Transfer"
|
50 |
+
DEFAULT_TEXT = "Photorealistic scenery of bookshelf in a room"
|
51 |
+
description = """
|
52 |
+
<center><a href="https://docs.google.com/document/d/1fde8XKIMT1nNU72859ytd2c58LFBxepS3od9KFBrJbM/edit?usp=sharing">[PAPER]</a> <a href="https://github.com/snknitin/FSDL-Project/blob/main/src/utils/shared_utils.py">[CODE]</a></center>
|
53 |
+
<details>
|
54 |
+
<summary><b>Instructions</b></summary>
|
55 |
+
<p style="margin-top: -3px;">With this app, you can generate a suitable background image to overlay your portrait!<br />You have several ways to set how your final auto-edited image will look like:<br /></p>
|
56 |
+
<ul style="margin-top: -20px;margin-bottom: -15px;">
|
57 |
+
<li style="margin-bottom: -10px;margin-left: 20px;">Use the "<i>Inputs</i>" tab to either upload an image from your device or allow the use of your webcam to capture</li>
|
58 |
+
<li style="margin-left: 20px;">Use the "<i>Background Image Inputs</i>" to upload your own background</li>
|
59 |
+
<li style="margin-left: 20px;">Use the "<i>Text prompt</i>" tab to generate a satisfactory bacground image.</li>
|
60 |
+
</ul>
|
61 |
+
<p>After customization, just hit "<i>Edit</i>" and wait a few seconds.<br />The final image will be available for download <br /> <b>Enjoy!<b><p>
|
62 |
+
</details>
|
63 |
+
"""
|
64 |
+
|
65 |
+
running = """
|
66 |
+
|
67 |
+
### Instructions for running the 3 S's in sequence
|
68 |
+
|
69 |
+
* **Superimpose** - This button allows you to isolate the foreground from your image and overlay it on the background. Remove background using alpha matting
|
70 |
+
* **Style-Transfer** - This button transfer the style from your original image to re-map your new background realistically. Uses Nvidia FastPhotoStyle
|
71 |
+
* **Smoothing** - Given than image resolutions and clarity can be an issue, this smoothing button makes your final image crisp after the stylization transfer. Fair warning - this last process can take 5-10 mins
|
72 |
+
"""
|
73 |
+
|
74 |
+
|
75 |
+
demo = gr.Blocks()
|
76 |
+
|
77 |
+
with demo:
|
78 |
+
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>")
|
79 |
+
with gr.Box():
|
80 |
+
gr.Markdown(description)
|
81 |
+
# First row - Inputs
|
82 |
+
with gr.Row(scale=1):
|
83 |
+
with gr.Column():
|
84 |
+
with gr.Tabs():
|
85 |
+
with gr.TabItem("Upload "):
|
86 |
+
input_img = gr.Image(shape=(800, 800), interactive=True, label="You")
|
87 |
+
with gr.TabItem("Webcam Capture"):
|
88 |
+
webcm_img = gr.Image(source="webcam", streaming=True, shape=(800, 800), interactive=True)
|
89 |
+
inp_select_btn = gr.Button("Select")
|
90 |
+
|
91 |
+
with gr.Column():
|
92 |
+
with gr.Tabs():
|
93 |
+
with gr.TabItem("Upload"):
|
94 |
+
bgm_img = gr.Image(shape=(800, 800), type="pil", interactive=True, label="The Background")
|
95 |
+
bgm_select_btn = gr.Button("Select")
|
96 |
+
|
97 |
+
with gr.TabItem("Generate via Text Prompt"):
|
98 |
+
with gr.Box():
|
99 |
+
with gr.Row().style(mobile_collapse=False, equal_height=True):
|
100 |
+
text = gr.Textbox(lines=7,
|
101 |
+
placeholder="Enter your prompt to generate a background image... something like - Photorealistic scenery of bookshelf in a room")
|
102 |
+
|
103 |
+
samples = gr.Slider(label="Number of Images", minimum=1, maximum=5, value=2, step=1)
|
104 |
+
btn = gr.Button("Generate images",variant="primary").style(
|
105 |
+
margin=False,
|
106 |
+
rounded=(False, True, True, False),
|
107 |
+
)
|
108 |
+
|
109 |
+
gallery = gr.Gallery(label="Generated images", show_label=True).style(grid=(1, 3), height="auto")
|
110 |
+
# image_options = gr.Radio(label="Pick", interactive=True, choices=None, type="value")
|
111 |
+
text.submit(infer, inputs=[text, samples], outputs=gallery)
|
112 |
+
btn.click(infer, inputs=[text, samples], outputs=gallery, show_progress=True, status_tracker=None)
|
113 |
+
|
114 |
+
|
115 |
+
# Second Row - Backgrounds
|
116 |
+
with gr.Row(scale=1):
|
117 |
+
with gr.Column():
|
118 |
+
final_input_img = gr.Image(shape=(800, 800), type="pil", label="Foreground")
|
119 |
+
|
120 |
+
with gr.Column():
|
121 |
+
final_back_img = gr.Image(shape=(800, 800), type="pil", label="Background", interactive=True)
|
122 |
+
|
123 |
+
bgm_select_btn.click(fn=lambda x: x, inputs=bgm_img, outputs=final_back_img)
|
124 |
+
|
125 |
+
inp_select_btn.click(select_input, [input_img, webcm_img], final_input_img)
|
126 |
+
|
127 |
+
with gr.Row(scale=1):
|
128 |
+
with gr.Box():
|
129 |
+
gr.Markdown(running)
|
130 |
+
|
131 |
+
with gr.Row(scale=1):
|
132 |
+
|
133 |
+
with gr.Column(scale=1):
|
134 |
+
supimp_btn = gr.Button("SuperImpose")
|
135 |
+
overlay_img = gr.Image(shape=(800, 800), label="Overlay", type="pil")
|
136 |
+
|
137 |
+
|
138 |
+
with gr.Column(scale=1):
|
139 |
+
style_btn = gr.Button("Composition-Transfer",variant="primary")
|
140 |
+
style_img = gr.Image(shape=(800, 800),label="Style-Transfer Image",type="pil")
|
141 |
+
|
142 |
+
with gr.Column(scale=1):
|
143 |
+
submit_btn = gr.Button("Smoothen",variant="primary")
|
144 |
+
output_img = gr.Image(shape=(800, 800),label="FinalSmoothened Image",type="pil")
|
145 |
+
|
146 |
+
supimp_btn.click(fn=st.superimpose, inputs=[final_input_img, final_back_img], outputs=[overlay_img])
|
147 |
+
style_btn.click(fn=st.style_transfer, inputs=[overlay_img,final_input_img], outputs=[style_img])
|
148 |
+
submit_btn.click(fn=st.smoother, inputs=[style_img,overlay_img], outputs=[output_img])
|
149 |
+
|
150 |
+
demo.queue()
|
151 |
+
demo.launch()
|
152 |
+
|
models/__init__.py
ADDED
File without changes
|
models/components/__init__.py
ADDED
File without changes
|
models/components/photo_wct.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bedc114a83833de79e92b7166b37bc522db71a30bbfa13d0c4f36387789c8af5
|
3 |
+
size 33410469
|
models/models.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
class VGGEncoder(nn.Module):
|
9 |
+
def __init__(self, level):
|
10 |
+
super(VGGEncoder, self).__init__()
|
11 |
+
self.level = level
|
12 |
+
|
13 |
+
# 224 x 224
|
14 |
+
self.conv0 = nn.Conv2d(3, 3, 1, 1, 0)
|
15 |
+
|
16 |
+
self.pad1_1 = nn.ReflectionPad2d((1, 1, 1, 1))
|
17 |
+
# 226 x 226
|
18 |
+
self.conv1_1 = nn.Conv2d(3, 64, 3, 1, 0)
|
19 |
+
self.relu1_1 = nn.ReLU(inplace=True)
|
20 |
+
# 224 x 224
|
21 |
+
|
22 |
+
if level < 2: return
|
23 |
+
|
24 |
+
self.pad1_2 = nn.ReflectionPad2d((1, 1, 1, 1))
|
25 |
+
self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 0)
|
26 |
+
self.relu1_2 = nn.ReLU(inplace=True)
|
27 |
+
# 224 x 224
|
28 |
+
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
|
29 |
+
# 112 x 112
|
30 |
+
|
31 |
+
self.pad2_1 = nn.ReflectionPad2d((1, 1, 1, 1))
|
32 |
+
self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 0)
|
33 |
+
self.relu2_1 = nn.ReLU(inplace=True)
|
34 |
+
# 112 x 112
|
35 |
+
|
36 |
+
if level < 3: return
|
37 |
+
|
38 |
+
self.pad2_2 = nn.ReflectionPad2d((1, 1, 1, 1))
|
39 |
+
self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 0)
|
40 |
+
self.relu2_2 = nn.ReLU(inplace=True)
|
41 |
+
# 112 x 112
|
42 |
+
|
43 |
+
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
|
44 |
+
# 56 x 56
|
45 |
+
|
46 |
+
self.pad3_1 = nn.ReflectionPad2d((1, 1, 1, 1))
|
47 |
+
self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 0)
|
48 |
+
self.relu3_1 = nn.ReLU(inplace=True)
|
49 |
+
# 56 x 56
|
50 |
+
|
51 |
+
if level < 4: return
|
52 |
+
|
53 |
+
self.pad3_2 = nn.ReflectionPad2d((1, 1, 1, 1))
|
54 |
+
self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 0)
|
55 |
+
self.relu3_2 = nn.ReLU(inplace=True)
|
56 |
+
# 56 x 56
|
57 |
+
|
58 |
+
self.pad3_3 = nn.ReflectionPad2d((1, 1, 1, 1))
|
59 |
+
self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 0)
|
60 |
+
self.relu3_3 = nn.ReLU(inplace=True)
|
61 |
+
# 56 x 56
|
62 |
+
|
63 |
+
self.pad3_4 = nn.ReflectionPad2d((1, 1, 1, 1))
|
64 |
+
self.conv3_4 = nn.Conv2d(256, 256, 3, 1, 0)
|
65 |
+
self.relu3_4 = nn.ReLU(inplace=True)
|
66 |
+
# 56 x 56
|
67 |
+
|
68 |
+
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
|
69 |
+
# 28 x 28
|
70 |
+
|
71 |
+
self.pad4_1 = nn.ReflectionPad2d((1, 1, 1, 1))
|
72 |
+
self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 0)
|
73 |
+
self.relu4_1 = nn.ReLU(inplace=True)
|
74 |
+
# 28 x 28
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
out = self.conv0(x)
|
78 |
+
|
79 |
+
out = self.pad1_1(out)
|
80 |
+
out = self.conv1_1(out)
|
81 |
+
out = self.relu1_1(out)
|
82 |
+
|
83 |
+
if self.level < 2:
|
84 |
+
return out
|
85 |
+
|
86 |
+
out = self.pad1_2(out)
|
87 |
+
out = self.conv1_2(out)
|
88 |
+
pool1 = self.relu1_2(out)
|
89 |
+
|
90 |
+
out, pool1_idx = self.maxpool1(pool1)
|
91 |
+
|
92 |
+
out = self.pad2_1(out)
|
93 |
+
out = self.conv2_1(out)
|
94 |
+
out = self.relu2_1(out)
|
95 |
+
|
96 |
+
if self.level < 3:
|
97 |
+
return out, pool1_idx, pool1.size()
|
98 |
+
|
99 |
+
out = self.pad2_2(out)
|
100 |
+
out = self.conv2_2(out)
|
101 |
+
pool2 = self.relu2_2(out)
|
102 |
+
|
103 |
+
out, pool2_idx = self.maxpool2(pool2)
|
104 |
+
|
105 |
+
out = self.pad3_1(out)
|
106 |
+
out = self.conv3_1(out)
|
107 |
+
out = self.relu3_1(out)
|
108 |
+
|
109 |
+
if self.level < 4:
|
110 |
+
return out, pool1_idx, pool1.size(), pool2_idx, pool2.size()
|
111 |
+
|
112 |
+
out = self.pad3_2(out)
|
113 |
+
out = self.conv3_2(out)
|
114 |
+
out = self.relu3_2(out)
|
115 |
+
|
116 |
+
out = self.pad3_3(out)
|
117 |
+
out = self.conv3_3(out)
|
118 |
+
out = self.relu3_3(out)
|
119 |
+
|
120 |
+
out = self.pad3_4(out)
|
121 |
+
out = self.conv3_4(out)
|
122 |
+
pool3 = self.relu3_4(out)
|
123 |
+
out, pool3_idx = self.maxpool3(pool3)
|
124 |
+
|
125 |
+
out = self.pad4_1(out)
|
126 |
+
out = self.conv4_1(out)
|
127 |
+
out = self.relu4_1(out)
|
128 |
+
|
129 |
+
return out, pool1_idx, pool1.size(), pool2_idx, pool2.size(), pool3_idx, pool3.size()
|
130 |
+
|
131 |
+
def forward_multiple(self, x):
|
132 |
+
out = self.conv0(x)
|
133 |
+
|
134 |
+
out = self.pad1_1(out)
|
135 |
+
out = self.conv1_1(out)
|
136 |
+
out = self.relu1_1(out)
|
137 |
+
|
138 |
+
if self.level < 2: return out
|
139 |
+
|
140 |
+
out1 = out
|
141 |
+
|
142 |
+
out = self.pad1_2(out)
|
143 |
+
out = self.conv1_2(out)
|
144 |
+
pool1 = self.relu1_2(out)
|
145 |
+
|
146 |
+
out, pool1_idx = self.maxpool1(pool1)
|
147 |
+
|
148 |
+
out = self.pad2_1(out)
|
149 |
+
out = self.conv2_1(out)
|
150 |
+
out = self.relu2_1(out)
|
151 |
+
|
152 |
+
if self.level < 3: return out, out1
|
153 |
+
|
154 |
+
out2 = out
|
155 |
+
|
156 |
+
out = self.pad2_2(out)
|
157 |
+
out = self.conv2_2(out)
|
158 |
+
pool2 = self.relu2_2(out)
|
159 |
+
|
160 |
+
out, pool2_idx = self.maxpool2(pool2)
|
161 |
+
|
162 |
+
out = self.pad3_1(out)
|
163 |
+
out = self.conv3_1(out)
|
164 |
+
out = self.relu3_1(out)
|
165 |
+
|
166 |
+
if self.level < 4: return out, out2, out1
|
167 |
+
|
168 |
+
out3 = out
|
169 |
+
|
170 |
+
out = self.pad3_2(out)
|
171 |
+
out = self.conv3_2(out)
|
172 |
+
out = self.relu3_2(out)
|
173 |
+
|
174 |
+
out = self.pad3_3(out)
|
175 |
+
out = self.conv3_3(out)
|
176 |
+
out = self.relu3_3(out)
|
177 |
+
|
178 |
+
out = self.pad3_4(out)
|
179 |
+
out = self.conv3_4(out)
|
180 |
+
pool3 = self.relu3_4(out)
|
181 |
+
out, pool3_idx = self.maxpool3(pool3)
|
182 |
+
|
183 |
+
out = self.pad4_1(out)
|
184 |
+
out = self.conv4_1(out)
|
185 |
+
out = self.relu4_1(out)
|
186 |
+
|
187 |
+
return out, out3, out2, out1
|
188 |
+
|
189 |
+
|
190 |
+
class VGGDecoder(nn.Module):
|
191 |
+
def __init__(self, level):
|
192 |
+
super(VGGDecoder, self).__init__()
|
193 |
+
self.level = level
|
194 |
+
|
195 |
+
if level > 3:
|
196 |
+
self.pad4_1 = nn.ReflectionPad2d((1, 1, 1, 1))
|
197 |
+
self.conv4_1 = nn.Conv2d(512, 256, 3, 1, 0)
|
198 |
+
self.relu4_1 = nn.ReLU(inplace=True)
|
199 |
+
# 28 x 28
|
200 |
+
|
201 |
+
self.unpool3 = nn.MaxUnpool2d(kernel_size=2, stride=2)
|
202 |
+
# 56 x 56
|
203 |
+
|
204 |
+
self.pad3_4 = nn.ReflectionPad2d((1, 1, 1, 1))
|
205 |
+
self.conv3_4 = nn.Conv2d(256, 256, 3, 1, 0)
|
206 |
+
self.relu3_4 = nn.ReLU(inplace=True)
|
207 |
+
# 56 x 56
|
208 |
+
|
209 |
+
self.pad3_3 = nn.ReflectionPad2d((1, 1, 1, 1))
|
210 |
+
self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 0)
|
211 |
+
self.relu3_3 = nn.ReLU(inplace=True)
|
212 |
+
# 56 x 56
|
213 |
+
|
214 |
+
self.pad3_2 = nn.ReflectionPad2d((1, 1, 1, 1))
|
215 |
+
self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 0)
|
216 |
+
self.relu3_2 = nn.ReLU(inplace=True)
|
217 |
+
# 56 x 56
|
218 |
+
|
219 |
+
if level > 2:
|
220 |
+
self.pad3_1 = nn.ReflectionPad2d((1, 1, 1, 1))
|
221 |
+
self.conv3_1 = nn.Conv2d(256, 128, 3, 1, 0)
|
222 |
+
self.relu3_1 = nn.ReLU(inplace=True)
|
223 |
+
# 56 x 56
|
224 |
+
|
225 |
+
self.unpool2 = nn.MaxUnpool2d(kernel_size=2, stride=2)
|
226 |
+
# 112 x 112
|
227 |
+
|
228 |
+
self.pad2_2 = nn.ReflectionPad2d((1, 1, 1, 1))
|
229 |
+
self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 0)
|
230 |
+
self.relu2_2 = nn.ReLU(inplace=True)
|
231 |
+
# 112 x 112
|
232 |
+
|
233 |
+
if level > 1:
|
234 |
+
self.pad2_1 = nn.ReflectionPad2d((1, 1, 1, 1))
|
235 |
+
self.conv2_1 = nn.Conv2d(128, 64, 3, 1, 0)
|
236 |
+
self.relu2_1 = nn.ReLU(inplace=True)
|
237 |
+
# 112 x 112
|
238 |
+
|
239 |
+
self.unpool1 = nn.MaxUnpool2d(kernel_size=2, stride=2)
|
240 |
+
# 224 x 224
|
241 |
+
|
242 |
+
self.pad1_2 = nn.ReflectionPad2d((1, 1, 1, 1))
|
243 |
+
self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 0)
|
244 |
+
self.relu1_2 = nn.ReLU(inplace=True)
|
245 |
+
# 224 x 224
|
246 |
+
|
247 |
+
if level > 0:
|
248 |
+
self.pad1_1 = nn.ReflectionPad2d((1, 1, 1, 1))
|
249 |
+
self.conv1_1 = nn.Conv2d(64, 3, 3, 1, 0)
|
250 |
+
|
251 |
+
def forward(self, x, pool1_idx=None, pool1_size=None, pool2_idx=None, pool2_size=None, pool3_idx=None,
|
252 |
+
pool3_size=None):
|
253 |
+
out = x
|
254 |
+
|
255 |
+
if self.level > 3:
|
256 |
+
out = self.pad4_1(out)
|
257 |
+
out = self.conv4_1(out)
|
258 |
+
out = self.relu4_1(out)
|
259 |
+
out = self.unpool3(out, pool3_idx, output_size=pool3_size)
|
260 |
+
|
261 |
+
out = self.pad3_4(out)
|
262 |
+
out = self.conv3_4(out)
|
263 |
+
out = self.relu3_4(out)
|
264 |
+
|
265 |
+
out = self.pad3_3(out)
|
266 |
+
out = self.conv3_3(out)
|
267 |
+
out = self.relu3_3(out)
|
268 |
+
|
269 |
+
out = self.pad3_2(out)
|
270 |
+
out = self.conv3_2(out)
|
271 |
+
out = self.relu3_2(out)
|
272 |
+
|
273 |
+
if self.level > 2:
|
274 |
+
out = self.pad3_1(out)
|
275 |
+
out = self.conv3_1(out)
|
276 |
+
out = self.relu3_1(out)
|
277 |
+
out = self.unpool2(out, pool2_idx, output_size=pool2_size)
|
278 |
+
|
279 |
+
out = self.pad2_2(out)
|
280 |
+
out = self.conv2_2(out)
|
281 |
+
out = self.relu2_2(out)
|
282 |
+
|
283 |
+
if self.level > 1:
|
284 |
+
out = self.pad2_1(out)
|
285 |
+
out = self.conv2_1(out)
|
286 |
+
out = self.relu2_1(out)
|
287 |
+
out = self.unpool1(out, pool1_idx, output_size=pool1_size)
|
288 |
+
|
289 |
+
out = self.pad1_2(out)
|
290 |
+
out = self.conv1_2(out)
|
291 |
+
out = self.relu1_2(out)
|
292 |
+
|
293 |
+
if self.level > 0:
|
294 |
+
out = self.pad1_1(out)
|
295 |
+
out = self.conv1_1(out)
|
296 |
+
|
297 |
+
return out
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu116
|
2 |
+
torch
|
3 |
+
diffusers
|
4 |
+
transformers
|
5 |
+
scipy
|
6 |
+
ftfy
|
7 |
+
gradio
|
8 |
+
torchvision
|
9 |
+
scikit-image
|
10 |
+
rembg
|
11 |
+
replicate
|
12 |
+
requests
|
13 |
+
Pillow
|
14 |
+
numpy
|
15 |
+
scipy
|
16 |
+
pyrootutils
|
17 |
+
pynvrtc
|
18 |
+
cupy
|
utils/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.utils.pylogger import get_pylogger
|
2 |
+
from src.utils.rich_utils import enforce_tags, print_config_tree
|
3 |
+
from src.utils.utils import (
|
4 |
+
close_loggers,
|
5 |
+
extras,
|
6 |
+
get_metric_value,
|
7 |
+
instantiate_callbacks,
|
8 |
+
instantiate_loggers,
|
9 |
+
log_hyperparameters,
|
10 |
+
save_file,
|
11 |
+
task_wrapper,
|
12 |
+
)
|
utils/photo_smooth.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
from __future__ import division
|
6 |
+
import torch.nn as nn
|
7 |
+
import scipy.misc
|
8 |
+
import scipy._lib
|
9 |
+
import numpy as np
|
10 |
+
import scipy.sparse
|
11 |
+
import scipy.sparse.linalg as linalg
|
12 |
+
from numpy.lib.stride_tricks import as_strided
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
|
16 |
+
class Propagator(nn.Module):
|
17 |
+
def __init__(self, beta=0.9999):
|
18 |
+
super(Propagator, self).__init__()
|
19 |
+
self.beta = beta
|
20 |
+
|
21 |
+
def process(self, initImg, contentImg):
|
22 |
+
|
23 |
+
if type(contentImg) == str:
|
24 |
+
content = scipy.misc.imread(contentImg, mode='RGB')
|
25 |
+
else:
|
26 |
+
content = contentImg.copy()
|
27 |
+
# content = scipy.misc.imread(contentImg, mode='RGB')
|
28 |
+
|
29 |
+
if type(initImg) == str:
|
30 |
+
B = scipy.misc.imread(initImg, mode='RGB').astype(np.float64) / 255
|
31 |
+
else:
|
32 |
+
B = scipy.asarray(initImg).astype(np.float64) / 255
|
33 |
+
# B = self.
|
34 |
+
# B = scipy.misc.imread(initImg, mode='RGB').astype(np.float64)/255
|
35 |
+
h1,w1,k = B.shape
|
36 |
+
h = h1 - 4
|
37 |
+
w = w1 - 4
|
38 |
+
B = B[int((h1-h)/2):int((h1-h)/2+h),int((w1-w)/2):int((w1-w)/2+w),:]
|
39 |
+
#content = scipy.misc.imresize(content,(h,w))
|
40 |
+
content = np.asarray(Image.fromarray(np.array(content)).resize((h,w),Image.BICUBIC))
|
41 |
+
B = self.__replication_padding(B,2)
|
42 |
+
content = self.__replication_padding(content,2)
|
43 |
+
content = content.astype(np.float64)/255
|
44 |
+
B = np.reshape(B,(h1*w1,k))
|
45 |
+
W = self.__compute_laplacian(content)
|
46 |
+
W = W.tocsc()
|
47 |
+
dd = W.sum(0)
|
48 |
+
dd = np.sqrt(np.power(dd,-1))
|
49 |
+
dd = dd.A.squeeze()
|
50 |
+
D = scipy.sparse.csc_matrix((dd, (np.arange(0,w1*h1), np.arange(0,w1*h1)))) # 0.026
|
51 |
+
S = D.dot(W).dot(D)
|
52 |
+
A = scipy.sparse.identity(w1*h1) - self.beta*S
|
53 |
+
A = A.tocsc()
|
54 |
+
solver = linalg.factorized(A)
|
55 |
+
V = np.zeros((h1*w1,k))
|
56 |
+
V[:,0] = solver(B[:,0])
|
57 |
+
V[:,1] = solver(B[:,1])
|
58 |
+
V[:,2] = solver(B[:,2])
|
59 |
+
V = V*(1-self.beta)
|
60 |
+
V = V.reshape(h1,w1,k)
|
61 |
+
V = V[2:2+h,2:2+w,:]
|
62 |
+
|
63 |
+
img = Image.fromarray(np.uint8(np.clip(V * 255., 0, 255.)))
|
64 |
+
return img
|
65 |
+
|
66 |
+
# Returns sparse matting laplacian
|
67 |
+
# The implementation of the function is heavily borrowed from
|
68 |
+
# https://github.com/MarcoForte/closed-form-matting/blob/master/closed_form_matting.py
|
69 |
+
# We thank Marco Forte for sharing his code.
|
70 |
+
def __compute_laplacian(self, img, eps=10**(-7), win_rad=1):
|
71 |
+
win_size = (win_rad*2+1)**2
|
72 |
+
h, w, d = img.shape
|
73 |
+
c_h, c_w = h - 2*win_rad, w - 2*win_rad
|
74 |
+
win_diam = win_rad*2+1
|
75 |
+
indsM = np.arange(h*w).reshape((h, w))
|
76 |
+
ravelImg = img.reshape(h*w, d)
|
77 |
+
win_inds = self.__rolling_block(indsM, block=(win_diam, win_diam))
|
78 |
+
win_inds = win_inds.reshape(c_h, c_w, win_size)
|
79 |
+
winI = ravelImg[win_inds]
|
80 |
+
win_mu = np.mean(winI, axis=2, keepdims=True)
|
81 |
+
win_var = np.einsum('...ji,...jk ->...ik', winI, winI)/win_size - np.einsum('...ji,...jk ->...ik', win_mu, win_mu)
|
82 |
+
inv = np.linalg.inv(win_var + (eps/win_size)*np.eye(3))
|
83 |
+
X = np.einsum('...ij,...jk->...ik', winI - win_mu, inv)
|
84 |
+
vals = (1/win_size)*(1 + np.einsum('...ij,...kj->...ik', X, winI - win_mu))
|
85 |
+
nz_indsCol = np.tile(win_inds, win_size).ravel()
|
86 |
+
nz_indsRow = np.repeat(win_inds, win_size).ravel()
|
87 |
+
nz_indsVal = vals.ravel()
|
88 |
+
L = scipy.sparse.coo_matrix((nz_indsVal, (nz_indsRow, nz_indsCol)), shape=(h*w, h*w))
|
89 |
+
return L
|
90 |
+
|
91 |
+
def __replication_padding(self, arr,pad):
|
92 |
+
h,w,c = arr.shape
|
93 |
+
ans = np.zeros((h+pad*2,w+pad*2,c))
|
94 |
+
for i in range(c):
|
95 |
+
ans[:,:,i] = np.pad(arr[:,:,i],pad_width=(pad,pad),mode='edge')
|
96 |
+
return ans
|
97 |
+
|
98 |
+
def __rolling_block(self, A, block=(3, 3)):
|
99 |
+
shape = (A.shape[0] - block[0] + 1, A.shape[1] - block[1] + 1) + block
|
100 |
+
strides = (A.strides[0], A.strides[1]) + A.strides
|
101 |
+
return as_strided(A, shape=shape, strides=strides)
|
utils/photo_wct.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from models.models import VGGEncoder, VGGDecoder
|
11 |
+
|
12 |
+
|
13 |
+
class PhotoWCT(nn.Module):
|
14 |
+
def __init__(self):
|
15 |
+
super(PhotoWCT, self).__init__()
|
16 |
+
self.e1 = VGGEncoder(1)
|
17 |
+
self.d1 = VGGDecoder(1)
|
18 |
+
self.e2 = VGGEncoder(2)
|
19 |
+
self.d2 = VGGDecoder(2)
|
20 |
+
self.e3 = VGGEncoder(3)
|
21 |
+
self.d3 = VGGDecoder(3)
|
22 |
+
self.e4 = VGGEncoder(4)
|
23 |
+
self.d4 = VGGDecoder(4)
|
24 |
+
|
25 |
+
def transform(self, cont_img, styl_img, cont_seg, styl_seg):
|
26 |
+
self.__compute_label_info(cont_seg, styl_seg)
|
27 |
+
|
28 |
+
sF4, sF3, sF2, sF1 = self.e4.forward_multiple(styl_img)
|
29 |
+
|
30 |
+
cF4, cpool_idx, cpool1, cpool_idx2, cpool2, cpool_idx3, cpool3 = self.e4(cont_img)
|
31 |
+
sF4 = sF4.data.squeeze(0)
|
32 |
+
cF4 = cF4.data.squeeze(0)
|
33 |
+
# print(cont_seg)
|
34 |
+
csF4 = self.__feature_wct(cF4, sF4, cont_seg, styl_seg)
|
35 |
+
Im4 = self.d4(csF4, cpool_idx, cpool1, cpool_idx2, cpool2, cpool_idx3, cpool3)
|
36 |
+
|
37 |
+
cF3, cpool_idx, cpool1, cpool_idx2, cpool2 = self.e3(Im4)
|
38 |
+
sF3 = sF3.data.squeeze(0)
|
39 |
+
cF3 = cF3.data.squeeze(0)
|
40 |
+
csF3 = self.__feature_wct(cF3, sF3, cont_seg, styl_seg)
|
41 |
+
Im3 = self.d3(csF3, cpool_idx, cpool1, cpool_idx2, cpool2)
|
42 |
+
|
43 |
+
cF2, cpool_idx, cpool = self.e2(Im3)
|
44 |
+
sF2 = sF2.data.squeeze(0)
|
45 |
+
cF2 = cF2.data.squeeze(0)
|
46 |
+
csF2 = self.__feature_wct(cF2, sF2, cont_seg, styl_seg)
|
47 |
+
Im2 = self.d2(csF2, cpool_idx, cpool)
|
48 |
+
|
49 |
+
cF1 = self.e1(Im2)
|
50 |
+
sF1 = sF1.data.squeeze(0)
|
51 |
+
cF1 = cF1.data.squeeze(0)
|
52 |
+
csF1 = self.__feature_wct(cF1, sF1, cont_seg, styl_seg)
|
53 |
+
Im1 = self.d1(csF1)
|
54 |
+
return Im1
|
55 |
+
|
56 |
+
def __compute_label_info(self, cont_seg, styl_seg):
|
57 |
+
if cont_seg.size == False or styl_seg.size == False:
|
58 |
+
return
|
59 |
+
max_label = np.max(cont_seg) + 1
|
60 |
+
self.label_set = np.unique(cont_seg)
|
61 |
+
self.label_indicator = np.zeros(max_label)
|
62 |
+
for l in self.label_set:
|
63 |
+
# if l==0:
|
64 |
+
# continue
|
65 |
+
is_valid = lambda a, b: a > 10 and b > 10 and a / b < 100 and b / a < 100
|
66 |
+
o_cont_mask = np.where(cont_seg.reshape(cont_seg.shape[0] * cont_seg.shape[1]) == l)
|
67 |
+
o_styl_mask = np.where(styl_seg.reshape(styl_seg.shape[0] * styl_seg.shape[1]) == l)
|
68 |
+
self.label_indicator[l] = is_valid(o_cont_mask[0].size, o_styl_mask[0].size)
|
69 |
+
|
70 |
+
def __feature_wct(self, cont_feat, styl_feat, cont_seg, styl_seg):
|
71 |
+
cont_c, cont_h, cont_w = cont_feat.size(0), cont_feat.size(1), cont_feat.size(2)
|
72 |
+
styl_c, styl_h, styl_w = styl_feat.size(0), styl_feat.size(1), styl_feat.size(2)
|
73 |
+
cont_feat_view = cont_feat.view(cont_c, -1).clone()
|
74 |
+
styl_feat_view = styl_feat.view(styl_c, -1).clone()
|
75 |
+
|
76 |
+
if cont_seg.size == False or styl_seg.size == False:
|
77 |
+
target_feature = self.__wct_core(cont_feat_view, styl_feat_view)
|
78 |
+
else:
|
79 |
+
target_feature = cont_feat.view(cont_c, -1).clone()
|
80 |
+
if len(cont_seg.shape) == 2:
|
81 |
+
t_cont_seg = np.asarray(Image.fromarray(cont_seg).resize((cont_w, cont_h), Image.NEAREST))
|
82 |
+
else:
|
83 |
+
t_cont_seg = np.asarray(Image.fromarray(cont_seg, mode='RGB').resize((cont_w, cont_h), Image.NEAREST))
|
84 |
+
if len(styl_seg.shape) == 2:
|
85 |
+
t_styl_seg = np.asarray(Image.fromarray(styl_seg).resize((styl_w, styl_h), Image.NEAREST))
|
86 |
+
else:
|
87 |
+
t_styl_seg = np.asarray(Image.fromarray(styl_seg, mode='RGB').resize((styl_w, styl_h), Image.NEAREST))
|
88 |
+
|
89 |
+
for l in self.label_set:
|
90 |
+
if self.label_indicator[l] == 0:
|
91 |
+
continue
|
92 |
+
cont_mask = np.where(t_cont_seg.reshape(t_cont_seg.shape[0] * t_cont_seg.shape[1]) == l)
|
93 |
+
styl_mask = np.where(t_styl_seg.reshape(t_styl_seg.shape[0] * t_styl_seg.shape[1]) == l)
|
94 |
+
if cont_mask[0].size <= 0 or styl_mask[0].size <= 0:
|
95 |
+
continue
|
96 |
+
|
97 |
+
cont_indi = torch.LongTensor(cont_mask[0])
|
98 |
+
styl_indi = torch.LongTensor(styl_mask[0])
|
99 |
+
if self.is_cuda:
|
100 |
+
cont_indi = cont_indi.cuda(0)
|
101 |
+
styl_indi = styl_indi.cuda(0)
|
102 |
+
|
103 |
+
cFFG = torch.index_select(cont_feat_view, 1, cont_indi)
|
104 |
+
sFFG = torch.index_select(styl_feat_view, 1, styl_indi)
|
105 |
+
# print(len(cont_indi))
|
106 |
+
# print(len(styl_indi))
|
107 |
+
tmp_target_feature = self.__wct_core(cFFG, sFFG)
|
108 |
+
# print(tmp_target_feature.size())
|
109 |
+
if torch.__version__ >= "0.4.0":
|
110 |
+
# This seems to be a bug in PyTorch 0.4.0 to me.
|
111 |
+
new_target_feature = torch.transpose(target_feature, 1, 0)
|
112 |
+
new_target_feature.index_copy_(0, cont_indi, \
|
113 |
+
torch.transpose(tmp_target_feature,1,0))
|
114 |
+
target_feature = torch.transpose(new_target_feature, 1, 0)
|
115 |
+
else:
|
116 |
+
target_feature.index_copy_(1, cont_indi, tmp_target_feature)
|
117 |
+
|
118 |
+
target_feature = target_feature.view_as(cont_feat)
|
119 |
+
ccsF = target_feature.float().unsqueeze(0)
|
120 |
+
return ccsF
|
121 |
+
|
122 |
+
def __wct_core(self, cont_feat, styl_feat):
|
123 |
+
cFSize = cont_feat.size()
|
124 |
+
c_mean = torch.mean(cont_feat, 1) # c x (h x w)
|
125 |
+
c_mean = c_mean.unsqueeze(1).expand_as(cont_feat)
|
126 |
+
cont_feat = cont_feat - c_mean
|
127 |
+
|
128 |
+
iden = torch.eye(cFSize[0]) # .double()
|
129 |
+
if self.is_cuda:
|
130 |
+
iden = iden.cuda()
|
131 |
+
|
132 |
+
contentConv = torch.mm(cont_feat, cont_feat.t()).div(cFSize[1] - 1) + iden
|
133 |
+
# del iden
|
134 |
+
c_u, c_e, c_v = torch.svd(contentConv, some=False)
|
135 |
+
# c_e2, c_v = torch.eig(contentConv, True)
|
136 |
+
# c_e = c_e2[:,0]
|
137 |
+
|
138 |
+
k_c = cFSize[0]
|
139 |
+
for i in range(cFSize[0] - 1, -1, -1):
|
140 |
+
if c_e[i] >= 0.00001:
|
141 |
+
k_c = i + 1
|
142 |
+
break
|
143 |
+
|
144 |
+
sFSize = styl_feat.size()
|
145 |
+
s_mean = torch.mean(styl_feat, 1)
|
146 |
+
styl_feat = styl_feat - s_mean.unsqueeze(1).expand_as(styl_feat)
|
147 |
+
styleConv = torch.mm(styl_feat, styl_feat.t()).div(sFSize[1] - 1)
|
148 |
+
s_u, s_e, s_v = torch.svd(styleConv, some=False)
|
149 |
+
|
150 |
+
k_s = sFSize[0]
|
151 |
+
for i in range(sFSize[0] - 1, -1, -1):
|
152 |
+
if s_e[i] >= 0.00001:
|
153 |
+
k_s = i + 1
|
154 |
+
break
|
155 |
+
|
156 |
+
c_d = (c_e[0:k_c]).pow(-0.5)
|
157 |
+
step1 = torch.mm(c_v[:, 0:k_c], torch.diag(c_d))
|
158 |
+
step2 = torch.mm(step1, (c_v[:, 0:k_c].t()))
|
159 |
+
whiten_cF = torch.mm(step2, cont_feat)
|
160 |
+
|
161 |
+
s_d = (s_e[0:k_s]).pow(0.5)
|
162 |
+
targetFeature = torch.mm(torch.mm(torch.mm(s_v[:, 0:k_s], torch.diag(s_d)), (s_v[:, 0:k_s].t())), whiten_cF)
|
163 |
+
targetFeature = targetFeature + s_mean.unsqueeze(1).expand_as(targetFeature)
|
164 |
+
return targetFeature
|
165 |
+
|
166 |
+
@property
|
167 |
+
def is_cuda(self):
|
168 |
+
return next(self.parameters()).is_cuda
|
169 |
+
|
170 |
+
def forward(self, *input):
|
171 |
+
pass
|
utils/shared_utils.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from rembg import remove
|
3 |
+
import io
|
4 |
+
|
5 |
+
# Apply the transformations needed
|
6 |
+
from torch import autocast, nn
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch
|
10 |
+
import torchvision.transforms as transforms
|
11 |
+
import torchvision.utils as utils
|
12 |
+
import torch.nn as nn
|
13 |
+
import pyrootutils
|
14 |
+
from PIL import Image
|
15 |
+
import numpy as np
|
16 |
+
from utils.photo_wct import PhotoWCT
|
17 |
+
from utils.photo_smooth import Propagator
|
18 |
+
|
19 |
+
# Load models
|
20 |
+
root = pyrootutils.setup_root(Path.cwd(), pythonpath=True)
|
21 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
+
# Load model
|
23 |
+
p_wct = PhotoWCT()
|
24 |
+
p_wct.load_state_dict(torch.load(root/"models/components/photo_wct.pth"))
|
25 |
+
p_pro = Propagator()
|
26 |
+
stylization_module=p_wct
|
27 |
+
smoothing_module=p_pro
|
28 |
+
|
29 |
+
|
30 |
+
#Dependecies - To be installed -
|
31 |
+
#!pip install replicate
|
32 |
+
#Token - To be authenticated -
|
33 |
+
#API TOKEN - 664474670af075461f85420f7b1d23d18484f826
|
34 |
+
#To be declared as an environment variable -
|
35 |
+
#export REPLICATE_API_TOKEN =
|
36 |
+
import replicate
|
37 |
+
import os
|
38 |
+
import requests
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
def stableDiffusionAPICall(text_prompt):
|
43 |
+
os.environ['REPLICATE_API_TOKEN'] = 'a9f4c06cb9808f42b29637bb60b7b88f106ad5b8'
|
44 |
+
model = replicate.models.get("stability-ai/stable-diffusion")
|
45 |
+
#text_prompt = 'photorealistic, elf fighting Sauron'
|
46 |
+
gen_bg_img = model.predict(prompt=text_prompt)[0]
|
47 |
+
img_data = requests.get(gen_bg_img).content
|
48 |
+
# r_data = binascii.unhexlify(img_data)
|
49 |
+
stream = io.BytesIO(img_data)
|
50 |
+
img = Image.open(stream)
|
51 |
+
del img_data
|
52 |
+
|
53 |
+
return img
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
def memory_limit_image_resize(cont_img):
|
58 |
+
# prevent too small or too big images
|
59 |
+
MINSIZE=400
|
60 |
+
MAXSIZE=800
|
61 |
+
orig_width = cont_img.width
|
62 |
+
orig_height = cont_img.height
|
63 |
+
if max(cont_img.width,cont_img.height) < MINSIZE:
|
64 |
+
if cont_img.width > cont_img.height:
|
65 |
+
cont_img.thumbnail((int(cont_img.width*1.0/cont_img.height*MINSIZE), MINSIZE), Image.BICUBIC)
|
66 |
+
else:
|
67 |
+
cont_img.thumbnail((MINSIZE, int(cont_img.height*1.0/cont_img.width*MINSIZE)), Image.BICUBIC)
|
68 |
+
if min(cont_img.width,cont_img.height) > MAXSIZE:
|
69 |
+
if cont_img.width > cont_img.height:
|
70 |
+
cont_img.thumbnail((MAXSIZE, int(cont_img.height*1.0/cont_img.width*MAXSIZE)), Image.BICUBIC)
|
71 |
+
else:
|
72 |
+
cont_img.thumbnail(((int(cont_img.width*1.0/cont_img.height*MAXSIZE), MAXSIZE)), Image.BICUBIC)
|
73 |
+
print("Resize image: (%d,%d)->(%d,%d)" % (orig_width, orig_height, cont_img.width, cont_img.height))
|
74 |
+
return cont_img.width, cont_img.height
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
def superimpose(input_img,back_img):
|
81 |
+
matte_img = remove(input_img)
|
82 |
+
back_img.paste(matte_img, (0, 0), matte_img)
|
83 |
+
return back_img
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
def style_transfer(cont_img,styl_img):
|
88 |
+
with torch.no_grad():
|
89 |
+
new_cw, new_ch = memory_limit_image_resize(cont_img)
|
90 |
+
new_sw, new_sh = memory_limit_image_resize(styl_img)
|
91 |
+
cont_pilimg = cont_img.copy()
|
92 |
+
cw = cont_pilimg.width
|
93 |
+
ch = cont_pilimg.height
|
94 |
+
cont_img = transforms.ToTensor()(cont_img).unsqueeze(0)
|
95 |
+
styl_img = transforms.ToTensor()(styl_img).unsqueeze(0)
|
96 |
+
|
97 |
+
cont_seg = []
|
98 |
+
styl_seg = []
|
99 |
+
|
100 |
+
if device == 'cuda':
|
101 |
+
cont_img = cont_img.to(device)
|
102 |
+
styl_img = styl_img.to(device)
|
103 |
+
stylization_module.to(device)
|
104 |
+
cont_seg = np.asarray(cont_seg)
|
105 |
+
styl_seg = np.asarray(styl_seg)
|
106 |
+
|
107 |
+
stylized_img = stylization_module.transform(cont_img, styl_img, cont_seg, styl_seg)
|
108 |
+
if ch != new_ch or cw != new_cw:
|
109 |
+
stylized_img = nn.functional.upsample(stylized_img, size=(ch, cw), mode='bilinear')
|
110 |
+
grid = utils.make_grid(stylized_img.data, nrow=1, padding=0)
|
111 |
+
ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
|
112 |
+
stylized_img = Image.fromarray(ndarr)
|
113 |
+
#final_img = smooth_filter(stylized_img, cont_pilimg, f_radius=15, f_edge=1e-1)
|
114 |
+
return stylized_img
|
115 |
+
|
116 |
+
def smoother(stylized_img, over_img):
|
117 |
+
final_img = smoothing_module.process(stylized_img, over_img)
|
118 |
+
return final_img
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == "__main__":
|
122 |
+
root = pyrootutils.setup_root(__file__, pythonpath=True)
|
123 |
+
fg_path = root/"notebooks/profile_new.png"
|
124 |
+
bg_path = root/"notebooks/back_img.png"
|
125 |
+
ckpt_path = root/"src/models/MODNet/pretrained/modnet_photographic_portrait_matting.ckpt"
|
126 |
+
|
127 |
+
#stableDiffusionAPICall("Photorealistic scenery of a concert")
|
128 |
+
fg_img = Image.open(fg_path).resize((800,800))
|
129 |
+
bg_img = Image.open(bg_path).resize((800,800))
|
130 |
+
#img = combined_display(fg_img, bg_img,ckpt_path)
|
131 |
+
img = superimpose(fg_img,bg_img)
|
132 |
+
img.save(root/"notebooks/overlay.png")
|
133 |
+
# bg_img.paste(img, (0, 0), img)
|
134 |
+
# bg_img.save(root/"notebooks/check.png")
|
135 |
+
|
136 |
+
|
utils/smooth_filter.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
src = '''
|
6 |
+
#include "/usr/local/cuda/include/math_functions.h"
|
7 |
+
#define TB 256
|
8 |
+
#define EPS 1e-7
|
9 |
+
|
10 |
+
__device__ bool InverseMat4x4(double m_in[4][4], double inv_out[4][4]) {
|
11 |
+
double m[16], inv[16];
|
12 |
+
for (int i = 0; i < 4; i++) {
|
13 |
+
for (int j = 0; j < 4; j++) {
|
14 |
+
m[i * 4 + j] = m_in[i][j];
|
15 |
+
}
|
16 |
+
}
|
17 |
+
|
18 |
+
inv[0] = m[5] * m[10] * m[15] -
|
19 |
+
m[5] * m[11] * m[14] -
|
20 |
+
m[9] * m[6] * m[15] +
|
21 |
+
m[9] * m[7] * m[14] +
|
22 |
+
m[13] * m[6] * m[11] -
|
23 |
+
m[13] * m[7] * m[10];
|
24 |
+
|
25 |
+
inv[4] = -m[4] * m[10] * m[15] +
|
26 |
+
m[4] * m[11] * m[14] +
|
27 |
+
m[8] * m[6] * m[15] -
|
28 |
+
m[8] * m[7] * m[14] -
|
29 |
+
m[12] * m[6] * m[11] +
|
30 |
+
m[12] * m[7] * m[10];
|
31 |
+
|
32 |
+
inv[8] = m[4] * m[9] * m[15] -
|
33 |
+
m[4] * m[11] * m[13] -
|
34 |
+
m[8] * m[5] * m[15] +
|
35 |
+
m[8] * m[7] * m[13] +
|
36 |
+
m[12] * m[5] * m[11] -
|
37 |
+
m[12] * m[7] * m[9];
|
38 |
+
|
39 |
+
inv[12] = -m[4] * m[9] * m[14] +
|
40 |
+
m[4] * m[10] * m[13] +
|
41 |
+
m[8] * m[5] * m[14] -
|
42 |
+
m[8] * m[6] * m[13] -
|
43 |
+
m[12] * m[5] * m[10] +
|
44 |
+
m[12] * m[6] * m[9];
|
45 |
+
|
46 |
+
inv[1] = -m[1] * m[10] * m[15] +
|
47 |
+
m[1] * m[11] * m[14] +
|
48 |
+
m[9] * m[2] * m[15] -
|
49 |
+
m[9] * m[3] * m[14] -
|
50 |
+
m[13] * m[2] * m[11] +
|
51 |
+
m[13] * m[3] * m[10];
|
52 |
+
|
53 |
+
inv[5] = m[0] * m[10] * m[15] -
|
54 |
+
m[0] * m[11] * m[14] -
|
55 |
+
m[8] * m[2] * m[15] +
|
56 |
+
m[8] * m[3] * m[14] +
|
57 |
+
m[12] * m[2] * m[11] -
|
58 |
+
m[12] * m[3] * m[10];
|
59 |
+
|
60 |
+
inv[9] = -m[0] * m[9] * m[15] +
|
61 |
+
m[0] * m[11] * m[13] +
|
62 |
+
m[8] * m[1] * m[15] -
|
63 |
+
m[8] * m[3] * m[13] -
|
64 |
+
m[12] * m[1] * m[11] +
|
65 |
+
m[12] * m[3] * m[9];
|
66 |
+
|
67 |
+
inv[13] = m[0] * m[9] * m[14] -
|
68 |
+
m[0] * m[10] * m[13] -
|
69 |
+
m[8] * m[1] * m[14] +
|
70 |
+
m[8] * m[2] * m[13] +
|
71 |
+
m[12] * m[1] * m[10] -
|
72 |
+
m[12] * m[2] * m[9];
|
73 |
+
|
74 |
+
inv[2] = m[1] * m[6] * m[15] -
|
75 |
+
m[1] * m[7] * m[14] -
|
76 |
+
m[5] * m[2] * m[15] +
|
77 |
+
m[5] * m[3] * m[14] +
|
78 |
+
m[13] * m[2] * m[7] -
|
79 |
+
m[13] * m[3] * m[6];
|
80 |
+
|
81 |
+
inv[6] = -m[0] * m[6] * m[15] +
|
82 |
+
m[0] * m[7] * m[14] +
|
83 |
+
m[4] * m[2] * m[15] -
|
84 |
+
m[4] * m[3] * m[14] -
|
85 |
+
m[12] * m[2] * m[7] +
|
86 |
+
m[12] * m[3] * m[6];
|
87 |
+
|
88 |
+
inv[10] = m[0] * m[5] * m[15] -
|
89 |
+
m[0] * m[7] * m[13] -
|
90 |
+
m[4] * m[1] * m[15] +
|
91 |
+
m[4] * m[3] * m[13] +
|
92 |
+
m[12] * m[1] * m[7] -
|
93 |
+
m[12] * m[3] * m[5];
|
94 |
+
|
95 |
+
inv[14] = -m[0] * m[5] * m[14] +
|
96 |
+
m[0] * m[6] * m[13] +
|
97 |
+
m[4] * m[1] * m[14] -
|
98 |
+
m[4] * m[2] * m[13] -
|
99 |
+
m[12] * m[1] * m[6] +
|
100 |
+
m[12] * m[2] * m[5];
|
101 |
+
|
102 |
+
inv[3] = -m[1] * m[6] * m[11] +
|
103 |
+
m[1] * m[7] * m[10] +
|
104 |
+
m[5] * m[2] * m[11] -
|
105 |
+
m[5] * m[3] * m[10] -
|
106 |
+
m[9] * m[2] * m[7] +
|
107 |
+
m[9] * m[3] * m[6];
|
108 |
+
|
109 |
+
inv[7] = m[0] * m[6] * m[11] -
|
110 |
+
m[0] * m[7] * m[10] -
|
111 |
+
m[4] * m[2] * m[11] +
|
112 |
+
m[4] * m[3] * m[10] +
|
113 |
+
m[8] * m[2] * m[7] -
|
114 |
+
m[8] * m[3] * m[6];
|
115 |
+
|
116 |
+
inv[11] = -m[0] * m[5] * m[11] +
|
117 |
+
m[0] * m[7] * m[9] +
|
118 |
+
m[4] * m[1] * m[11] -
|
119 |
+
m[4] * m[3] * m[9] -
|
120 |
+
m[8] * m[1] * m[7] +
|
121 |
+
m[8] * m[3] * m[5];
|
122 |
+
|
123 |
+
inv[15] = m[0] * m[5] * m[10] -
|
124 |
+
m[0] * m[6] * m[9] -
|
125 |
+
m[4] * m[1] * m[10] +
|
126 |
+
m[4] * m[2] * m[9] +
|
127 |
+
m[8] * m[1] * m[6] -
|
128 |
+
m[8] * m[2] * m[5];
|
129 |
+
|
130 |
+
double det = m[0] * inv[0] + m[1] * inv[4] + m[2] * inv[8] + m[3] * inv[12];
|
131 |
+
|
132 |
+
if (abs(det) < 1e-9) {
|
133 |
+
return false;
|
134 |
+
}
|
135 |
+
|
136 |
+
|
137 |
+
det = 1.0 / det;
|
138 |
+
|
139 |
+
for (int i = 0; i < 4; i++) {
|
140 |
+
for (int j = 0; j < 4; j++) {
|
141 |
+
inv_out[i][j] = inv[i * 4 + j] * det;
|
142 |
+
}
|
143 |
+
}
|
144 |
+
|
145 |
+
return true;
|
146 |
+
}
|
147 |
+
|
148 |
+
extern "C"
|
149 |
+
__global__ void best_local_affine_kernel(
|
150 |
+
float *output, float *input, float *affine_model,
|
151 |
+
int h, int w, float epsilon, int kernel_radius
|
152 |
+
)
|
153 |
+
{
|
154 |
+
int size = h * w;
|
155 |
+
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
156 |
+
|
157 |
+
if (id < size) {
|
158 |
+
int x = id % w, y = id / w;
|
159 |
+
|
160 |
+
double Mt_M[4][4] = {}; // 4x4
|
161 |
+
double invMt_M[4][4] = {};
|
162 |
+
double Mt_S[3][4] = {}; // RGB -> 1x4
|
163 |
+
double A[3][4] = {};
|
164 |
+
for (int i = 0; i < 4; i++)
|
165 |
+
for (int j = 0; j < 4; j++) {
|
166 |
+
Mt_M[i][j] = 0, invMt_M[i][j] = 0;
|
167 |
+
if (i != 3) {
|
168 |
+
Mt_S[i][j] = 0, A[i][j] = 0;
|
169 |
+
if (i == j)
|
170 |
+
Mt_M[i][j] = 1e-3;
|
171 |
+
}
|
172 |
+
}
|
173 |
+
|
174 |
+
for (int dy = -kernel_radius; dy <= kernel_radius; dy++) {
|
175 |
+
for (int dx = -kernel_radius; dx <= kernel_radius; dx++) {
|
176 |
+
|
177 |
+
int xx = x + dx, yy = y + dy;
|
178 |
+
int id2 = yy * w + xx;
|
179 |
+
|
180 |
+
if (0 <= xx && xx < w && 0 <= yy && yy < h) {
|
181 |
+
|
182 |
+
Mt_M[0][0] += input[id2 + 2*size] * input[id2 + 2*size];
|
183 |
+
Mt_M[0][1] += input[id2 + 2*size] * input[id2 + size];
|
184 |
+
Mt_M[0][2] += input[id2 + 2*size] * input[id2];
|
185 |
+
Mt_M[0][3] += input[id2 + 2*size];
|
186 |
+
|
187 |
+
Mt_M[1][0] += input[id2 + size] * input[id2 + 2*size];
|
188 |
+
Mt_M[1][1] += input[id2 + size] * input[id2 + size];
|
189 |
+
Mt_M[1][2] += input[id2 + size] * input[id2];
|
190 |
+
Mt_M[1][3] += input[id2 + size];
|
191 |
+
|
192 |
+
Mt_M[2][0] += input[id2] * input[id2 + 2*size];
|
193 |
+
Mt_M[2][1] += input[id2] * input[id2 + size];
|
194 |
+
Mt_M[2][2] += input[id2] * input[id2];
|
195 |
+
Mt_M[2][3] += input[id2];
|
196 |
+
|
197 |
+
Mt_M[3][0] += input[id2 + 2*size];
|
198 |
+
Mt_M[3][1] += input[id2 + size];
|
199 |
+
Mt_M[3][2] += input[id2];
|
200 |
+
Mt_M[3][3] += 1;
|
201 |
+
|
202 |
+
Mt_S[0][0] += input[id2 + 2*size] * output[id2 + 2*size];
|
203 |
+
Mt_S[0][1] += input[id2 + size] * output[id2 + 2*size];
|
204 |
+
Mt_S[0][2] += input[id2] * output[id2 + 2*size];
|
205 |
+
Mt_S[0][3] += output[id2 + 2*size];
|
206 |
+
|
207 |
+
Mt_S[1][0] += input[id2 + 2*size] * output[id2 + size];
|
208 |
+
Mt_S[1][1] += input[id2 + size] * output[id2 + size];
|
209 |
+
Mt_S[1][2] += input[id2] * output[id2 + size];
|
210 |
+
Mt_S[1][3] += output[id2 + size];
|
211 |
+
|
212 |
+
Mt_S[2][0] += input[id2 + 2*size] * output[id2];
|
213 |
+
Mt_S[2][1] += input[id2 + size] * output[id2];
|
214 |
+
Mt_S[2][2] += input[id2] * output[id2];
|
215 |
+
Mt_S[2][3] += output[id2];
|
216 |
+
}
|
217 |
+
}
|
218 |
+
}
|
219 |
+
|
220 |
+
bool success = InverseMat4x4(Mt_M, invMt_M);
|
221 |
+
|
222 |
+
for (int i = 0; i < 3; i++) {
|
223 |
+
for (int j = 0; j < 4; j++) {
|
224 |
+
for (int k = 0; k < 4; k++) {
|
225 |
+
A[i][j] += invMt_M[j][k] * Mt_S[i][k];
|
226 |
+
}
|
227 |
+
}
|
228 |
+
}
|
229 |
+
|
230 |
+
for (int i = 0; i < 3; i++) {
|
231 |
+
for (int j = 0; j < 4; j++) {
|
232 |
+
int affine_id = i * 4 + j;
|
233 |
+
affine_model[12 * id + affine_id] = A[i][j];
|
234 |
+
}
|
235 |
+
}
|
236 |
+
}
|
237 |
+
return ;
|
238 |
+
}
|
239 |
+
|
240 |
+
extern "C"
|
241 |
+
__global__ void bilateral_smooth_kernel(
|
242 |
+
float *affine_model, float *filtered_affine_model, float *guide,
|
243 |
+
int h, int w, int kernel_radius, float sigma1, float sigma2
|
244 |
+
)
|
245 |
+
{
|
246 |
+
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
247 |
+
int size = h * w;
|
248 |
+
if (id < size) {
|
249 |
+
int x = id % w;
|
250 |
+
int y = id / w;
|
251 |
+
|
252 |
+
double sum_affine[12] = {};
|
253 |
+
double sum_weight = 0;
|
254 |
+
for (int dx = -kernel_radius; dx <= kernel_radius; dx++) {
|
255 |
+
for (int dy = -kernel_radius; dy <= kernel_radius; dy++) {
|
256 |
+
int yy = y + dy, xx = x + dx;
|
257 |
+
int id2 = yy * w + xx;
|
258 |
+
if (0 <= xx && xx < w && 0 <= yy && yy < h) {
|
259 |
+
float color_diff1 = guide[yy*w + xx] - guide[y*w + x];
|
260 |
+
float color_diff2 = guide[yy*w + xx + size] - guide[y*w + x + size];
|
261 |
+
float color_diff3 = guide[yy*w + xx + 2*size] - guide[y*w + x + 2*size];
|
262 |
+
float color_diff_sqr =
|
263 |
+
(color_diff1*color_diff1 + color_diff2*color_diff2 + color_diff3*color_diff3) / 3;
|
264 |
+
|
265 |
+
float v1 = exp(-(dx * dx + dy * dy) / (2 * sigma1 * sigma1));
|
266 |
+
float v2 = exp(-(color_diff_sqr) / (2 * sigma2 * sigma2));
|
267 |
+
float weight = v1 * v2;
|
268 |
+
|
269 |
+
for (int i = 0; i < 3; i++) {
|
270 |
+
for (int j = 0; j < 4; j++) {
|
271 |
+
int affine_id = i * 4 + j;
|
272 |
+
sum_affine[affine_id] += weight * affine_model[id2*12 + affine_id];
|
273 |
+
}
|
274 |
+
}
|
275 |
+
sum_weight += weight;
|
276 |
+
}
|
277 |
+
}
|
278 |
+
}
|
279 |
+
|
280 |
+
for (int i = 0; i < 3; i++) {
|
281 |
+
for (int j = 0; j < 4; j++) {
|
282 |
+
int affine_id = i * 4 + j;
|
283 |
+
filtered_affine_model[id*12 + affine_id] = sum_affine[affine_id] / sum_weight;
|
284 |
+
}
|
285 |
+
}
|
286 |
+
}
|
287 |
+
return ;
|
288 |
+
}
|
289 |
+
|
290 |
+
|
291 |
+
extern "C"
|
292 |
+
__global__ void reconstruction_best_kernel(
|
293 |
+
float *input, float *filtered_affine_model, float *filtered_best_output,
|
294 |
+
int h, int w
|
295 |
+
)
|
296 |
+
{
|
297 |
+
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
298 |
+
int size = h * w;
|
299 |
+
if (id < size) {
|
300 |
+
double out1 =
|
301 |
+
input[id + 2*size] * filtered_affine_model[id*12 + 0] + // A[0][0] +
|
302 |
+
input[id + size] * filtered_affine_model[id*12 + 1] + // A[0][1] +
|
303 |
+
input[id] * filtered_affine_model[id*12 + 2] + // A[0][2] +
|
304 |
+
filtered_affine_model[id*12 + 3]; //A[0][3];
|
305 |
+
double out2 =
|
306 |
+
input[id + 2*size] * filtered_affine_model[id*12 + 4] + //A[1][0] +
|
307 |
+
input[id + size] * filtered_affine_model[id*12 + 5] + //A[1][1] +
|
308 |
+
input[id] * filtered_affine_model[id*12 + 6] + //A[1][2] +
|
309 |
+
filtered_affine_model[id*12 + 7]; //A[1][3];
|
310 |
+
double out3 =
|
311 |
+
input[id + 2*size] * filtered_affine_model[id*12 + 8] + //A[2][0] +
|
312 |
+
input[id + size] * filtered_affine_model[id*12 + 9] + //A[2][1] +
|
313 |
+
input[id] * filtered_affine_model[id*12 + 10] + //A[2][2] +
|
314 |
+
filtered_affine_model[id*12 + 11]; // A[2][3];
|
315 |
+
|
316 |
+
filtered_best_output[id] = out1;
|
317 |
+
filtered_best_output[id + size] = out2;
|
318 |
+
filtered_best_output[id + 2*size] = out3;
|
319 |
+
}
|
320 |
+
return ;
|
321 |
+
}
|
322 |
+
'''
|
323 |
+
|
324 |
+
import torch
|
325 |
+
import numpy as np
|
326 |
+
from PIL import Image
|
327 |
+
from cupy.cuda import function
|
328 |
+
from pynvrtc.compiler import Program
|
329 |
+
from collections import namedtuple
|
330 |
+
|
331 |
+
|
332 |
+
def smooth_local_affine(output_cpu, input_cpu, epsilon, patch, h, w, f_r, f_e):
|
333 |
+
# program = Program(src.encode('utf-8'), 'best_local_affine_kernel.cu'.encode('utf-8'))
|
334 |
+
# ptx = program.compile(['-I/usr/local/cuda/include'.encode('utf-8')])
|
335 |
+
program = Program(src, 'best_local_affine_kernel.cu')
|
336 |
+
ptx = program.compile(['-I/usr/local/cuda/include'])
|
337 |
+
m = function.Module()
|
338 |
+
m.load(bytes(ptx.encode()))
|
339 |
+
|
340 |
+
_reconstruction_best_kernel = m.get_function('reconstruction_best_kernel')
|
341 |
+
_bilateral_smooth_kernel = m.get_function('bilateral_smooth_kernel')
|
342 |
+
_best_local_affine_kernel = m.get_function('best_local_affine_kernel')
|
343 |
+
Stream = namedtuple('Stream', ['ptr'])
|
344 |
+
s = Stream(ptr=torch.cuda.current_stream().cuda_stream)
|
345 |
+
|
346 |
+
filter_radius = f_r
|
347 |
+
sigma1 = filter_radius / 3
|
348 |
+
sigma2 = f_e
|
349 |
+
radius = (patch - 1) / 2
|
350 |
+
|
351 |
+
filtered_best_output = torch.zeros(np.shape(input_cpu)).cuda()
|
352 |
+
affine_model = torch.zeros((h * w, 12)).cuda()
|
353 |
+
filtered_affine_model =torch.zeros((h * w, 12)).cuda()
|
354 |
+
|
355 |
+
input_ = torch.from_numpy(input_cpu).cuda()
|
356 |
+
output_ = torch.from_numpy(output_cpu).cuda()
|
357 |
+
_best_local_affine_kernel(
|
358 |
+
grid=(int((h * w) / 256 + 1), 1),
|
359 |
+
block=(256, 1, 1),
|
360 |
+
args=[output_.data_ptr(), input_.data_ptr(), affine_model.data_ptr(),
|
361 |
+
np.int32(h), np.int32(w), np.float32(epsilon), np.int32(radius)], stream=s
|
362 |
+
)
|
363 |
+
|
364 |
+
_bilateral_smooth_kernel(
|
365 |
+
grid=(int((h * w) / 256 + 1), 1),
|
366 |
+
block=(256, 1, 1),
|
367 |
+
args=[affine_model.data_ptr(), filtered_affine_model.data_ptr(), input_.data_ptr(), np.int32(h), np.int32(w), np.int32(f_r), np.float32(sigma1), np.float32(sigma2)], stream=s
|
368 |
+
)
|
369 |
+
|
370 |
+
_reconstruction_best_kernel(
|
371 |
+
grid=(int((h * w) / 256 + 1), 1),
|
372 |
+
block=(256, 1, 1),
|
373 |
+
args=[input_.data_ptr(), filtered_affine_model.data_ptr(), filtered_best_output.data_ptr(),
|
374 |
+
np.int32(h), np.int32(w)], stream=s
|
375 |
+
)
|
376 |
+
numpy_filtered_best_output = filtered_best_output.cpu().numpy()
|
377 |
+
return numpy_filtered_best_output
|
378 |
+
|
379 |
+
|
380 |
+
def smooth_filter(initImg, contentImg, f_radius=15,f_edge=1e-1):
|
381 |
+
'''
|
382 |
+
:param initImg: intermediate output. Either image path or PIL Image
|
383 |
+
:param contentImg: content image output. Either path or PIL Image
|
384 |
+
:return: stylized output image. PIL Image
|
385 |
+
'''
|
386 |
+
if type(initImg) == str:
|
387 |
+
initImg = Image.open(initImg).convert("RGB")
|
388 |
+
best_image_bgr = np.array(initImg, dtype=np.float32)
|
389 |
+
bW, bH, bC = best_image_bgr.shape
|
390 |
+
best_image_bgr = best_image_bgr[:, :, ::-1]
|
391 |
+
best_image_bgr = best_image_bgr.transpose((2, 0, 1))
|
392 |
+
|
393 |
+
if type(contentImg) == str:
|
394 |
+
contentImg = Image.open(contentImg).convert("RGB")
|
395 |
+
content_input = contentImg.resize((bH,bW))
|
396 |
+
content_input = np.array(content_input, dtype=np.float32)
|
397 |
+
content_input = content_input[:, :, ::-1]
|
398 |
+
content_input = content_input.transpose((2, 0, 1))
|
399 |
+
input_ = np.ascontiguousarray(content_input, dtype=np.float32) / 255.
|
400 |
+
_, H, W = np.shape(input_)
|
401 |
+
output_ = np.ascontiguousarray(best_image_bgr, dtype=np.float32) / 255.
|
402 |
+
best_ = smooth_local_affine(output_, input_, 1e-7, 3, H, W, f_radius, f_edge)
|
403 |
+
best_ = best_.transpose(1, 2, 0)
|
404 |
+
result = Image.fromarray(np.uint8(np.clip(best_ * 255., 0, 255.)))
|
405 |
+
return result
|