import base64 import datetime import os import sys from io import BytesIO from pathlib import Path import numpy as np import requests import torch import torch.nn.functional as F from PIL import Image import time import streamlit as st from demo_config import HUGGING_FACE, WORKER_URL PACKAGE_PARENT = 'wise' SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__)))) sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT))) from parameter_optimization.parametric_styletransfer import single_optimize from parameter_optimization.parametric_styletransfer import CONFIG as ST_CONFIG from parameter_optimization.strotss_org import strotss, pil_resize_long_edge_to from helpers import torch_to_np, np_to_torch def retrieve_for_results_from_server(): task_id = st.session_state['current_server_task_id'] vp_res = requests.get(WORKER_URL+"/get_vp", params={"task_id": task_id}) image_res = requests.get(WORKER_URL+"/get_image", params={"task_id": task_id}) if vp_res.status_code != 200 or image_res.status_code != 200: st.warning("got status for " + WORKER_URL+"/get_vp" + str(vp_res.status_code)) st.warning("got status for " + WORKER_URL+"/image_res" + str(image_res.status_code)) st.session_state['current_server_task_id'] = None vp_res.raise_for_status() image_res.raise_for_status() else: st.session_state['current_server_task_id'] = None vp = np.load(BytesIO(vp_res.content))["vp"] print("received vp from server") print("got numpy array", vp.shape) vp = torch.from_numpy(vp).cuda() image = Image.open(BytesIO(image_res.content)) print("received image from server") image = np_to_torch(np.asarray(image)).cuda() st.session_state["effect_input"] = image st.session_state["result_vp"] = vp def monitor_task(progress_placeholder): task_id = st.session_state['current_server_task_id'] started_time = time.time() retries = 3 with progress_placeholder.container(): st.warning("Do not interact with the app until results are shown - otherwise results might be lost.") progress_bar = st.empty() while True: status = requests.get(WORKER_URL+"/get_status", params={"task_id": task_id}) if status.status_code != 200: print("get_status got status_code", status.status_code) st.warning(status.content) retries -= 1 if retries == 0: return else: time.sleep(2) continue status = status.json() print(status) if status["status"] != "running" and status["status"] != "queued" : if status["msg"] != "": print("got error for task", task_id, ":", status["msg"]) progress_placeholder.error(status["msg"]) st.session_state['current_server_task_id'] = None st.stop() if status["status"] == "finished": retrieve_for_results_from_server() return elif status["status"] == "queued": started_time = time.time() queue_length = requests.get(WORKER_URL+"/queue_length").json() progress_bar.write(f"There are {queue_length['length']} tasks in the queue") elif status["progress"] == 0.0: progressed = min(0.5 * (time.time() - started_time) / 80.0, 0.5) #estimate 80s for strotts progress_bar.progress(progressed) else: progress_bar.progress(min(0.5 + status["progress"] / 2.0, 1.0)) time.sleep(2) def get_queue_length(): queue_length = requests.get(WORKER_URL+"/queue_length").json() return queue_length['length'] def optimize_on_server(content, style, result_image_placeholder): content_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg" style_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg" asp_c, asp_s = content.height / content.width, style.height / style.width if any(a < 0.5 or a > 2.0 for a in (asp_c, asp_s)): result_image_placeholder.error('aspect ratio must be <= 2') st.stop() content = pil_resize_long_edge_to(content, 1024) content.save(content_path) style = pil_resize_long_edge_to(style, 1024) style.save(style_path) files = {'style-image': open(style_path, "rb"), "content-image": open(content_path, "rb")} print("start-optimizing") url = WORKER_URL + "/upload" task_id_res = requests.post(url, files=files) if task_id_res.status_code != 200: result_image_placeholder.error(task_id_res.content) st.stop() else: task_id = task_id_res.json()['task_id'] st.session_state['current_server_task_id'] = task_id monitor_task(result_image_placeholder) def optimize_params(effect, preset, content, style, result_image_placeholder): result_image_placeholder.text("Executing NST to create reference image..") base_dir = f"result/{datetime.datetime.now().strftime(r'%Y-%m-%d %H.%Mh %Ss')}" os.makedirs(base_dir) reference = strotss(pil_resize_long_edge_to(content, 1024), pil_resize_long_edge_to(style, 1024), content_weight=16.0, device=torch.device("cuda"), space="uniform") progress_bar = result_image_placeholder.progress(0.0) ref_save_path = os.path.join(base_dir, "reference.jpg") content_save_path = os.path.join(base_dir, "content.jpg") resize_to = 720 reference = pil_resize_long_edge_to(reference, resize_to) reference.save(ref_save_path) content.save(content_save_path) ST_CONFIG["n_iterations"] = 300 vp, content_img_cuda = single_optimize(effect, preset, "l1", content_save_path, str(ref_save_path), write_video=False, base_dir=base_dir, iter_callback=lambda i: progress_bar.progress( float(i) / ST_CONFIG["n_iterations"])) st.session_state["effect_input"], st.session_state["result_vp"] = content_img_cuda.detach(), vp.cuda().detach()