wyysf commited on
Commit
1656878
β€’
1 Parent(s): 50ad329

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +3 -3
gradio_app.py CHANGED
@@ -154,8 +154,8 @@ def gen_mvimg(
154
  prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
155
  prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
156
 
157
- imgs_in = imgs_in.to(device=device, dtype=torch.float16)
158
- prompt_embeddings = prompt_embeddings.to(device=device, dtype=torch.float16)
159
 
160
  mv_imgs = era3d_pipeline(
161
  imgs_in,
@@ -259,7 +259,7 @@ if __name__=="__main__":
259
  )
260
  # enable xformers
261
  # era3d_pipeline.unet.enable_xformers_memory_efficient_attention()
262
- era3d_pipeline.to(device)
263
  elif "CRM" in mvimg_model_config_list:
264
  stage1_config = OmegaConf.load(f"apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml").config
265
  stage1_sampler_config = stage1_config.sampler
 
154
  prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
155
  prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
156
 
157
+ imgs_in = imgs_in.to(dtype=torch.float16)
158
+ prompt_embeddings = prompt_embeddings.to(dtype=torch.float16)
159
 
160
  mv_imgs = era3d_pipeline(
161
  imgs_in,
 
259
  )
260
  # enable xformers
261
  # era3d_pipeline.unet.enable_xformers_memory_efficient_attention()
262
+ # era3d_pipeline.to(device)
263
  elif "CRM" in mvimg_model_config_list:
264
  stage1_config = OmegaConf.load(f"apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml").config
265
  stage1_sampler_config = stage1_config.sampler