GrayShine commited on
Commit
a88163e
β€’
1 Parent(s): 019f472

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -4
app.py CHANGED
@@ -154,7 +154,7 @@ vae = None
154
  text_encoder = None
155
  image_encoder = None
156
  clip_image_processor = None
157
- @spaces.GPU
158
  def init_model():
159
  global device
160
  global output_path
@@ -215,7 +215,7 @@ init_model()
215
  # ========================================
216
  @spaces.GPU
217
  def video_generation(text, image, scfg_scale, tcfg_scale, img_cfg_scale, diffusion):
218
- global device
219
  global output_path
220
  global use_fp16
221
  global model
@@ -223,6 +223,25 @@ def video_generation(text, image, scfg_scale, tcfg_scale, img_cfg_scale, diffusi
223
  global text_encoder
224
  global image_encoder
225
  global clip_image_processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  with torch.no_grad():
227
  print("begin generation", flush=True)
228
  transform_video = transforms.Compose([
@@ -253,14 +272,33 @@ def video_generation(text, image, scfg_scale, tcfg_scale, img_cfg_scale, diffusi
253
  # ========================================
254
  @spaces.GPU
255
  def video_prediction(text, image, scfg_scale, tcfg_scale, img_cfg_scale, preframe, diffusion):
256
- global device
257
  global output_path
258
  global use_fp16
259
  global model
260
  global vae
261
  global text_encoder
262
  global image_encoder
263
- global clip_image_processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  with torch.no_grad():
265
  print("begin generation", flush=True)
266
  transform_video = transforms.Compose([
 
154
  text_encoder = None
155
  image_encoder = None
156
  clip_image_processor = None
157
+ # @spaces.GPU
158
  def init_model():
159
  global device
160
  global output_path
 
215
  # ========================================
216
  @spaces.GPU
217
  def video_generation(text, image, scfg_scale, tcfg_scale, img_cfg_scale, diffusion):
218
+ device = "cuda" if torch.cuda.is_available() else "cpu"
219
  global output_path
220
  global use_fp16
221
  global model
 
223
  global text_encoder
224
  global image_encoder
225
  global clip_image_processor
226
+ vae = vae.to(device)
227
+ text_encoder = text_encoder.to(device)
228
+ image_encoder = image_encoder.to(device)
229
+ model = model.to(device)
230
+ if args.enable_xformers_memory_efficient_attention and device=="cuda":
231
+ if is_xformers_available():
232
+ model.enable_xformers_memory_efficient_attention()
233
+ print("xformer!")
234
+ else:
235
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
236
+ if args.use_fp16:
237
+ print('Warnning: using half percision for inferencing!')
238
+ vae.to(dtype=torch.float16)
239
+ model.to(dtype=torch.float16)
240
+ text_encoder.to(dtype=torch.float16)
241
+ image_encoder.to(dtype=torch.float16)
242
+ use_fp16 = True
243
+ print('Initialization Finished')
244
+
245
  with torch.no_grad():
246
  print("begin generation", flush=True)
247
  transform_video = transforms.Compose([
 
272
  # ========================================
273
  @spaces.GPU
274
  def video_prediction(text, image, scfg_scale, tcfg_scale, img_cfg_scale, preframe, diffusion):
275
+ device = "cuda" if torch.cuda.is_available() else "cpu"
276
  global output_path
277
  global use_fp16
278
  global model
279
  global vae
280
  global text_encoder
281
  global image_encoder
282
+ global clip_image_processor
283
+ vae = vae.to(device)
284
+ text_encoder = text_encoder.to(device)
285
+ image_encoder = image_encoder.to(device)
286
+ model = model.to(device)
287
+ if args.enable_xformers_memory_efficient_attention and device=="cuda":
288
+ if is_xformers_available():
289
+ model.enable_xformers_memory_efficient_attention()
290
+ print("xformer!")
291
+ else:
292
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
293
+ if args.use_fp16:
294
+ print('Warnning: using half percision for inferencing!')
295
+ vae.to(dtype=torch.float16)
296
+ model.to(dtype=torch.float16)
297
+ text_encoder.to(dtype=torch.float16)
298
+ image_encoder.to(dtype=torch.float16)
299
+ use_fp16 = True
300
+ print('Initialization Finished')
301
+
302
  with torch.no_grad():
303
  print("begin generation", flush=True)
304
  transform_video = transforms.Compose([