Max Reimann
commited on
Commit
•
6124669
0
Parent(s):
Initial commit of app
Browse files- .gitattributes +37 -0
- .gitmodules +3 -0
- README.md +14 -0
- Whitebox_style_transfer.py +307 -0
- demo_config.py +1 -0
- pages/Apply_preset.py +118 -0
- pages/Local_edits.py +241 -0
- pages/Readme.py +30 -0
- requirements.txt +13 -0
.gitattributes
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
29 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "wise"]
|
2 |
+
path = wise
|
3 |
+
url = https://github.com/winfried-ripken/wise
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: White-box Style Transfer Editing (WISE)
|
3 |
+
emoji: 🎨
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: red
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.10.0
|
8 |
+
app_file: Whitebox_style_transfer.py
|
9 |
+
tags: [Style Transfer,Image Synthesis,Editing,Painting]
|
10 |
+
pinned: false
|
11 |
+
license: mit
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
Whitebox_style_transfer.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import datetime
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from io import BytesIO
|
6 |
+
from pathlib import Path
|
7 |
+
import numpy as np
|
8 |
+
import requests
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
PACKAGE_PARENT = 'wise'
|
14 |
+
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
|
15 |
+
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
|
16 |
+
|
17 |
+
import streamlit as st
|
18 |
+
from streamlit.logger import get_logger
|
19 |
+
from st_click_detector import click_detector
|
20 |
+
import streamlit.components.v1 as components
|
21 |
+
from streamlit_extras.switch_page_button import switch_page
|
22 |
+
|
23 |
+
from demo_config import HUGGING_FACE
|
24 |
+
from parameter_optimization.parametric_styletransfer import single_optimize
|
25 |
+
from parameter_optimization.parametric_styletransfer import CONFIG as ST_CONFIG
|
26 |
+
from parameter_optimization.strotss_org import strotss, pil_resize_long_edge_to
|
27 |
+
import helpers.session_state as session_state
|
28 |
+
from helpers import torch_to_np, np_to_torch
|
29 |
+
from effects import get_default_settings, MinimalPipelineEffect
|
30 |
+
|
31 |
+
st.set_page_config(layout="wide")
|
32 |
+
BASE_URL = "https://ivpg.hpi3d.de/wise/wise-demo/images/"
|
33 |
+
LOGGER = get_logger(__name__)
|
34 |
+
|
35 |
+
effect_type = "minimal_pipeline"
|
36 |
+
|
37 |
+
if "click_counter" not in st.session_state:
|
38 |
+
st.session_state.click_counter = 1
|
39 |
+
|
40 |
+
if "action" not in st.session_state:
|
41 |
+
st.session_state["action"] = ""
|
42 |
+
|
43 |
+
content_urls = [
|
44 |
+
{
|
45 |
+
"name": "Portrait", "id": "portrait",
|
46 |
+
"src": BASE_URL + "/content/portrait.jpeg"
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"name": "Tuebingen", "id": "tubingen",
|
50 |
+
"src": BASE_URL + "/content/tubingen.jpeg"
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"name": "Colibri", "id": "colibri",
|
54 |
+
"src": BASE_URL + "/content/colibri.jpeg"
|
55 |
+
}
|
56 |
+
]
|
57 |
+
|
58 |
+
style_urls = [
|
59 |
+
{
|
60 |
+
"name": "Starry Night, Van Gogh", "id": "starry_night",
|
61 |
+
"src": BASE_URL + "/style/starry_night.jpg"
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"name": "The Scream, Edward Munch", "id": "the_scream",
|
65 |
+
"src": BASE_URL + "/style/the_scream.jpg"
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"name": "The Great Wave, Ukiyo-e", "id": "wave",
|
69 |
+
"src": BASE_URL + "/style/wave.jpg"
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"name": "Woman with Hat, Henry Matisse", "id": "woman_with_hat",
|
73 |
+
"src": BASE_URL + "/style/woman_with_hat.jpg"
|
74 |
+
}
|
75 |
+
]
|
76 |
+
|
77 |
+
|
78 |
+
def last_image_clicked(type="content", action=None, ):
|
79 |
+
kw = "last_image_clicked" + "_" + type
|
80 |
+
if action:
|
81 |
+
session_state.get(**{kw: action})
|
82 |
+
elif kw not in session_state.get():
|
83 |
+
return None
|
84 |
+
else:
|
85 |
+
return session_state.get()[kw]
|
86 |
+
|
87 |
+
|
88 |
+
@st.cache
|
89 |
+
def _retrieve_from_id(clicked, urls):
|
90 |
+
src = [x["src"] for x in urls if x["id"] == clicked][0]
|
91 |
+
img = Image.open(requests.get(src, stream=True).raw)
|
92 |
+
return img, src
|
93 |
+
|
94 |
+
|
95 |
+
def store_img_from_id(clicked, urls, imgtype):
|
96 |
+
img, src = _retrieve_from_id(clicked, urls)
|
97 |
+
session_state.get(**{f"{imgtype}_im": img, f"{imgtype}_render_src": src, f"{imgtype}_id": clicked})
|
98 |
+
|
99 |
+
|
100 |
+
def img_choice_panel(imgtype, urls, default_choice, expanded):
|
101 |
+
with st.expander(f"Select {imgtype} image:", expanded=expanded):
|
102 |
+
html_code = '<div class="column" style="display: flex; flex-wrap: wrap; padding: 0 4px;">'
|
103 |
+
for url in urls:
|
104 |
+
html_code += f"<a href='#' id='{url['id']}' style='padding: 0px 5px'><img height='160px' style='margin-top: 8px;' src='{url['src']}'></a>"
|
105 |
+
html_code += "</div>"
|
106 |
+
clicked = click_detector(html_code)
|
107 |
+
|
108 |
+
if not clicked and st.session_state["action"] not in ("uploaded", "switch_page_from_local_edits", "switch_page_from_presets", "slider_change", "reset"): # default val
|
109 |
+
store_img_from_id(default_choice, urls, imgtype)
|
110 |
+
|
111 |
+
st.write("OR: ")
|
112 |
+
|
113 |
+
with st.form(imgtype + "-form", clear_on_submit=True):
|
114 |
+
uploaded_im = st.file_uploader(f"Load {imgtype} image:", type=["png", "jpg"], )
|
115 |
+
upload_pressed = st.form_submit_button("Upload")
|
116 |
+
|
117 |
+
if upload_pressed and uploaded_im is not None:
|
118 |
+
img = Image.open(uploaded_im)
|
119 |
+
buffered = BytesIO()
|
120 |
+
img.save(buffered, format="JPEG")
|
121 |
+
encoded = base64.b64encode(buffered.getvalue()).decode()
|
122 |
+
# session_state.get(uploaded_im=img, content_render_src=f"data:image/jpeg;base64,{encoded}")
|
123 |
+
session_state.get(**{f"{imgtype}_im": img, f"{imgtype}_render_src": f"data:image/jpeg;base64,{encoded}",
|
124 |
+
f"{imgtype}_id": "uploaded"})
|
125 |
+
st.session_state["action"] = "uploaded"
|
126 |
+
st.write("uploaded.")
|
127 |
+
|
128 |
+
last_clicked = last_image_clicked(type=imgtype)
|
129 |
+
print("last_clicked", last_clicked, "clicked", clicked, "action", st.session_state["action"] )
|
130 |
+
if not upload_pressed and clicked != "": # trigger when no file uploaded
|
131 |
+
if last_clicked != clicked: # only activate when content was actually clicked
|
132 |
+
store_img_from_id(clicked, urls, imgtype)
|
133 |
+
last_image_clicked(type=imgtype, action=clicked)
|
134 |
+
st.session_state["action"] = "clicked"
|
135 |
+
st.session_state.click_counter += 1 # hack to get page to reload at top
|
136 |
+
|
137 |
+
state = session_state.get()
|
138 |
+
st.sidebar.write(f'Selected {imgtype} image:')
|
139 |
+
st.sidebar.markdown(f'<img src="{state[f"{imgtype}_render_src"]}" width=240px></img>', unsafe_allow_html=True)
|
140 |
+
|
141 |
+
|
142 |
+
def optimize(effect, preset, result_image_placeholder):
|
143 |
+
content = st.session_state["Content_im"]
|
144 |
+
style = st.session_state["Style_im"]
|
145 |
+
result_image_placeholder.text("<- Custom content/style needs to be style transferred")
|
146 |
+
optimize_button = st.sidebar.button("Optimize Style Transfer")
|
147 |
+
if optimize_button:
|
148 |
+
if HUGGING_FACE:
|
149 |
+
result_image_placeholder.warning("NST optimization is currently disabled in this HuggingFace Space because it takes ~5min to optimize. To try it out, please clone the repo and change the huggingface variable in demo_config.py")
|
150 |
+
st.stop()
|
151 |
+
|
152 |
+
result_image_placeholder.text("Executing NST to create reference image..")
|
153 |
+
base_dir = f"result/{datetime.datetime.now().strftime(r'%Y-%m-%d %H.%Mh %Ss')}"
|
154 |
+
os.makedirs(base_dir)
|
155 |
+
with st.spinner(text="Running NST"):
|
156 |
+
reference = strotss(pil_resize_long_edge_to(content, 1024),
|
157 |
+
pil_resize_long_edge_to(style, 1024), content_weight=16.0,
|
158 |
+
device=torch.device("cuda"), space="uniform")
|
159 |
+
progress_bar = result_image_placeholder.progress(0.0)
|
160 |
+
ref_save_path = os.path.join(base_dir, "reference.jpg")
|
161 |
+
content_save_path = os.path.join(base_dir, "content.jpg")
|
162 |
+
resize_to = 720
|
163 |
+
reference = pil_resize_long_edge_to(reference, resize_to)
|
164 |
+
reference.save(ref_save_path)
|
165 |
+
content.save(content_save_path)
|
166 |
+
ST_CONFIG["n_iterations"] = 300
|
167 |
+
with st.spinner(text="Optimizing parameters.."):
|
168 |
+
vp, content_img_cuda = single_optimize(effect, preset, "l1", content_save_path, str(ref_save_path),
|
169 |
+
write_video=False, base_dir=base_dir,
|
170 |
+
iter_callback=lambda i: progress_bar.progress(
|
171 |
+
float(i) / ST_CONFIG["n_iterations"]))
|
172 |
+
return content_img_cuda.detach(), vp.cuda().detach()
|
173 |
+
else:
|
174 |
+
if not "result_vp" in st.session_state:
|
175 |
+
st.stop()
|
176 |
+
else:
|
177 |
+
return st.session_state["effect_input"], st.session_state["result_vp"]
|
178 |
+
|
179 |
+
|
180 |
+
@st.cache(hash_funcs={MinimalPipelineEffect: id})
|
181 |
+
def create_effect():
|
182 |
+
effect, preset, param_set = get_default_settings(effect_type)
|
183 |
+
effect.enable_checkpoints()
|
184 |
+
effect.cuda()
|
185 |
+
return effect, preset
|
186 |
+
|
187 |
+
|
188 |
+
def load_visual_params(vp_path: str, img_org: Image, org_cuda: torch.Tensor, effect) -> torch.Tensor:
|
189 |
+
if Path(vp_path).exists():
|
190 |
+
vp = torch.load(vp_path).detach().clone()
|
191 |
+
vp = F.interpolate(vp, (img_org.height, img_org.width))
|
192 |
+
if len(effect.vpd.vp_ranges) == vp.shape[1]:
|
193 |
+
return vp
|
194 |
+
# use preset and save it
|
195 |
+
vp = effect.vpd.preset_tensor(preset, org_cuda, add_local_dims=True)
|
196 |
+
torch.save(vp, vp_path)
|
197 |
+
return vp
|
198 |
+
|
199 |
+
|
200 |
+
# @st.cache(hash_funcs={torch.Tensor: id})
|
201 |
+
@st.experimental_memo
|
202 |
+
def load_params(content_id, style_id):#, effect):
|
203 |
+
preoptim_param_path = os.path.join("precomputed", effect_type, content_id, style_id)
|
204 |
+
img_org = Image.open(os.path.join(preoptim_param_path, "input.png"))
|
205 |
+
content_cuda = np_to_torch(img_org).cuda()
|
206 |
+
vp_path = os.path.join(preoptim_param_path, "vp.pt")
|
207 |
+
vp = load_visual_params(vp_path, img_org, content_cuda, effect)
|
208 |
+
return content_cuda, vp
|
209 |
+
|
210 |
+
|
211 |
+
def render_effect(effect, content_cuda, vp):
|
212 |
+
with torch.no_grad():
|
213 |
+
result_cuda = effect(content_cuda, vp)
|
214 |
+
img_res = Image.fromarray((torch_to_np(result_cuda) * 255.0).astype(np.uint8))
|
215 |
+
return img_res
|
216 |
+
|
217 |
+
|
218 |
+
result_container = st.container()
|
219 |
+
coll1, coll2 = result_container.columns([3,2])
|
220 |
+
coll1.header("Result")
|
221 |
+
coll2.header("Global Edits")
|
222 |
+
result_image_placeholder = coll1.empty()
|
223 |
+
result_image_placeholder.markdown("## loading..")
|
224 |
+
|
225 |
+
img_choice_panel("Content", content_urls, "portrait", expanded=True)
|
226 |
+
img_choice_panel("Style", style_urls, "starry_night", expanded=True)
|
227 |
+
|
228 |
+
state = session_state.get()
|
229 |
+
content_id = state["Content_id"]
|
230 |
+
style_id = state["Style_id"]
|
231 |
+
|
232 |
+
effect, preset = create_effect()
|
233 |
+
|
234 |
+
print("content id, style id", content_id, style_id )
|
235 |
+
if st.session_state["action"] == "uploaded":
|
236 |
+
content_img, _vp = optimize(effect, preset, result_image_placeholder)
|
237 |
+
elif st.session_state["action"] in ("switch_page_from_local_edits", "switch_page_from_presets", "slider_change") or \
|
238 |
+
content_id == "uploaded" or style_id == "uploaded":
|
239 |
+
print("restore param")
|
240 |
+
_vp = st.session_state["result_vp"]
|
241 |
+
content_img = st.session_state["effect_input"]
|
242 |
+
else:
|
243 |
+
print("load_params")
|
244 |
+
content_img, _vp = load_params(content_id, style_id)#, effect)
|
245 |
+
|
246 |
+
vp = torch.clone(_vp)
|
247 |
+
|
248 |
+
|
249 |
+
def reset_params(means, names):
|
250 |
+
for i, name in enumerate(names):
|
251 |
+
st.session_state["slider_" + name] = means[i]
|
252 |
+
|
253 |
+
def on_slider():
|
254 |
+
st.session_state["action"] = "slider_change"
|
255 |
+
|
256 |
+
|
257 |
+
with coll2:
|
258 |
+
show_params_names = [ 'bumpScale', "bumpOpacity", "contourOpacity"]
|
259 |
+
display_means = []
|
260 |
+
def create_slider(name):
|
261 |
+
mean = torch.mean(vp[:, effect.vpd.name2idx[name]]).item()
|
262 |
+
display_mean = mean + 0.5
|
263 |
+
display_means.append(display_mean)
|
264 |
+
if "slider_" + name not in st.session_state or st.session_state["action"] != "slider_change":
|
265 |
+
st.session_state["slider_" + name] = display_mean
|
266 |
+
slider = st.slider(f"Mean {name}: ", 0.0, 1.0, step=0.05, key="slider_" + name, on_change=on_slider)
|
267 |
+
vp[:, effect.vpd.name2idx[name]] += slider - display_mean
|
268 |
+
vp.clamp_(-0.5, 0.5)
|
269 |
+
|
270 |
+
for name in show_params_names:
|
271 |
+
create_slider(name)
|
272 |
+
|
273 |
+
others_idx = set(range(len(effect.vpd.vp_ranges))) - set([effect.vpd.name2idx[name] for name in show_params_names])
|
274 |
+
others_names = [effect.vpd.vp_ranges[i][0] for i in sorted(list(others_idx))]
|
275 |
+
other_param = st.selectbox("Other parameters: ", others_names)
|
276 |
+
create_slider(other_param)
|
277 |
+
|
278 |
+
|
279 |
+
reset_button = st.button("Reset Parameters", on_click=reset_params, args=(display_means, show_params_names))
|
280 |
+
if reset_button:
|
281 |
+
st.session_state["action"] = "reset"
|
282 |
+
st.experimental_rerun()
|
283 |
+
|
284 |
+
edit_locally_btn = st.button("Edit Local Parameter Maps")
|
285 |
+
if edit_locally_btn:
|
286 |
+
switch_page("Local_edits")
|
287 |
+
|
288 |
+
img_res = render_effect(effect, content_img, vp)
|
289 |
+
|
290 |
+
st.session_state["result_vp"] = vp
|
291 |
+
st.session_state["effect_input"] = content_img
|
292 |
+
st.session_state["last_result"] = img_res
|
293 |
+
|
294 |
+
with coll1:
|
295 |
+
# width = int(img_res.width * 500 / img_res.height)
|
296 |
+
result_image_placeholder.image(img_res)#, width=width)
|
297 |
+
|
298 |
+
# a bit hacky way to return focus to top of page after clicking on images
|
299 |
+
components.html(
|
300 |
+
f"""
|
301 |
+
<p>{st.session_state.click_counter}</p>
|
302 |
+
<script>
|
303 |
+
window.parent.document.querySelector('section.main').scrollTo(0, 0);
|
304 |
+
</script>
|
305 |
+
""",
|
306 |
+
height=0
|
307 |
+
)
|
demo_config.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
HUGGING_FACE=True # if run in hugging face. Disables some things like full NST optimization
|
pages/Apply_preset.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch
|
5 |
+
|
6 |
+
PACKAGE_PARENT = '../wise/'
|
7 |
+
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
|
8 |
+
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
|
9 |
+
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
import streamlit as st
|
14 |
+
from streamlit_drawable_canvas import st_canvas
|
15 |
+
|
16 |
+
from effects.minimal_pipeline import MinimalPipelineEffect
|
17 |
+
from helpers.visual_parameter_def import minimal_pipeline_presets, minimal_pipeline_bump_mapping_preset, minimal_pipeline_xdog_preset
|
18 |
+
from helpers import torch_to_np, np_to_torch
|
19 |
+
from effects import get_default_settings
|
20 |
+
|
21 |
+
st.set_page_config(page_title="Preset Edit Demo", layout="wide")
|
22 |
+
|
23 |
+
|
24 |
+
# @st.cache(hash_funcs={OilPaintEffect: id})
|
25 |
+
@st.cache(hash_funcs={MinimalPipelineEffect: id})
|
26 |
+
def local_edits_create_effect():
|
27 |
+
effect, preset, param_set = get_default_settings("minimal_pipeline")
|
28 |
+
effect.enable_checkpoints()
|
29 |
+
effect.cuda()
|
30 |
+
return effect, param_set
|
31 |
+
|
32 |
+
|
33 |
+
effect, param_set = local_edits_create_effect()
|
34 |
+
presets = {
|
35 |
+
"original": minimal_pipeline_presets,
|
36 |
+
"bump mapped": minimal_pipeline_bump_mapping_preset,
|
37 |
+
"contoured": minimal_pipeline_xdog_preset
|
38 |
+
}
|
39 |
+
|
40 |
+
st.session_state["action"] = "switch_page_from_presets" # on switchback, remember effect input
|
41 |
+
|
42 |
+
active_preset = st.sidebar.selectbox("apply preset: ", ["original", "bump mapped", "contoured"])
|
43 |
+
blend_strength = st.sidebar.slider("Parameter blending strength (non-hue) : ", 0.0, 1.0, 1.0, 0.05)
|
44 |
+
hue_blend_strength = st.sidebar.slider("Hue-shift blending strength : ", 0.0, 1.0, 1.0, 0.05)
|
45 |
+
|
46 |
+
st.sidebar.text("Drawing options:")
|
47 |
+
stroke_width = st.sidebar.slider("Stroke width: ", 1, 80, 40)
|
48 |
+
drawing_mode = st.sidebar.selectbox(
|
49 |
+
"Drawing tool:", ("freedraw", "line", "rect", "circle", "transform")
|
50 |
+
)
|
51 |
+
|
52 |
+
st.session_state["preset_canvas_key"] ="preset_canvas"
|
53 |
+
|
54 |
+
vp = torch.clone(st.session_state["result_vp"])
|
55 |
+
org_cuda = st.session_state["effect_input"]
|
56 |
+
|
57 |
+
@st.experimental_memo
|
58 |
+
def greyscale_original(_org_cuda, content_id): #content_id is used for hashing
|
59 |
+
if HUGGING_FACE:
|
60 |
+
wsize = 450
|
61 |
+
img_org_height, img_org_width = _org_cuda.shape[-2:]
|
62 |
+
wpercent = (wsize / float(img_org_width))
|
63 |
+
hsize = int((float(img_org_height) * float(wpercent)))
|
64 |
+
else:
|
65 |
+
longest_edge = 670
|
66 |
+
img_org_height, img_org_width = _org_cuda.shape[-2:]
|
67 |
+
max_width_height = max(img_org_width, img_org_height)
|
68 |
+
hsize = int((float(longest_edge) * float(float(img_org_height) / max_width_height)))
|
69 |
+
wsize = int((float(longest_edge) * float(float(img_org_width) / max_width_height)))
|
70 |
+
|
71 |
+
org_img = F.interpolate(_org_cuda, (hsize, wsize), mode="bilinear")
|
72 |
+
org_img = torch.mean(org_img, dim=1, keepdim=True) / 2.0
|
73 |
+
org_img = torch_to_np(org_img, multiply_by_255=True)[..., np.newaxis].repeat(3, axis=2)
|
74 |
+
org_img = Image.fromarray(org_img.astype(np.uint8))
|
75 |
+
return org_img, hsize, wsize
|
76 |
+
|
77 |
+
greyscale_img, hsize, wsize = greyscale_original(org_cuda, st.session_state["Content_id"])
|
78 |
+
|
79 |
+
coll1, coll2 = st.columns(2)
|
80 |
+
coll1.header("Draw Mask")
|
81 |
+
coll2.header("Live Result")
|
82 |
+
|
83 |
+
with coll1:
|
84 |
+
# Create a canvas component
|
85 |
+
canvas_result = st_canvas(
|
86 |
+
fill_color="rgba(0, 0, 0, 1)", # Fixed fill color with some opacity
|
87 |
+
stroke_width=stroke_width,
|
88 |
+
background_image=greyscale_img,
|
89 |
+
width=greyscale_img.width,
|
90 |
+
height=greyscale_img.height,
|
91 |
+
drawing_mode=drawing_mode,
|
92 |
+
key=st.session_state["preset_canvas_key"]
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
res_data = None
|
97 |
+
if canvas_result.image_data is not None:
|
98 |
+
abc = np_to_torch(canvas_result.image_data.astype(np.float)).sum(dim=1, keepdim=True).cuda()
|
99 |
+
|
100 |
+
img_org_width = org_cuda.shape[-1]
|
101 |
+
img_org_height = org_cuda.shape[-2]
|
102 |
+
res_data = F.interpolate(abc, (img_org_height, img_org_width)).squeeze(1)
|
103 |
+
|
104 |
+
preset_tensor = effect.vpd.preset_tensor(presets[active_preset], org_cuda, add_local_dims=True)
|
105 |
+
hue = torch.clone(vp[:,effect.vpd.name2idx["hueShift"]])
|
106 |
+
vp[:] = preset_tensor * res_data * blend_strength + vp[:] * (1 - res_data * blend_strength)
|
107 |
+
vp[:, effect.vpd.name2idx["hueShift"]] = \
|
108 |
+
preset_tensor[:,effect.vpd.name2idx["hueShift"]] * res_data * hue_blend_strength + hue * (1 - res_data * hue_blend_strength)
|
109 |
+
|
110 |
+
with torch.no_grad():
|
111 |
+
result_cuda = effect(org_cuda, vp)
|
112 |
+
|
113 |
+
img_res = Image.fromarray((torch_to_np(result_cuda) * 255.0).astype(np.uint8))
|
114 |
+
coll2.image(img_res)
|
115 |
+
|
116 |
+
apply_btn = st.sidebar.button("Apply")
|
117 |
+
if apply_btn:
|
118 |
+
st.session_state["result_vp"] = vp
|
pages/Local_edits.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib
|
8 |
+
from matplotlib import pyplot as plt
|
9 |
+
import matplotlib.cm
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
import streamlit as st
|
13 |
+
from streamlit_drawable_canvas import st_canvas
|
14 |
+
|
15 |
+
from .. import demo_config
|
16 |
+
from demo_config import HUGGING_FACE
|
17 |
+
|
18 |
+
PACKAGE_PARENT = '../wise/'
|
19 |
+
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
|
20 |
+
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
from effects.gauss2d_xy_separated import Gauss2DEffect
|
25 |
+
from effects.minimal_pipeline import MinimalPipelineEffect
|
26 |
+
from helpers import torch_to_np, np_to_torch
|
27 |
+
from effects import get_default_settings
|
28 |
+
|
29 |
+
st.set_page_config(page_title="Editing Demo", layout="wide")
|
30 |
+
|
31 |
+
# @st.cache(hash_funcs={OilPaintEffect: id})
|
32 |
+
@st.cache(hash_funcs={MinimalPipelineEffect: id})
|
33 |
+
def local_edits_create_effect():
|
34 |
+
effect, preset, param_set = get_default_settings("minimal_pipeline")
|
35 |
+
effect.enable_checkpoints()
|
36 |
+
effect.cuda()
|
37 |
+
return effect, param_set
|
38 |
+
|
39 |
+
|
40 |
+
effect, param_set = local_edits_create_effect()
|
41 |
+
|
42 |
+
@st.experimental_memo
|
43 |
+
def gen_param_strength_fig():
|
44 |
+
cmap = matplotlib.cm.get_cmap('plasma')
|
45 |
+
# cmap show
|
46 |
+
gradient = np.linspace(0, 1, 256)
|
47 |
+
gradient = np.vstack((gradient, gradient))
|
48 |
+
fig, ax = plt.subplots(figsize=(3, 0.1))
|
49 |
+
fig.patch.set_alpha(0.0)
|
50 |
+
ax.set_title("parameter strength", fontsize=6.5, loc="left")
|
51 |
+
ax.imshow(gradient, aspect='auto', cmap=cmap)
|
52 |
+
ax.set_axis_off()
|
53 |
+
return fig, cmap
|
54 |
+
|
55 |
+
cmap_fig, cmap = gen_param_strength_fig()
|
56 |
+
|
57 |
+
st.session_state["canvas_key"] = "canvas"
|
58 |
+
try:
|
59 |
+
vp = st.session_state["result_vp"]
|
60 |
+
org_cuda = st.session_state["effect_input"]
|
61 |
+
except KeyError as e:
|
62 |
+
print("init run, certain keys not found. If this happens once its ok.")
|
63 |
+
|
64 |
+
if st.session_state["action"] != "switch_page_from_local_edits":
|
65 |
+
st.session_state.local_edit_action = "init"
|
66 |
+
|
67 |
+
st.session_state["action"] = "switch_page_from_local_edits" # on switchback, remember effect input
|
68 |
+
|
69 |
+
if "mask_edit_counter" not in st.session_state:
|
70 |
+
st.session_state["mask_edit_counter"] = 1
|
71 |
+
if "initial_drawing" not in st.session_state:
|
72 |
+
st.session_state["initial_drawing"] = {"random": st.session_state["mask_edit_counter"], "background": "#eee"}
|
73 |
+
|
74 |
+
def on_slider_change():
|
75 |
+
if st.session_state.local_edit_action == "init":
|
76 |
+
st.stop()
|
77 |
+
st.session_state.local_edit_action = "slider"
|
78 |
+
|
79 |
+
def on_param_change():
|
80 |
+
st.session_state.local_edit_action = "param_change"
|
81 |
+
|
82 |
+
active_param = st.sidebar.selectbox("active parameter: ", param_set + ["smooth"], index=2, on_change=on_param_change)
|
83 |
+
|
84 |
+
st.sidebar.text("Drawing options")
|
85 |
+
if active_param != "smooth":
|
86 |
+
plus_or_minus = st.sidebar.slider("Increase or decrease param map: ", -1.0, 1.0, 0.8, 0.05,
|
87 |
+
on_change=on_slider_change)
|
88 |
+
else:
|
89 |
+
sigma = st.sidebar.slider("Sigma: ", 0.1, 10.0, 0.5, 0.1, on_change=on_slider_change)
|
90 |
+
|
91 |
+
stroke_width = st.sidebar.slider("Stroke width: ", 1, 50, 20, on_change=on_slider_change)
|
92 |
+
drawing_mode = st.sidebar.selectbox(
|
93 |
+
"Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"), on_change=on_slider_change,
|
94 |
+
)
|
95 |
+
|
96 |
+
st.sidebar.text("Viewing options")
|
97 |
+
if active_param != "smooth":
|
98 |
+
overlay = st.sidebar.slider("show parameter overlay: ", 0.0, 1.0, 0.8, 0.02, on_change=on_slider_change)
|
99 |
+
st.sidebar.pyplot(cmap_fig, bbox_inches='tight', pad_inches=0)
|
100 |
+
|
101 |
+
st.sidebar.text("Update:")
|
102 |
+
realtime_update = st.sidebar.checkbox("Update in realtime", True)
|
103 |
+
clear_after_draw = st.sidebar.checkbox("Clear Canvas after each Stroke", False)
|
104 |
+
invert_selection = st.sidebar.checkbox("Invert Selection", False)
|
105 |
+
|
106 |
+
|
107 |
+
@st.experimental_memo
|
108 |
+
def greyscale_org(_org_cuda, content_id): #content_id is used for hashing
|
109 |
+
if HUGGING_FACE:
|
110 |
+
wsize = 450
|
111 |
+
img_org_height, img_org_width = _org_cuda.shape[-2:]
|
112 |
+
wpercent = (wsize / float(img_org_width))
|
113 |
+
hsize = int((float(img_org_height) * float(wpercent)))
|
114 |
+
else:
|
115 |
+
longest_edge = 670
|
116 |
+
img_org_height, img_org_width = _org_cuda.shape[-2:]
|
117 |
+
max_width_height = max(img_org_width, img_org_height)
|
118 |
+
hsize = int((float(longest_edge) * float(float(img_org_height) / max_width_height)))
|
119 |
+
wsize = int((float(longest_edge) * float(float(img_org_width) / max_width_height)))
|
120 |
+
|
121 |
+
org_img = F.interpolate(_org_cuda, (hsize, wsize), mode="bilinear")
|
122 |
+
org_img = torch.mean(org_img, dim=1, keepdim=True) / 2.0
|
123 |
+
org_img = torch_to_np(org_img)[..., np.newaxis].repeat(3, axis=2)
|
124 |
+
return org_img, hsize, wsize
|
125 |
+
|
126 |
+
def generate_param_mask(vp):
|
127 |
+
greyscale_img, hsize, wsize = greyscale_org(org_cuda, st.session_state["Content_id"])
|
128 |
+
if active_param != "smooth":
|
129 |
+
scaled_vp = F.interpolate(vp, (hsize, wsize))[:, effect.vpd.name2idx[active_param]]
|
130 |
+
param_cmapped = cmap((scaled_vp + 0.5).cpu().numpy())[...,:3][0]
|
131 |
+
greyscale_img = greyscale_img * (1 - overlay) + param_cmapped * overlay
|
132 |
+
return Image.fromarray((greyscale_img * 255).astype(np.uint8))
|
133 |
+
|
134 |
+
def compute_results(_vp):
|
135 |
+
if "cached_canvas" in st.session_state and st.session_state["cached_canvas"].image_data is not None:
|
136 |
+
canvas_result = st.session_state["cached_canvas"]
|
137 |
+
abc = np_to_torch(canvas_result.image_data.astype(np.float32)).sum(dim=1, keepdim=True).cuda()
|
138 |
+
|
139 |
+
if invert_selection:
|
140 |
+
abc = abc * (- 1.0) + 1.0
|
141 |
+
|
142 |
+
img_org_width = org_cuda.shape[-1]
|
143 |
+
img_org_height = org_cuda.shape[-2]
|
144 |
+
res_data = F.interpolate(abc, (img_org_height, img_org_width)).squeeze(1)
|
145 |
+
|
146 |
+
if active_param != "smooth":
|
147 |
+
_vp[:, effect.vpd.name2idx[active_param]] += plus_or_minus * res_data
|
148 |
+
_vp.clamp_(-0.5, 0.5)
|
149 |
+
else:
|
150 |
+
gauss2dx = Gauss2DEffect(dxdy=[1.0, 0.0], dim_kernsize=5)
|
151 |
+
gauss2dy = Gauss2DEffect(dxdy=[0.0, 1.0], dim_kernsize=5)
|
152 |
+
|
153 |
+
vp_smoothed = gauss2dx(_vp, torch.tensor(sigma).cuda())
|
154 |
+
vp_smoothed = gauss2dy(vp_smoothed, torch.tensor(sigma).cuda())
|
155 |
+
|
156 |
+
print(res_data.shape)
|
157 |
+
print(_vp.shape)
|
158 |
+
print(vp_smoothed.shape)
|
159 |
+
_vp = torch.lerp(_vp, vp_smoothed, res_data.unsqueeze(1))
|
160 |
+
|
161 |
+
with torch.no_grad():
|
162 |
+
result_cuda = effect(org_cuda, _vp)
|
163 |
+
|
164 |
+
_, hsize, wsize = greyscale_org(org_cuda, st.session_state["Content_id"])
|
165 |
+
result_cuda = F.interpolate(result_cuda, (hsize, wsize), mode="bilinear")
|
166 |
+
|
167 |
+
return Image.fromarray((torch_to_np(result_cuda) * 255.0).astype(np.uint8)), _vp
|
168 |
+
|
169 |
+
coll1, coll2 = st.columns(2)
|
170 |
+
coll1.header("Draw Mask:")
|
171 |
+
coll2.header("Live Result")
|
172 |
+
|
173 |
+
# there is no way of removing the canvas history/state without rerunning the whole program.
|
174 |
+
# therefore, giving the canvas a initial_drawing that differs from the canvas state will clear the background
|
175 |
+
def mark_canvas_for_redraw():
|
176 |
+
print("mark for redraw")
|
177 |
+
st.session_state["mask_edit_counter"] += 1 # change state of initial drawing
|
178 |
+
initial_drawing = {"random": st.session_state["mask_edit_counter"], "background": "#eee"}
|
179 |
+
st.session_state["initial_drawing"] = initial_drawing
|
180 |
+
|
181 |
+
|
182 |
+
with coll1:
|
183 |
+
print("edit action", st.session_state.local_edit_action)
|
184 |
+
if clear_after_draw and st.session_state.local_edit_action not in ("slider", "param_change", "init"):
|
185 |
+
if st.session_state.local_edit_action == "redraw":
|
186 |
+
st.session_state.local_edit_action = "draw"
|
187 |
+
mark_canvas_for_redraw()
|
188 |
+
else:
|
189 |
+
st.session_state.local_edit_action = "redraw"
|
190 |
+
|
191 |
+
mask = generate_param_mask(st.session_state["result_vp"])
|
192 |
+
st.session_state["last_mask"] = mask
|
193 |
+
|
194 |
+
# Create a canvas component
|
195 |
+
canvas_result = st_canvas(
|
196 |
+
fill_color="rgba(0, 0, 0, 1)",
|
197 |
+
stroke_width=stroke_width,
|
198 |
+
background_image=mask,
|
199 |
+
update_streamlit=realtime_update,
|
200 |
+
width=mask.width,
|
201 |
+
height=mask.height,
|
202 |
+
initial_drawing=st.session_state["initial_drawing"],
|
203 |
+
drawing_mode=drawing_mode,
|
204 |
+
key=st.session_state.canvas_key,
|
205 |
+
)
|
206 |
+
|
207 |
+
if canvas_result.json_data is None:
|
208 |
+
print("stops")
|
209 |
+
st.stop()
|
210 |
+
|
211 |
+
st.session_state["cached_canvas"] = canvas_result
|
212 |
+
|
213 |
+
print("compute result")
|
214 |
+
img_res, vp = compute_results(vp)
|
215 |
+
st.session_state["last_result"] = img_res
|
216 |
+
st.session_state["result_vp"] = vp
|
217 |
+
|
218 |
+
st.markdown("### Mask: " + active_param)
|
219 |
+
|
220 |
+
if st.session_state.local_edit_action in ("slider", "param_change", "init"):
|
221 |
+
print("set redraw")
|
222 |
+
st.session_state.local_edit_action = "redraw"
|
223 |
+
|
224 |
+
|
225 |
+
print("plot masks")
|
226 |
+
texts = []
|
227 |
+
preview_masks = []
|
228 |
+
img = st.session_state["last_mask"]
|
229 |
+
for i, p in enumerate(param_set):
|
230 |
+
idx = effect.vpd.name2idx[p]
|
231 |
+
iii = F.interpolate(vp[:, idx:idx + 1] + 0.5, (int(img.height * 0.2), int(img.width * 0.2)))
|
232 |
+
texts.append(p[:15])
|
233 |
+
preview_masks.append(torch_to_np(iii))
|
234 |
+
|
235 |
+
coll2.image(img_res) # , use_column_width="auto")
|
236 |
+
ppp = st.columns(len(param_set))
|
237 |
+
for i, (txt, im) in enumerate(zip(texts, preview_masks)):
|
238 |
+
ppp[i].text(txt)
|
239 |
+
ppp[i].image(im, clamp=True)
|
240 |
+
|
241 |
+
print("....")
|
pages/Readme.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
st.title("White-box Style Transfer Editing")
|
4 |
+
|
5 |
+
st.markdown("""
|
6 |
+
This app demonstrates the editing capabilities of the White-box Style Transfer Editing (WISE) framework.
|
7 |
+
It optimizes the parameters of classical image processing filters to match a given style image.
|
8 |
+
After optimization, parameters can be tuned by hand to achieve a desired look.
|
9 |
+
|
10 |
+
### How does it work?
|
11 |
+
We provide a small stylization effect that contains several filters such as bump mapping or edge enhancement that can be optimized. The optimization yields so-called parameter masks, which contain per pixel parameter settings of each filter.
|
12 |
+
|
13 |
+
### How to use the app ?
|
14 |
+
- On the first page select existing content/style combinations or upload images to optimize.
|
15 |
+
- After the effect has been applied, use the parameter sliders to adjust a parameter value globally
|
16 |
+
- On the "apply preset" page, we defined several parameter presets that can be drawn on the image. Press "Apply" to make the changes permanent
|
17 |
+
- On the " local editing" page, individual parameter masks can be edited regionally. Choose the parameter on the left sidebar, and use the parameter strength slider to either increase or decrease the strength of the drawn strokes
|
18 |
+
- Strokes on the drawing canvas (left column) are updated in real-time on the result in the right column.
|
19 |
+
- Strokes stay on the canvas unless manually deleted by clicking the trash button. To remove them from the canvas after each stroke, tick the corresponding checkbox in the sidebar.
|
20 |
+
|
21 |
+
### Links & Paper
|
22 |
+
[Project page](https://ivpg.hpi3d.de/wise/),
|
23 |
+
[arxiv link](https://arxiv.org/abs/2207.14606)
|
24 |
+
|
25 |
+
"WISE: Whitebox Image Stylization by Example-based Learning", by Winfried Lötzsch*, Max Reimann*, Martin Büßemeyer, Amir Semmo, Jürgen Döllner, Matthias Trapp, in ECCV 2022
|
26 |
+
|
27 |
+
### Further notes
|
28 |
+
Pull Requests and further improvements are very welcome.
|
29 |
+
Please note that the shown effect is a minimal pipeline in terms of stylization capability, the much more feature-rich oilpaint and watercolor pipelines we show in our ECCV paper cannot be open-sourced due to IP reasons.
|
30 |
+
""")
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
imageio
|
2 |
+
imageio-ffmpeg
|
3 |
+
matplotlib
|
4 |
+
Pillow
|
5 |
+
numpy
|
6 |
+
--extra-index-url https://download.pytorch.org/whl/cu113
|
7 |
+
torch
|
8 |
+
torchvision
|
9 |
+
streamlit==1.10.0
|
10 |
+
streamlit_drawable_canvas==0.8.0
|
11 |
+
streamlit_extras==0.1.5
|
12 |
+
st_click_detector
|
13 |
+
scipy
|