skytnt commited on
Commit
6c802d8
1 Parent(s): a2917b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -230,8 +230,8 @@ class Model:
230
  video.close()
231
 
232
 
233
- def gen_fn(use_seed, seed, psi):
234
- z = RandomState(int(seed) + 2 ** 31).randn(1, 512) if use_seed else np.random.randn(1, 512)
235
  w = model.get_w(z.astype(dtype=np.float32), psi)
236
  img_out = model.get_img(w)
237
  return img_out, json.dumps(w.tolist()), img_out
@@ -247,7 +247,7 @@ def encode_img_fn(img):
247
 
248
 
249
  def gen_video_fn(w1, w2, frame):
250
- if w1 is None or w2 is None:
251
  return None
252
  model.gen_video(np.array(json.loads(w1), dtype=np.float32), np.array(json.loads(w2), dtype=np.float32), "video.mp4",
253
  int(frame))
@@ -266,7 +266,7 @@ if __name__ == '__main__':
266
  with gr.Row():
267
  with gr.Column():
268
  gr.Markdown("generate image randomly or by seed")
269
- gen_input1 = gr.Checkbox(value=False, label="use seed")
270
  gen_input2 = gr.Number(value=1, label="seed")
271
  gen_input3 = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.6, label="truncation psi")
272
  with gr.Group():
 
230
  video.close()
231
 
232
 
233
+ def gen_fn(method, seed, psi):
234
+ z = RandomState(int(seed) + 2 ** 31).randn(1, 512) if method == 1 else np.random.randn(1, 512)
235
  w = model.get_w(z.astype(dtype=np.float32), psi)
236
  img_out = model.get_img(w)
237
  return img_out, json.dumps(w.tolist()), img_out
 
247
 
248
 
249
  def gen_video_fn(w1, w2, frame):
250
+ if w1 is None or w2 is None or w1 == "" or w2 == "":
251
  return None
252
  model.gen_video(np.array(json.loads(w1), dtype=np.float32), np.array(json.loads(w2), dtype=np.float32), "video.mp4",
253
  int(frame))
 
266
  with gr.Row():
267
  with gr.Column():
268
  gr.Markdown("generate image randomly or by seed")
269
+ gen_input1 = gr.Radio(label="method", choices=["random", "use seed"], type="index")
270
  gen_input2 = gr.Number(value=1, label="seed")
271
  gen_input3 = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.6, label="truncation psi")
272
  with gr.Group():