yizhangliu
commited on
Commit
•
f47bc1e
1
Parent(s):
1f359be
Update app.py
Browse files
app.py
CHANGED
@@ -706,13 +706,16 @@ def change_radio_display(task_type, mask_source_radio):
|
|
706 |
return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible)
|
707 |
|
708 |
def get_model_device(module):
|
709 |
-
|
710 |
-
module
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
|
|
|
|
|
|
716 |
|
717 |
if __name__ == "__main__":
|
718 |
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
|
|
|
706 |
return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible)
|
707 |
|
708 |
def get_model_device(module):
|
709 |
+
try:
|
710 |
+
if isinstance(module, torch.nn.DataParallel):
|
711 |
+
module = module.module
|
712 |
+
for submodule in module.children():
|
713 |
+
if hasattr(submodule, "_parameters"):
|
714 |
+
parameters = submodule._parameters
|
715 |
+
if "weight" in parameters:
|
716 |
+
return parameters["weight"].device
|
717 |
+
except Exception as e:
|
718 |
+
return 'ohoh'
|
719 |
|
720 |
if __name__ == "__main__":
|
721 |
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
|