Max Reimann commited on
Commit
89dbdbc
1 Parent(s): 08fedb8

Improve waiting behaviour in app during optimization

Browse files
Whitebox_style_transfer.py CHANGED
@@ -140,20 +140,27 @@ def img_choice_panel(imgtype, urls, default_choice, expanded):
140
  st.sidebar.write(f'Selected {imgtype} image:')
141
  st.sidebar.markdown(f'<img src="{state[f"{imgtype}_render_src"]}" width=240px></img>', unsafe_allow_html=True)
142
 
143
-
144
  def optimize(effect, preset, result_image_placeholder):
145
  content = st.session_state["Content_im"]
146
  style = st.session_state["Style_im"]
 
 
 
 
 
 
 
 
147
  result_image_placeholder.text("<- Custom content/style needs to be style transferred")
148
- st.sidebar.warning("Note: Optimizing takes up to 5 minutes.")
 
 
 
 
149
  optimize_button = st.sidebar.button("Optimize Style Transfer")
150
  if optimize_button:
151
- with st.spinner(text="Optimizing parameters.."):
152
- if HUGGING_FACE:
153
- optimize_on_server(content, style, result_image_placeholder)
154
- else:
155
- optimize_params(effect, preset, content, style, result_image_placeholder)
156
- return st.session_state["effect_input"], st.session_state["result_vp"]
157
  else:
158
  if not "result_vp" in st.session_state:
159
  st.stop()
@@ -206,15 +213,24 @@ coll2.header("Global Edits")
206
  result_image_placeholder = coll1.empty()
207
  result_image_placeholder.markdown("## loading..")
208
 
209
- from tasks import optimize_on_server, optimize_params, monitor_task
210
 
211
  if "current_server_task_id" not in st.session_state:
212
  st.session_state['current_server_task_id'] = None
213
 
 
 
 
 
 
214
  if HUGGING_FACE and st.session_state['current_server_task_id'] is not None:
215
  with st.spinner(text="Optimizing parameters.."):
216
  monitor_task(result_image_placeholder)
217
 
 
 
 
 
218
  img_choice_panel("Content", content_urls, "portrait", expanded=True)
219
  img_choice_panel("Style", style_urls, "starry_night", expanded=True)
220
 
@@ -222,11 +238,10 @@ state = session_state.get()
222
  content_id = state["Content_id"]
223
  style_id = state["Style_id"]
224
 
225
- effect, preset = create_effect()
226
 
227
  print("content id, style id", content_id, style_id )
228
  if st.session_state["action"] == "uploaded":
229
- content_img, _vp = optimize(effect, preset, result_image_placeholder)
230
  elif st.session_state["action"] in ("switch_page_from_local_edits", "switch_page_from_presets", "slider_change") or \
231
  content_id == "uploaded" or style_id == "uploaded":
232
  print("restore param")
 
140
  st.sidebar.write(f'Selected {imgtype} image:')
141
  st.sidebar.markdown(f'<img src="{state[f"{imgtype}_render_src"]}" width=240px></img>', unsafe_allow_html=True)
142
 
 
143
  def optimize(effect, preset, result_image_placeholder):
144
  content = st.session_state["Content_im"]
145
  style = st.session_state["Style_im"]
146
+ st.session_state["optimize_next"] = False
147
+ with st.spinner(text="Optimizing parameters.."):
148
+ if HUGGING_FACE:
149
+ optimize_on_server(content, style, result_image_placeholder)
150
+ else:
151
+ optimize_params(effect, preset, content, style, result_image_placeholder)
152
+
153
+ def optimize_next(result_image_placeholder):
154
  result_image_placeholder.text("<- Custom content/style needs to be style transferred")
155
+ queue_length = 0 if not HUGGING_FACE else get_queue_length()
156
+ if queue_length > 0:
157
+ st.sidebar.warning(f"WARNING: Already {queue_length} tasks in the queue. It will take approx {(queue_length+1) * 5} min for your image to be completed.")
158
+ else:
159
+ st.sidebar.warning("Note: Optimizing takes up to 5 minutes.")
160
  optimize_button = st.sidebar.button("Optimize Style Transfer")
161
  if optimize_button:
162
+ st.session_state["optimize_next"] = True
163
+ st.experimental_rerun()
 
 
 
 
164
  else:
165
  if not "result_vp" in st.session_state:
166
  st.stop()
 
213
  result_image_placeholder = coll1.empty()
214
  result_image_placeholder.markdown("## loading..")
215
 
216
+ from tasks import optimize_on_server, optimize_params, monitor_task, get_queue_length
217
 
218
  if "current_server_task_id" not in st.session_state:
219
  st.session_state['current_server_task_id'] = None
220
 
221
+ if "optimize_next" not in st.session_state:
222
+ st.session_state['optimize_next'] = False
223
+
224
+ effect, preset = create_effect()
225
+
226
  if HUGGING_FACE and st.session_state['current_server_task_id'] is not None:
227
  with st.spinner(text="Optimizing parameters.."):
228
  monitor_task(result_image_placeholder)
229
 
230
+ if st.session_state["optimize_next"]:
231
+ print("optimize now")
232
+ optimize(effect, preset, result_image_placeholder)
233
+
234
  img_choice_panel("Content", content_urls, "portrait", expanded=True)
235
  img_choice_panel("Style", style_urls, "starry_night", expanded=True)
236
 
 
238
  content_id = state["Content_id"]
239
  style_id = state["Style_id"]
240
 
 
241
 
242
  print("content id, style id", content_id, style_id )
243
  if st.session_state["action"] == "uploaded":
244
+ content_img, _vp = optimize_next(result_image_placeholder)
245
  elif st.session_state["action"] in ("switch_page_from_local_edits", "switch_page_from_presets", "slider_change") or \
246
  content_id == "uploaded" or style_id == "uploaded":
247
  print("restore param")
pages/1_🎨_Apply_preset.py CHANGED
@@ -42,7 +42,7 @@ presets = {
42
 
43
  st.session_state["action"] = "switch_page_from_presets" # on switchback, remember effect input
44
 
45
- active_preset = st.sidebar.selectbox("apply preset: ", ["original", "bump mapped", "contoured"])
46
  blend_strength = st.sidebar.slider("Parameter blending strength (non-hue) : ", 0.0, 1.0, 1.0, 0.05)
47
  hue_blend_strength = st.sidebar.slider("Hue-shift blending strength : ", 0.0, 1.0, 1.0, 0.05)
48
 
@@ -118,4 +118,6 @@ coll2.image(img_res)
118
 
119
  apply_btn = st.sidebar.button("Apply")
120
  if apply_btn:
121
- st.session_state["result_vp"] = vp
 
 
 
42
 
43
  st.session_state["action"] = "switch_page_from_presets" # on switchback, remember effect input
44
 
45
+ active_preset = st.sidebar.selectbox("apply preset: ", ["bump mapped", "contoured", "original"])
46
  blend_strength = st.sidebar.slider("Parameter blending strength (non-hue) : ", 0.0, 1.0, 1.0, 0.05)
47
  hue_blend_strength = st.sidebar.slider("Hue-shift blending strength : ", 0.0, 1.0, 1.0, 0.05)
48
 
 
118
 
119
  apply_btn = st.sidebar.button("Apply")
120
  if apply_btn:
121
+ st.session_state["result_vp"] = vp
122
+
123
+ st.info("Note: Press apply to make changes permanent")
tasks.py CHANGED
@@ -53,55 +53,63 @@ def monitor_task(progress_placeholder):
53
 
54
  started_time = time.time()
55
  retries = 3
56
- while True:
57
- status = requests.get(WORKER_URL+"/get_status", params={"task_id": task_id})
58
- if status.status_code != 200:
59
- print("get_status got status_code", status.status_code)
60
- st.warning(status.content)
61
- retries -= 1
62
- if retries == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  return
 
 
 
 
 
 
 
64
  else:
65
- time.sleep(2)
66
- continue
67
- status = status.json()
68
- print(status)
69
- if status["status"] != "running" and status["status"] != "queued" :
70
- if status["msg"] != "":
71
- print("got error for task", task_id, ":", status["msg"])
72
- progress_placeholder.error(status["msg"])
73
- st.session_state['current_server_task_id'] = None
74
- st.stop()
75
- if status["status"] == "finished":
76
- retrieve_for_results_from_server()
77
- return
78
- elif status["status"] == "queued":
79
- started_time = time.time()
80
- queue_length = requests.get(WORKER_URL+"/queue_length").json()
81
- progress_placeholder.write(f"There are {queue_length['length']} tasks in the queue")
82
- elif status["progress"] == 0.0:
83
- progressed = min(0.5 * (time.time() - started_time) / 80.0, 0.5) #estimate 80s for strotts
84
- progress_placeholder.progress(progressed)
85
- else:
86
- progress_placeholder.progress(min(0.5 + status["progress"] / 2.0, 1.0))
87
-
88
- time.sleep(2)
89
 
90
 
91
  def optimize_on_server(content, style, result_image_placeholder):
92
- url = WORKER_URL + "/upload"
93
  content_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
94
  style_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
95
  asp_c, asp_s = content.height / content.width, style.height / style.width
96
  if any(a < 0.5 or a > 2.0 for a in (asp_c, asp_s)):
97
  result_image_placeholder.error('aspect ratio must be <= 2')
98
  st.stop()
 
99
  content = pil_resize_long_edge_to(content, 1024)
100
  content.save(content_path)
101
  style = pil_resize_long_edge_to(style, 1024)
102
  style.save(style_path)
103
  files = {'style-image': open(style_path, "rb"), "content-image": open(content_path, "rb")}
104
  print("start-optimizing")
 
105
  task_id_res = requests.post(url, files=files)
106
  if task_id_res.status_code != 200:
107
  result_image_placeholder.error(task_id_res.content)
 
53
 
54
  started_time = time.time()
55
  retries = 3
56
+ with progress_placeholder.container():
57
+ st.warning("Do not interact with the app until results are shown - otherwise results might be lost.")
58
+ progress_bar = st.empty()
59
+ while True:
60
+ status = requests.get(WORKER_URL+"/get_status", params={"task_id": task_id})
61
+ if status.status_code != 200:
62
+ print("get_status got status_code", status.status_code)
63
+ st.warning(status.content)
64
+ retries -= 1
65
+ if retries == 0:
66
+ return
67
+ else:
68
+ time.sleep(2)
69
+ continue
70
+ status = status.json()
71
+ print(status)
72
+ if status["status"] != "running" and status["status"] != "queued" :
73
+ if status["msg"] != "":
74
+ print("got error for task", task_id, ":", status["msg"])
75
+ progress_placeholder.error(status["msg"])
76
+ st.session_state['current_server_task_id'] = None
77
+ st.stop()
78
+ if status["status"] == "finished":
79
+ retrieve_for_results_from_server()
80
  return
81
+ elif status["status"] == "queued":
82
+ started_time = time.time()
83
+ queue_length = requests.get(WORKER_URL+"/queue_length").json()
84
+ progress_bar.write(f"There are {queue_length['length']} tasks in the queue")
85
+ elif status["progress"] == 0.0:
86
+ progressed = min(0.5 * (time.time() - started_time) / 80.0, 0.5) #estimate 80s for strotts
87
+ progress_bar.progress(progressed)
88
  else:
89
+ progress_bar.progress(min(0.5 + status["progress"] / 2.0, 1.0))
90
+
91
+ time.sleep(2)
92
+
93
+ def get_queue_length():
94
+ queue_length = requests.get(WORKER_URL+"/queue_length").json()
95
+ return queue_length['length']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
  def optimize_on_server(content, style, result_image_placeholder):
 
99
  content_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
100
  style_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
101
  asp_c, asp_s = content.height / content.width, style.height / style.width
102
  if any(a < 0.5 or a > 2.0 for a in (asp_c, asp_s)):
103
  result_image_placeholder.error('aspect ratio must be <= 2')
104
  st.stop()
105
+
106
  content = pil_resize_long_edge_to(content, 1024)
107
  content.save(content_path)
108
  style = pil_resize_long_edge_to(style, 1024)
109
  style.save(style_path)
110
  files = {'style-image': open(style_path, "rb"), "content-image": open(content_path, "rb")}
111
  print("start-optimizing")
112
+ url = WORKER_URL + "/upload"
113
  task_id_res = requests.post(url, files=files)
114
  if task_id_res.status_code != 200:
115
  result_image_placeholder.error(task_id_res.content)