Max Reimann
commited on
Commit
•
dc6a058
1
Parent(s):
591e364
Improve parameters
Browse files
Whitebox_style_transfer.py
CHANGED
@@ -264,22 +264,25 @@ def on_slider():
|
|
264 |
|
265 |
|
266 |
with coll2:
|
267 |
-
show_params_names = [ '
|
268 |
display_means = []
|
|
|
269 |
def create_slider(name):
|
270 |
-
|
271 |
-
|
|
|
272 |
display_means.append(display_mean)
|
273 |
if "slider_" + name not in st.session_state or st.session_state["action"] != "slider_change":
|
274 |
st.session_state["slider_" + name] = display_mean
|
275 |
slider = st.slider(f"Mean {name}: ", 0.0, 1.0, step=0.05, key="slider_" + name, on_change=on_slider)
|
276 |
-
|
277 |
-
|
|
|
278 |
|
279 |
for name in show_params_names:
|
280 |
create_slider(name)
|
281 |
|
282 |
-
others_idx = set(range(len(effect.vpd.vp_ranges))) - set([effect.vpd.name2idx[name] for name in
|
283 |
others_names = [effect.vpd.vp_ranges[i][0] for i in sorted(list(others_idx))]
|
284 |
other_param = st.selectbox("Other parameters: ", others_names)
|
285 |
create_slider(other_param)
|
|
|
264 |
|
265 |
|
266 |
with coll2:
|
267 |
+
show_params_names = [ 'bumpiness',"bumpSpecular", "contours"]
|
268 |
display_means = []
|
269 |
+
params_mapping = {"bumpiness": ['bumpScale', "bumpOpacity"], "bumpSpecular": ["bumpSpecular"], "contours": [ "contourOpacity", "contour"]}
|
270 |
def create_slider(name):
|
271 |
+
params = params_mapping[name] if name in params_mapping else [name]
|
272 |
+
means = [torch.mean(vp[:, effect.vpd.name2idx[n]]).item() for n in params]
|
273 |
+
display_mean = np.average(means) + 0.5
|
274 |
display_means.append(display_mean)
|
275 |
if "slider_" + name not in st.session_state or st.session_state["action"] != "slider_change":
|
276 |
st.session_state["slider_" + name] = display_mean
|
277 |
slider = st.slider(f"Mean {name}: ", 0.0, 1.0, step=0.05, key="slider_" + name, on_change=on_slider)
|
278 |
+
for i, param_name in enumerate(params):
|
279 |
+
vp[:, effect.vpd.name2idx[param_name]] += slider - (means[i] + 0.5)
|
280 |
+
vp.clamp_(-0.5, 0.5)
|
281 |
|
282 |
for name in show_params_names:
|
283 |
create_slider(name)
|
284 |
|
285 |
+
others_idx = set(range(len(effect.vpd.vp_ranges))) - set([effect.vpd.name2idx[name] for name in sum(params_mapping.values(), [])])
|
286 |
others_names = [effect.vpd.vp_ranges[i][0] for i in sorted(list(others_idx))]
|
287 |
other_param = st.selectbox("Other parameters: ", others_names)
|
288 |
create_slider(other_param)
|