Update gradio_app.py
Browse files- gradio_app.py +3 -3
gradio_app.py
CHANGED
@@ -154,8 +154,8 @@ def gen_mvimg(
|
|
154 |
prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
|
155 |
prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
|
156 |
|
157 |
-
imgs_in = imgs_in.to(
|
158 |
-
prompt_embeddings = prompt_embeddings.to(
|
159 |
|
160 |
mv_imgs = era3d_pipeline(
|
161 |
imgs_in,
|
@@ -259,7 +259,7 @@ if __name__=="__main__":
|
|
259 |
)
|
260 |
# enable xformers
|
261 |
# era3d_pipeline.unet.enable_xformers_memory_efficient_attention()
|
262 |
-
era3d_pipeline.to(device)
|
263 |
elif "CRM" in mvimg_model_config_list:
|
264 |
stage1_config = OmegaConf.load(f"apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml").config
|
265 |
stage1_sampler_config = stage1_config.sampler
|
|
|
154 |
prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
|
155 |
prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
|
156 |
|
157 |
+
imgs_in = imgs_in.to(dtype=torch.float16)
|
158 |
+
prompt_embeddings = prompt_embeddings.to(dtype=torch.float16)
|
159 |
|
160 |
mv_imgs = era3d_pipeline(
|
161 |
imgs_in,
|
|
|
259 |
)
|
260 |
# enable xformers
|
261 |
# era3d_pipeline.unet.enable_xformers_memory_efficient_attention()
|
262 |
+
# era3d_pipeline.to(device)
|
263 |
elif "CRM" in mvimg_model_config_list:
|
264 |
stage1_config = OmegaConf.load(f"apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml").config
|
265 |
stage1_sampler_config = stage1_config.sampler
|