Spaces:
Running
on
T4
Running
on
T4
import base64 | |
import datetime | |
import os | |
import sys | |
from io import BytesIO | |
from pathlib import Path | |
import numpy as np | |
import requests | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
import time | |
PACKAGE_PARENT = 'wise' | |
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__)))) | |
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT))) | |
import streamlit as st | |
from streamlit.logger import get_logger | |
from st_click_detector import click_detector | |
import streamlit.components.v1 as components | |
from streamlit.source_util import get_pages | |
from streamlit_extras.switch_page_button import switch_page | |
from demo_config import HUGGING_FACE | |
from parameter_optimization.parametric_styletransfer import single_optimize | |
from parameter_optimization.parametric_styletransfer import CONFIG as ST_CONFIG | |
from parameter_optimization.strotss_org import strotss, pil_resize_long_edge_to | |
import helpers.session_state as session_state | |
from helpers import torch_to_np, np_to_torch | |
from effects import get_default_settings, MinimalPipelineEffect | |
st.set_page_config(layout="wide") | |
BASE_URL = "https://ivpg.hpi3d.de/wise/wise-demo/images/" | |
LOGGER = get_logger(__name__) | |
effect_type = "minimal_pipeline" | |
if "click_counter" not in st.session_state: | |
st.session_state.click_counter = 1 | |
if "action" not in st.session_state: | |
st.session_state["action"] = "" | |
content_urls = [ | |
{ | |
"name": "Portrait", "id": "portrait", | |
"src": BASE_URL + "/content/portrait.jpeg" | |
}, | |
{ | |
"name": "Tuebingen", "id": "tubingen", | |
"src": BASE_URL + "/content/tubingen.jpeg" | |
}, | |
{ | |
"name": "Colibri", "id": "colibri", | |
"src": BASE_URL + "/content/colibri.jpeg" | |
} | |
] | |
style_urls = [ | |
{ | |
"name": "Starry Night, Van Gogh", "id": "starry_night", | |
"src": BASE_URL + "/style/starry_night.jpg" | |
}, | |
{ | |
"name": "The Scream, Edward Munch", "id": "the_scream", | |
"src": BASE_URL + "/style/the_scream.jpg" | |
}, | |
{ | |
"name": "The Great Wave, Ukiyo-e", "id": "wave", | |
"src": BASE_URL + "/style/wave.jpg" | |
}, | |
{ | |
"name": "Woman with Hat, Henry Matisse", "id": "woman_with_hat", | |
"src": BASE_URL + "/style/woman_with_hat.jpg" | |
} | |
] | |
def last_image_clicked(type="content", action=None, ): | |
kw = "last_image_clicked" + "_" + type | |
if action: | |
session_state.get(**{kw: action}) | |
elif kw not in session_state.get(): | |
return None | |
else: | |
return session_state.get()[kw] | |
def _retrieve_from_id(clicked, urls): | |
src = [x["src"] for x in urls if x["id"] == clicked][0] | |
img = Image.open(requests.get(src, stream=True).raw) | |
return img, src | |
def store_img_from_id(clicked, urls, imgtype): | |
img, src = _retrieve_from_id(clicked, urls) | |
session_state.get(**{f"{imgtype}_im": img, f"{imgtype}_render_src": src, f"{imgtype}_id": clicked}) | |
def img_choice_panel(imgtype, urls, default_choice, expanded): | |
with st.expander(f"Select {imgtype} image:", expanded=expanded): | |
html_code = '<div class="column" style="display: flex; flex-wrap: wrap; padding: 0 4px;">' | |
for url in urls: | |
html_code += f"<a href='#' id='{url['id']}' style='padding: 0px 5px'><img height='160px' style='margin-top: 8px;' src='{url['src']}'></a>" | |
html_code += "</div>" | |
clicked = click_detector(html_code) | |
if not clicked and st.session_state["action"] not in ("uploaded", "switch_page_from_local_edits", "switch_page_from_presets", "slider_change", "reset"): # default val | |
store_img_from_id(default_choice, urls, imgtype) | |
st.write("OR: ") | |
with st.form(imgtype + "-form", clear_on_submit=True): | |
uploaded_im = st.file_uploader(f"Load {imgtype} image:", type=["png", "jpg"], ) | |
upload_pressed = st.form_submit_button("Upload") | |
if upload_pressed and uploaded_im is not None: | |
img = Image.open(uploaded_im) | |
img = img.convert('RGB') | |
buffered = BytesIO() | |
img.save(buffered, format="JPEG") | |
encoded = base64.b64encode(buffered.getvalue()).decode() | |
# session_state.get(uploaded_im=img, content_render_src=f"data:image/jpeg;base64,{encoded}") | |
session_state.get(**{f"{imgtype}_im": img, f"{imgtype}_render_src": f"data:image/jpeg;base64,{encoded}", | |
f"{imgtype}_id": "uploaded"}) | |
st.session_state["action"] = "uploaded" | |
st.write("uploaded.") | |
last_clicked = last_image_clicked(type=imgtype) | |
print("last_clicked", last_clicked, "clicked", clicked, "action", st.session_state["action"] ) | |
if not upload_pressed and clicked != "": # trigger when no file uploaded | |
if last_clicked != clicked: # only activate when content was actually clicked | |
store_img_from_id(clicked, urls, imgtype) | |
last_image_clicked(type=imgtype, action=clicked) | |
st.session_state["action"] = "clicked" | |
st.session_state.click_counter += 1 # hack to get page to reload at top | |
state = session_state.get() | |
st.sidebar.write(f'Selected {imgtype} image:') | |
st.sidebar.markdown(f'<img src="{state[f"{imgtype}_render_src"]}" width=240px></img>', unsafe_allow_html=True) | |
def optimize(effect, preset, result_image_placeholder): | |
content = st.session_state["Content_im"] | |
style = st.session_state["Style_im"] | |
st.session_state["optimize_next"] = False | |
with st.spinner(text="Optimizing parameters.."): | |
if HUGGING_FACE: | |
optimize_on_server(content, style, result_image_placeholder) | |
else: | |
optimize_params(effect, preset, content, style, result_image_placeholder) | |
def optimize_next(result_image_placeholder): | |
result_image_placeholder.text("<- Custom content/style needs to be style transferred") | |
queue_length = 0 if not HUGGING_FACE else get_queue_length() | |
if queue_length > 0: | |
st.sidebar.warning(f"WARNING: Already {queue_length} tasks in the queue. It will take approx {(queue_length+1) * 5} min for your image to be completed.") | |
else: | |
st.sidebar.warning("Note: Optimizing takes up to 5 minutes.") | |
optimize_button = st.sidebar.button("Optimize Style Transfer") | |
if optimize_button: | |
st.session_state["optimize_next"] = True | |
st.experimental_rerun() | |
else: | |
if not "result_vp" in st.session_state: | |
st.stop() | |
else: | |
return st.session_state["effect_input"], st.session_state["result_vp"] | |
def create_effect(): | |
effect, preset, param_set = get_default_settings(effect_type) | |
effect.enable_checkpoints() | |
effect.cuda() | |
return effect, preset | |
def load_visual_params(vp_path: str, img_org: Image, org_cuda: torch.Tensor, effect) -> torch.Tensor: | |
if Path(vp_path).exists(): | |
vp = torch.load(vp_path).detach().clone() | |
vp = F.interpolate(vp, (img_org.height, img_org.width)) | |
if len(effect.vpd.vp_ranges) == vp.shape[1]: | |
return vp | |
# use preset and save it | |
vp = effect.vpd.preset_tensor(preset, org_cuda, add_local_dims=True) | |
torch.save(vp, vp_path) | |
return vp | |
# @st.cache(hash_funcs={torch.Tensor: id}) | |
def load_params(content_id, style_id):#, effect): | |
preoptim_param_path = os.path.join("precomputed", effect_type, content_id, style_id) | |
img_org = Image.open(os.path.join(preoptim_param_path, "input.png")) | |
content_cuda = np_to_torch(img_org).cuda() | |
vp_path = os.path.join(preoptim_param_path, "vp.pt") | |
vp = load_visual_params(vp_path, img_org, content_cuda, effect) | |
return content_cuda, vp | |
def render_effect(effect, content_cuda, vp): | |
with torch.no_grad(): | |
result_cuda = effect(content_cuda, vp) | |
img_res = Image.fromarray((torch_to_np(result_cuda) * 255.0).astype(np.uint8)) | |
return img_res | |
result_container = st.container() | |
coll1, coll2 = result_container.columns([3,2]) | |
coll1.header("Result") | |
coll2.header("Global Edits") | |
result_image_placeholder = coll1.empty() | |
result_image_placeholder.markdown("## loading..") | |
from tasks import optimize_on_server, optimize_params, monitor_task, get_queue_length | |
if "current_server_task_id" not in st.session_state: | |
st.session_state['current_server_task_id'] = None | |
if "optimize_next" not in st.session_state: | |
st.session_state['optimize_next'] = False | |
effect, preset = create_effect() | |
if HUGGING_FACE and st.session_state['current_server_task_id'] is not None: | |
with st.spinner(text="Optimizing parameters.."): | |
monitor_task(result_image_placeholder) | |
if st.session_state["optimize_next"]: | |
print("optimize now") | |
optimize(effect, preset, result_image_placeholder) | |
img_choice_panel("Content", content_urls, "portrait", expanded=True) | |
img_choice_panel("Style", style_urls, "starry_night", expanded=True) | |
state = session_state.get() | |
content_id = state["Content_id"] | |
style_id = state["Style_id"] | |
print("content id, style id", content_id, style_id ) | |
if st.session_state["action"] == "uploaded": | |
content_img, _vp = optimize_next(result_image_placeholder) | |
elif st.session_state["action"] in ("switch_page_from_local_edits", "switch_page_from_presets", "slider_change") or \ | |
content_id == "uploaded" or style_id == "uploaded": | |
print("restore param") | |
_vp = st.session_state["result_vp"] | |
content_img = st.session_state["effect_input"] | |
else: | |
print("load_params") | |
content_img, _vp = load_params(content_id, style_id)#, effect) | |
vp = torch.clone(_vp) | |
def reset_params(means, names): | |
for i, name in enumerate(names): | |
st.session_state["slider_" + name] = means[i] | |
def on_slider(): | |
st.session_state["action"] = "slider_change" | |
with coll2: | |
show_params_names = [ 'bumpScale', "bumpOpacity", "contourOpacity"] | |
display_means = [] | |
def create_slider(name): | |
mean = torch.mean(vp[:, effect.vpd.name2idx[name]]).item() | |
display_mean = mean + 0.5 | |
display_means.append(display_mean) | |
if "slider_" + name not in st.session_state or st.session_state["action"] != "slider_change": | |
st.session_state["slider_" + name] = display_mean | |
slider = st.slider(f"Mean {name}: ", 0.0, 1.0, step=0.05, key="slider_" + name, on_change=on_slider) | |
vp[:, effect.vpd.name2idx[name]] += slider - display_mean | |
vp.clamp_(-0.5, 0.5) | |
for name in show_params_names: | |
create_slider(name) | |
others_idx = set(range(len(effect.vpd.vp_ranges))) - set([effect.vpd.name2idx[name] for name in show_params_names]) | |
others_names = [effect.vpd.vp_ranges[i][0] for i in sorted(list(others_idx))] | |
other_param = st.selectbox("Other parameters: ", others_names) | |
create_slider(other_param) | |
reset_button = st.button("Reset Parameters", on_click=reset_params, args=(display_means, show_params_names)) | |
if reset_button: | |
st.session_state["action"] = "reset" | |
st.experimental_rerun() | |
apply_presets = st.button("Paint Presets") | |
if apply_presets: | |
switch_page("Apply_preset") | |
edit_locally_btn = st.button("Edit Local Parameter Maps") | |
if edit_locally_btn: | |
switch_page('️ local edits') | |
img_res = render_effect(effect, content_img, vp) | |
st.session_state["result_vp"] = vp | |
st.session_state["effect_input"] = content_img | |
st.session_state["last_result"] = img_res | |
with coll1: | |
# width = int(img_res.width * 500 / img_res.height) | |
result_image_placeholder.image(img_res)#, width=width) | |
# a bit hacky way to return focus to top of page after clicking on images | |
components.html( | |
f""" | |
<p>{st.session_state.click_counter}</p> | |
<script> | |
window.parent.document.querySelector('section.main').scrollTo(0, 0); | |
</script> | |
""", | |
height=0 | |
) | |