cocktailpeanut commited on
Commit
d53d73c
1 Parent(s): d9114d9
Files changed (1) hide show
  1. app.py +13 -3
app.py CHANGED
@@ -19,6 +19,13 @@ from funcs import (
19
  get_latent_z,
20
  save_videos
21
  )
 
 
 
 
 
 
 
22
 
23
  def download_model():
24
  REPO_ID = 'Doubiiu/DynamiCrafter_1024'
@@ -43,7 +50,7 @@ def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
43
  assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
44
  model = load_model_checkpoint(model, ckpt_path)
45
  model.eval()
46
- model = model.cuda()
47
  save_fps = 8
48
 
49
  seed_everything(seed)
@@ -51,7 +58,10 @@ def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
51
  transforms.Resize(min(resolution)),
52
  transforms.CenterCrop(resolution),
53
  ])
54
- torch.cuda.empty_cache()
 
 
 
55
  print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
56
  start = time.time()
57
  if steps > 60:
@@ -154,4 +164,4 @@ with gr.Blocks(analytics_enabled=False, css=css) as dynamicrafter_iface:
154
  fn = infer
155
  )
156
 
157
- dynamicrafter_iface.queue(max_size=12).launch(show_api=True)
 
19
  get_latent_z,
20
  save_videos
21
  )
22
+ if torch.cuda.is_available():
23
+ device = "cuda"
24
+ elif torch.backends.mps.is_available():
25
+ device = "mps"
26
+ else:
27
+ device = "cpu"
28
+
29
 
30
  def download_model():
31
  REPO_ID = 'Doubiiu/DynamiCrafter_1024'
 
50
  assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
51
  model = load_model_checkpoint(model, ckpt_path)
52
  model.eval()
53
+ model = model.to(device)
54
  save_fps = 8
55
 
56
  seed_everything(seed)
 
58
  transforms.Resize(min(resolution)),
59
  transforms.CenterCrop(resolution),
60
  ])
61
+ if device == "cuda":
62
+ torch.cuda.empty_cache()
63
+ elif device == "mps":
64
+ torch.mps.empty_cache()
65
  print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
66
  start = time.time()
67
  if steps > 60:
 
164
  fn = infer
165
  )
166
 
167
+ dynamicrafter_iface.queue(max_size=12).launch(show_api=True)