|
import base64 |
|
import datetime |
|
import os |
|
import sys |
|
from io import BytesIO |
|
from pathlib import Path |
|
import numpy as np |
|
import requests |
|
import torch |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
|
|
PACKAGE_PARENT = 'wise' |
|
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__)))) |
|
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT))) |
|
|
|
import streamlit as st |
|
from streamlit.logger import get_logger |
|
from st_click_detector import click_detector |
|
import streamlit.components.v1 as components |
|
from streamlit.source_util import get_pages |
|
from streamlit_extras.switch_page_button import switch_page |
|
|
|
from demo_config import HUGGING_FACE |
|
from parameter_optimization.parametric_styletransfer import single_optimize |
|
from parameter_optimization.parametric_styletransfer import CONFIG as ST_CONFIG |
|
from parameter_optimization.strotss_org import strotss, pil_resize_long_edge_to |
|
import helpers.session_state as session_state |
|
from helpers import torch_to_np, np_to_torch |
|
from effects import get_default_settings, MinimalPipelineEffect |
|
|
|
st.set_page_config(layout="wide") |
|
BASE_URL = "https://ivpg.hpi3d.de/wise/wise-demo/images/" |
|
LOGGER = get_logger(__name__) |
|
|
|
effect_type = "minimal_pipeline" |
|
|
|
if "click_counter" not in st.session_state: |
|
st.session_state.click_counter = 1 |
|
|
|
if "action" not in st.session_state: |
|
st.session_state["action"] = "" |
|
|
|
content_urls = [ |
|
{ |
|
"name": "Portrait", "id": "portrait", |
|
"src": BASE_URL + "/content/portrait.jpeg" |
|
}, |
|
{ |
|
"name": "Tuebingen", "id": "tubingen", |
|
"src": BASE_URL + "/content/tubingen.jpeg" |
|
}, |
|
{ |
|
"name": "Colibri", "id": "colibri", |
|
"src": BASE_URL + "/content/colibri.jpeg" |
|
} |
|
] |
|
|
|
style_urls = [ |
|
{ |
|
"name": "Starry Night, Van Gogh", "id": "starry_night", |
|
"src": BASE_URL + "/style/starry_night.jpg" |
|
}, |
|
{ |
|
"name": "The Scream, Edward Munch", "id": "the_scream", |
|
"src": BASE_URL + "/style/the_scream.jpg" |
|
}, |
|
{ |
|
"name": "The Great Wave, Ukiyo-e", "id": "wave", |
|
"src": BASE_URL + "/style/wave.jpg" |
|
}, |
|
{ |
|
"name": "Woman with Hat, Henry Matisse", "id": "woman_with_hat", |
|
"src": BASE_URL + "/style/woman_with_hat.jpg" |
|
} |
|
] |
|
|
|
|
|
def last_image_clicked(type="content", action=None, ): |
|
kw = "last_image_clicked" + "_" + type |
|
if action: |
|
session_state.get(**{kw: action}) |
|
elif kw not in session_state.get(): |
|
return None |
|
else: |
|
return session_state.get()[kw] |
|
|
|
|
|
@st.cache |
|
def _retrieve_from_id(clicked, urls): |
|
src = [x["src"] for x in urls if x["id"] == clicked][0] |
|
img = Image.open(requests.get(src, stream=True).raw) |
|
return img, src |
|
|
|
|
|
def store_img_from_id(clicked, urls, imgtype): |
|
img, src = _retrieve_from_id(clicked, urls) |
|
session_state.get(**{f"{imgtype}_im": img, f"{imgtype}_render_src": src, f"{imgtype}_id": clicked}) |
|
|
|
|
|
def img_choice_panel(imgtype, urls, default_choice, expanded): |
|
with st.expander(f"Select {imgtype} image:", expanded=expanded): |
|
html_code = '<div class="column" style="display: flex; flex-wrap: wrap; padding: 0 4px;">' |
|
for url in urls: |
|
html_code += f"<a href='#' id='{url['id']}' style='padding: 0px 5px'><img height='160px' style='margin-top: 8px;' src='{url['src']}'></a>" |
|
html_code += "</div>" |
|
clicked = click_detector(html_code) |
|
|
|
if not clicked and st.session_state["action"] not in ("uploaded", "switch_page_from_local_edits", "switch_page_from_presets", "slider_change", "reset"): |
|
store_img_from_id(default_choice, urls, imgtype) |
|
|
|
st.write("OR: ") |
|
|
|
with st.form(imgtype + "-form", clear_on_submit=True): |
|
uploaded_im = st.file_uploader(f"Load {imgtype} image:", type=["png", "jpg"], ) |
|
upload_pressed = st.form_submit_button("Upload") |
|
|
|
if upload_pressed and uploaded_im is not None: |
|
img = Image.open(uploaded_im) |
|
buffered = BytesIO() |
|
img.save(buffered, format="JPEG") |
|
encoded = base64.b64encode(buffered.getvalue()).decode() |
|
|
|
session_state.get(**{f"{imgtype}_im": img, f"{imgtype}_render_src": f"data:image/jpeg;base64,{encoded}", |
|
f"{imgtype}_id": "uploaded"}) |
|
st.session_state["action"] = "uploaded" |
|
st.write("uploaded.") |
|
|
|
last_clicked = last_image_clicked(type=imgtype) |
|
print("last_clicked", last_clicked, "clicked", clicked, "action", st.session_state["action"] ) |
|
if not upload_pressed and clicked != "": |
|
if last_clicked != clicked: |
|
store_img_from_id(clicked, urls, imgtype) |
|
last_image_clicked(type=imgtype, action=clicked) |
|
st.session_state["action"] = "clicked" |
|
st.session_state.click_counter += 1 |
|
|
|
state = session_state.get() |
|
st.sidebar.write(f'Selected {imgtype} image:') |
|
st.sidebar.markdown(f'<img src="{state[f"{imgtype}_render_src"]}" width=240px></img>', unsafe_allow_html=True) |
|
|
|
|
|
def optimize(effect, preset, result_image_placeholder): |
|
content = st.session_state["Content_im"] |
|
style = st.session_state["Style_im"] |
|
result_image_placeholder.text("<- Custom content/style needs to be style transferred") |
|
optimize_button = st.sidebar.button("Optimize Style Transfer") |
|
if optimize_button: |
|
if HUGGING_FACE: |
|
result_image_placeholder.warning("NST optimization is currently disabled in this HuggingFace Space because it takes ~5min to optimize. To try it out, please clone the repo and change the huggingface variable in demo_config.py") |
|
st.stop() |
|
|
|
result_image_placeholder.text("Executing NST to create reference image..") |
|
base_dir = f"result/{datetime.datetime.now().strftime(r'%Y-%m-%d %H.%Mh %Ss')}" |
|
os.makedirs(base_dir) |
|
with st.spinner(text="Running NST"): |
|
reference = strotss(pil_resize_long_edge_to(content, 1024), |
|
pil_resize_long_edge_to(style, 1024), content_weight=16.0, |
|
device=torch.device("cuda"), space="uniform") |
|
progress_bar = result_image_placeholder.progress(0.0) |
|
ref_save_path = os.path.join(base_dir, "reference.jpg") |
|
content_save_path = os.path.join(base_dir, "content.jpg") |
|
resize_to = 720 |
|
reference = pil_resize_long_edge_to(reference, resize_to) |
|
reference.save(ref_save_path) |
|
content.save(content_save_path) |
|
ST_CONFIG["n_iterations"] = 300 |
|
with st.spinner(text="Optimizing parameters.."): |
|
vp, content_img_cuda = single_optimize(effect, preset, "l1", content_save_path, str(ref_save_path), |
|
write_video=False, base_dir=base_dir, |
|
iter_callback=lambda i: progress_bar.progress( |
|
float(i) / ST_CONFIG["n_iterations"])) |
|
return content_img_cuda.detach(), vp.cuda().detach() |
|
else: |
|
if not "result_vp" in st.session_state: |
|
st.stop() |
|
else: |
|
return st.session_state["effect_input"], st.session_state["result_vp"] |
|
|
|
|
|
@st.cache(hash_funcs={MinimalPipelineEffect: id}) |
|
def create_effect(): |
|
effect, preset, param_set = get_default_settings(effect_type) |
|
effect.enable_checkpoints() |
|
effect.cuda() |
|
return effect, preset |
|
|
|
|
|
def load_visual_params(vp_path: str, img_org: Image, org_cuda: torch.Tensor, effect) -> torch.Tensor: |
|
if Path(vp_path).exists(): |
|
vp = torch.load(vp_path).detach().clone() |
|
vp = F.interpolate(vp, (img_org.height, img_org.width)) |
|
if len(effect.vpd.vp_ranges) == vp.shape[1]: |
|
return vp |
|
|
|
vp = effect.vpd.preset_tensor(preset, org_cuda, add_local_dims=True) |
|
torch.save(vp, vp_path) |
|
return vp |
|
|
|
|
|
|
|
@st.experimental_memo |
|
def load_params(content_id, style_id): |
|
preoptim_param_path = os.path.join("precomputed", effect_type, content_id, style_id) |
|
img_org = Image.open(os.path.join(preoptim_param_path, "input.png")) |
|
content_cuda = np_to_torch(img_org).cuda() |
|
vp_path = os.path.join(preoptim_param_path, "vp.pt") |
|
vp = load_visual_params(vp_path, img_org, content_cuda, effect) |
|
return content_cuda, vp |
|
|
|
|
|
def render_effect(effect, content_cuda, vp): |
|
with torch.no_grad(): |
|
result_cuda = effect(content_cuda, vp) |
|
img_res = Image.fromarray((torch_to_np(result_cuda) * 255.0).astype(np.uint8)) |
|
return img_res |
|
|
|
|
|
result_container = st.container() |
|
coll1, coll2 = result_container.columns([3,2]) |
|
coll1.header("Result") |
|
coll2.header("Global Edits") |
|
result_image_placeholder = coll1.empty() |
|
result_image_placeholder.markdown("## loading..") |
|
|
|
img_choice_panel("Content", content_urls, "portrait", expanded=True) |
|
img_choice_panel("Style", style_urls, "starry_night", expanded=True) |
|
|
|
state = session_state.get() |
|
content_id = state["Content_id"] |
|
style_id = state["Style_id"] |
|
|
|
effect, preset = create_effect() |
|
|
|
print("content id, style id", content_id, style_id ) |
|
if st.session_state["action"] == "uploaded": |
|
content_img, _vp = optimize(effect, preset, result_image_placeholder) |
|
elif st.session_state["action"] in ("switch_page_from_local_edits", "switch_page_from_presets", "slider_change") or \ |
|
content_id == "uploaded" or style_id == "uploaded": |
|
print("restore param") |
|
_vp = st.session_state["result_vp"] |
|
content_img = st.session_state["effect_input"] |
|
else: |
|
print("load_params") |
|
content_img, _vp = load_params(content_id, style_id) |
|
|
|
vp = torch.clone(_vp) |
|
|
|
|
|
def reset_params(means, names): |
|
for i, name in enumerate(names): |
|
st.session_state["slider_" + name] = means[i] |
|
|
|
def on_slider(): |
|
st.session_state["action"] = "slider_change" |
|
|
|
|
|
with coll2: |
|
show_params_names = [ 'bumpScale', "bumpOpacity", "contourOpacity"] |
|
display_means = [] |
|
def create_slider(name): |
|
mean = torch.mean(vp[:, effect.vpd.name2idx[name]]).item() |
|
display_mean = mean + 0.5 |
|
display_means.append(display_mean) |
|
if "slider_" + name not in st.session_state or st.session_state["action"] != "slider_change": |
|
st.session_state["slider_" + name] = display_mean |
|
slider = st.slider(f"Mean {name}: ", 0.0, 1.0, step=0.05, key="slider_" + name, on_change=on_slider) |
|
vp[:, effect.vpd.name2idx[name]] += slider - display_mean |
|
vp.clamp_(-0.5, 0.5) |
|
|
|
for name in show_params_names: |
|
create_slider(name) |
|
|
|
others_idx = set(range(len(effect.vpd.vp_ranges))) - set([effect.vpd.name2idx[name] for name in show_params_names]) |
|
others_names = [effect.vpd.vp_ranges[i][0] for i in sorted(list(others_idx))] |
|
other_param = st.selectbox("Other parameters: ", others_names) |
|
create_slider(other_param) |
|
|
|
|
|
reset_button = st.button("Reset Parameters", on_click=reset_params, args=(display_means, show_params_names)) |
|
if reset_button: |
|
st.session_state["action"] = "reset" |
|
st.experimental_rerun() |
|
|
|
edit_locally_btn = st.button("Edit Local Parameter Maps") |
|
if edit_locally_btn: |
|
switch_page('️ local edits') |
|
|
|
apply_presets = st.button("Paint Presets") |
|
if apply_presets: |
|
switch_page("Apply_preset") |
|
|
|
img_res = render_effect(effect, content_img, vp) |
|
|
|
st.session_state["result_vp"] = vp |
|
st.session_state["effect_input"] = content_img |
|
st.session_state["last_result"] = img_res |
|
|
|
with coll1: |
|
|
|
result_image_placeholder.image(img_res) |
|
|
|
|
|
components.html( |
|
f""" |
|
<p>{st.session_state.click_counter}</p> |
|
<script> |
|
window.parent.document.querySelector('section.main').scrollTo(0, 0); |
|
</script> |
|
""", |
|
height=0 |
|
) |
|
|