mazpie commited on
Commit
97df429
1 Parent(s): e6f6d44

Update demo/t2v.py

Browse files
Files changed (1) hide show
  1. demo/t2v.py +5 -3
demo/t2v.py CHANGED
@@ -52,12 +52,11 @@ class Text2Video():
52
  if not os.path.exists(self.result_dir):
53
  os.mkdir(self.result_dir)
54
 
55
- self.agent.to('cuda')
56
- self.clip.to('cuda')
57
-
58
  @spaces.GPU
59
  def get_prompt(self, prompt, duration):
60
  torch.cuda.empty_cache()
 
 
61
 
62
  print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
63
  start = time.time()
@@ -94,6 +93,9 @@ class Text2Video():
94
 
95
  save_videos(prior_recon.unsqueeze(0), self.result_dir, filenames=[prompt_str], fps=15)
96
  print(f"Saved in {prompt_str}.mp4. Time used: {(time.time() - start):.2f} seconds")
 
 
 
97
  return os.path.join(self.result_dir, f"{prompt_str}.mp4")
98
 
99
  def download_model(self, model_folder, model_filename):
 
52
  if not os.path.exists(self.result_dir):
53
  os.mkdir(self.result_dir)
54
 
 
 
 
55
  @spaces.GPU
56
  def get_prompt(self, prompt, duration):
57
  torch.cuda.empty_cache()
58
+ self.agent.to('cuda')
59
+ self.clip.to('cuda')
60
 
61
  print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
62
  start = time.time()
 
93
 
94
  save_videos(prior_recon.unsqueeze(0), self.result_dir, filenames=[prompt_str], fps=15)
95
  print(f"Saved in {prompt_str}.mp4. Time used: {(time.time() - start):.2f} seconds")
96
+ # Offload GPU
97
+ self.agent.to('cpu')
98
+ self.clip.to('cpu')
99
  return os.path.join(self.result_dir, f"{prompt_str}.mp4")
100
 
101
  def download_model(self, model_folder, model_filename):