ziyangmai commited on
Commit
1ad8288
1 Parent(s): a487160

add spaces decorator

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -6,11 +6,25 @@ import random
6
  import string
7
  import json
8
  from omegaconf import OmegaConf,ListConfig
 
9
 
10
 
11
  from train import main as train_main
12
  from inference import inference as inference_main
13
- # 模拟训练函数
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def train_model(video, config):
15
  output_dir = 'results'
16
  os.makedirs(output_dir, exist_ok=True)
@@ -28,7 +42,7 @@ def train_model(video, config):
28
  # cur_save_dir = 'results/06'
29
  return cur_save_dir
30
 
31
- # 模拟推理函数
32
  def inference_model(text, checkpoint, inference_steps, video_type,seed):
33
 
34
  checkpoint = os.path.join('results',checkpoint)
@@ -36,7 +50,7 @@ def inference_model(text, checkpoint, inference_steps, video_type,seed):
36
  embedding_dir = '/'.join(checkpoint.split('/')[:-1])
37
  video_round = checkpoint.split('/')[-1]
38
 
39
- video_path = inference_main(
40
  embedding_dir=embedding_dir,
41
  prompt=text,
42
  video_round=video_round,
@@ -49,7 +63,6 @@ def inference_model(text, checkpoint, inference_steps, video_type,seed):
49
  return video_path
50
 
51
 
52
- # 获取checkpoint文件列表
53
  def get_checkpoints(checkpoint_dir):
54
 
55
  checkpoints = []
@@ -135,7 +148,6 @@ if __name__ == "__main__":
135
  ]
136
 
137
  gradio_theme = gr.themes.Default()
138
- # 创建Gradio界面
139
  with gr.Blocks(
140
  theme=gradio_theme,
141
  title="Motion Inversion",
@@ -249,5 +261,5 @@ if __name__ == "__main__":
249
  checkpoint_output.change(update_checkpoints, inputs=checkpoint_output, outputs=checkpoint_dropdown)
250
  inference_button.click(inference_model, inputs=[text_input, checkpoint_dropdown,inference_steps,motion_type, seed], outputs=output_video)
251
  output_video.change(fn=update_generated_prompt, inputs=[text_input], outputs=generated_prompt)
252
- # 启动Gradio界面
253
  demo.launch()
 
6
  import string
7
  import json
8
  from omegaconf import OmegaConf,ListConfig
9
+ import spaces
10
 
11
 
12
  from train import main as train_main
13
  from inference import inference as inference_main
14
+
15
+ @spaces.GPU()
16
+ def inference_app(
17
+ embedding_dir,
18
+ prompt,
19
+ video_round,
20
+ save_dir,
21
+ motion_type,
22
+ seed,
23
+ inference_steps):
24
+
25
+ return inference_main(embedding_dir, prompt, video_round, save_dir, motion_type, seed, inference_steps)
26
+
27
+
28
  def train_model(video, config):
29
  output_dir = 'results'
30
  os.makedirs(output_dir, exist_ok=True)
 
42
  # cur_save_dir = 'results/06'
43
  return cur_save_dir
44
 
45
+
46
  def inference_model(text, checkpoint, inference_steps, video_type,seed):
47
 
48
  checkpoint = os.path.join('results',checkpoint)
 
50
  embedding_dir = '/'.join(checkpoint.split('/')[:-1])
51
  video_round = checkpoint.split('/')[-1]
52
 
53
+ video_path = inference_app(
54
  embedding_dir=embedding_dir,
55
  prompt=text,
56
  video_round=video_round,
 
63
  return video_path
64
 
65
 
 
66
  def get_checkpoints(checkpoint_dir):
67
 
68
  checkpoints = []
 
148
  ]
149
 
150
  gradio_theme = gr.themes.Default()
 
151
  with gr.Blocks(
152
  theme=gradio_theme,
153
  title="Motion Inversion",
 
261
  checkpoint_output.change(update_checkpoints, inputs=checkpoint_output, outputs=checkpoint_dropdown)
262
  inference_button.click(inference_model, inputs=[text_input, checkpoint_dropdown,inference_steps,motion_type, seed], outputs=output_video)
263
  output_video.change(fn=update_generated_prompt, inputs=[text_input], outputs=generated_prompt)
264
+
265
  demo.launch()