Zhibinhong
commited on
Commit
•
6b683e4
1
Parent(s):
f1779a1
Update visual_chatgpt.py
Browse files- visual_chatgpt.py +4 -4
visual_chatgpt.py
CHANGED
@@ -797,7 +797,7 @@ class Segmenting:
|
|
797 |
print(f"Inintializing Segmentation to {device}")
|
798 |
self.device = device
|
799 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
800 |
-
self.model_checkpoint_path = "/repository/checkpoints/sam"
|
801 |
|
802 |
self.download_parameters()
|
803 |
self.sam = build_sam(checkpoint=self.model_checkpoint_path).to(device)
|
@@ -813,9 +813,9 @@ class Segmenting:
|
|
813 |
print("finddir",os.system("find /repository -type d -iname 'checkpoints'"))
|
814 |
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
815 |
if not os.path.exists(path):
|
816 |
-
print("
|
817 |
# wget.download(url,out=self.model_checkpoint_path)
|
818 |
-
wget.download(url,out=
|
819 |
|
820 |
def show_mask(self, mask, ax, random_color=False):
|
821 |
if random_color:
|
@@ -917,7 +917,7 @@ class Text2Box:
|
|
917 |
print(f"Initializing ObjectDetection to {device}")
|
918 |
self.device = device
|
919 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
920 |
-
self.model_checkpoint_path = "repository/checkpoints/groundingdino"
|
921 |
self.model_config_path = "repository/checkpoints/grounding_config.py"
|
922 |
self.download_parameters()
|
923 |
self.box_threshold = 0.3
|
|
|
797 |
print(f"Inintializing Segmentation to {device}")
|
798 |
self.device = device
|
799 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
800 |
+
self.model_checkpoint_path = os.path.abspath("/repository/checkpoints/sam")
|
801 |
|
802 |
self.download_parameters()
|
803 |
self.sam = build_sam(checkpoint=self.model_checkpoint_path).to(device)
|
|
|
813 |
print("finddir",os.system("find /repository -type d -iname 'checkpoints'"))
|
814 |
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
815 |
if not os.path.exists(path):
|
816 |
+
print("I'm in!")
|
817 |
# wget.download(url,out=self.model_checkpoint_path)
|
818 |
+
wget.download(url,out=self.model_checkpoint_path)
|
819 |
|
820 |
def show_mask(self, mask, ax, random_color=False):
|
821 |
if random_color:
|
|
|
917 |
print(f"Initializing ObjectDetection to {device}")
|
918 |
self.device = device
|
919 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
920 |
+
self.model_checkpoint_path = os.path.abspath("repository/checkpoints/groundingdino")
|
921 |
self.model_config_path = "repository/checkpoints/grounding_config.py"
|
922 |
self.download_parameters()
|
923 |
self.box_threshold = 0.3
|