segment-anything-ui / feedback.py
Peng Shiya
fix: feadback.py
e18db8b
import os
from typing import Dict, List
import uuid
import csv
import numpy as np
from PIL import Image
def write_row(filepath:str, row: Dict):
new_file = not os.path.isfile(filepath)
with open(filepath, mode="a", newline="") as file:
fieldnames = row.keys()
writer = csv.DictWriter(file, fieldnames=fieldnames)
if new_file:
writer.writeheader() # Write header if new file
writer.writerow(row) # Write the row
class Feedback():
def __init__(self,
image_dir = './data/image',
mask_dir = './data/mask',
inference_csv = './data/inference.csv',
feedback_csv = './data/feedback.csv',
):
os.makedirs(image_dir, exist_ok=True)
os.makedirs(mask_dir, exist_ok=True)
self.image_dir = image_dir
self.mask_dir = mask_dir
self.inference_csv = inference_csv
self.feedback_csv = feedback_csv
def save_inference(self, pt_coords:List, pt_labels:List, image: Image.Image, mask: np.ndarray):
self.inference_id = uuid.uuid4()
image_path = os.path.join(self.image_dir,f'{self.inference_id}.png')
mask_path = os.path.join(self.mask_dir, f'{self.inference_id}.npy')
image.save(image_path)
np.save(mask_path, mask)
write_row(
filepath=self.inference_csv,
row = {
"inference_id": self.inference_id,
"image": image_path,
"mask": mask_path,
"pt_coords": str(pt_coords),
"pt_labels": str(pt_labels),
}
)
def save_feedback(self, cutout_idx:int=None, feedback_str:str=None, like:int=None):
write_row(
filepath=self.feedback_csv,
row = {
"inference_id": self.inference_id,
"cutout_idx": cutout_idx,
"feedback_str": feedback_str,
"like": like,
}
)