anvilarth commited on
Commit
c99b3ce
1 Parent(s): e866777

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -10,7 +10,7 @@ import urllib.request
10
  from PIL import Image, ImageDraw
11
  import matplotlib.pyplot as plt
12
 
13
- from Garage.models.GroundedSegmentAnything.segment_anything.segment_anything import SamPredictor, sam_model_registry
14
  from Garage.models.GroundedSegmentAnything.GroundingDINO.groundingdino.util.inference import Model
15
  from Garage import Augmenter
16
 
@@ -22,7 +22,7 @@ MODEL_DICT = dict(
22
  )
23
 
24
  GROUNDING_DINO_CONFIG_PATH = "Garage/models/GroundedSegmentAnything/GroundingDINO_SwinT_OGC.py"
25
- GROUNDING_DINO_CHECKPOINT_PATH = "Garage/models/checkpoints/GroundedSegmentAnything/groundingdino_swint_ogc.pth"
26
  SAM_CHECKPOINT_PATH = "Garage/models/checkpoints/GroundedSegmentAnything/sam_vit_h_4b8939.pth"
27
  SAM_ENCODER_VERSION = "vit_h"
28
 
@@ -210,7 +210,7 @@ class GradioWindow():
210
  print(f"Model {model} already exists")
211
 
212
  def setup_model(self) -> SamPredictor:
213
- self.sam = sam_model_registry[self.model_type](checkpoint=self.SAM_CHECKPOINT_PATH)
214
  self.sam.to(device=self.device)
215
  self.sam_predictor = SamPredictor(self.sam)
216
 
@@ -219,6 +219,8 @@ class GradioWindow():
219
  model_checkpoint_path=self.GROUNDING_DINO_CHECKPOINT_PATH,
220
  device=self.device
221
  )
 
 
222
 
223
  def change_mask_type(self, image, is_segmmask):
224
  self.selected_mask = None
 
10
  from PIL import Image, ImageDraw
11
  import matplotlib.pyplot as plt
12
 
13
+ from Garage.models.GroundedSegmentAnything.segment_anything.segment_anything import SamPredictor, build_sam
14
  from Garage.models.GroundedSegmentAnything.GroundingDINO.groundingdino.util.inference import Model
15
  from Garage import Augmenter
16
 
 
22
  )
23
 
24
  GROUNDING_DINO_CONFIG_PATH = "Garage/models/GroundedSegmentAnything/GroundingDINO_SwinT_OGC.py"
25
+ GROUNDING_DINO_CHECKPOINT_PATH = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
26
  SAM_CHECKPOINT_PATH = "Garage/models/checkpoints/GroundedSegmentAnything/sam_vit_h_4b8939.pth"
27
  SAM_ENCODER_VERSION = "vit_h"
28
 
 
210
  print(f"Model {model} already exists")
211
 
212
  def setup_model(self) -> SamPredictor:
213
+ self.sam = build_sam(checkpoint='sam_vit_h_4b8939.pth')
214
  self.sam.to(device=self.device)
215
  self.sam_predictor = SamPredictor(self.sam)
216
 
 
219
  model_checkpoint_path=self.GROUNDING_DINO_CHECKPOINT_PATH,
220
  device=self.device
221
  )
222
+
223
+ print("MODELS LOADED!")
224
 
225
  def change_mask_type(self, image, is_segmmask):
226
  self.selected_mask = None