Peng Shiya commited on
Commit
e18db8b
1 Parent(s): 773d6b2

fix: feadback.py

Browse files
Files changed (2) hide show
  1. app_configs.py +1 -1
  2. feedback.py +9 -3
app_configs.py CHANGED
@@ -2,4 +2,4 @@ model_type = r'vit_b'
2
  # model_ckpt_path = None
3
  model_ckpt_path = "checkpoints/sam_vit_b_01ec64.pth"
4
  device = None
5
- enable_segment_all = True
 
2
  # model_ckpt_path = None
3
  model_ckpt_path = "checkpoints/sam_vit_b_01ec64.pth"
4
  device = None
5
+ enable_segment_all = False
feedback.py CHANGED
@@ -17,11 +17,13 @@ def write_row(filepath:str, row: Dict):
17
 
18
  class Feedback():
19
  def __init__(self,
20
- image_dir = './data/input',
21
  mask_dir = './data/mask',
22
  inference_csv = './data/inference.csv',
23
  feedback_csv = './data/feedback.csv',
24
  ):
 
 
25
  self.image_dir = image_dir
26
  self.mask_dir = mask_dir
27
  self.inference_csv = inference_csv
@@ -29,12 +31,16 @@ class Feedback():
29
 
30
  def save_inference(self, pt_coords:List, pt_labels:List, image: Image.Image, mask: np.ndarray):
31
  self.inference_id = uuid.uuid4()
 
 
 
 
32
  write_row(
33
  filepath=self.inference_csv,
34
  row = {
35
  "inference_id": self.inference_id,
36
- "image": image.tobytes(),
37
- "mask": mask.tobytes(),
38
  "pt_coords": str(pt_coords),
39
  "pt_labels": str(pt_labels),
40
  }
 
17
 
18
  class Feedback():
19
  def __init__(self,
20
+ image_dir = './data/image',
21
  mask_dir = './data/mask',
22
  inference_csv = './data/inference.csv',
23
  feedback_csv = './data/feedback.csv',
24
  ):
25
+ os.makedirs(image_dir, exist_ok=True)
26
+ os.makedirs(mask_dir, exist_ok=True)
27
  self.image_dir = image_dir
28
  self.mask_dir = mask_dir
29
  self.inference_csv = inference_csv
 
31
 
32
  def save_inference(self, pt_coords:List, pt_labels:List, image: Image.Image, mask: np.ndarray):
33
  self.inference_id = uuid.uuid4()
34
+ image_path = os.path.join(self.image_dir,f'{self.inference_id}.png')
35
+ mask_path = os.path.join(self.mask_dir, f'{self.inference_id}.npy')
36
+ image.save(image_path)
37
+ np.save(mask_path, mask)
38
  write_row(
39
  filepath=self.inference_csv,
40
  row = {
41
  "inference_id": self.inference_id,
42
+ "image": image_path,
43
+ "mask": mask_path,
44
  "pt_coords": str(pt_coords),
45
  "pt_labels": str(pt_labels),
46
  }