hsshin98
commited on
Commit
•
e20de5f
1
Parent(s):
aff8d56
optimize
Browse files- app.py +3 -4
- demo/predictor.py +12 -1
app.py
CHANGED
@@ -82,10 +82,7 @@ def save_masks(preds, text):
|
|
82 |
cv2.imwrite(dir, mask * 255)
|
83 |
|
84 |
def predict(image, text):
|
85 |
-
|
86 |
-
cfg = setup_cfg(args)
|
87 |
-
demo = VisualizationDemo(cfg, text=text)
|
88 |
-
predictions, visualized_output = demo.run_on_image(image)
|
89 |
#save_masks(predictions, text.split(','))
|
90 |
canvas = fc(visualized_output.fig)
|
91 |
canvas.draw()
|
@@ -96,6 +93,8 @@ def predict(image, text):
|
|
96 |
if __name__ == "__main__":
|
97 |
args = get_parser().parse_args()
|
98 |
cfg = setup_cfg(args)
|
|
|
|
|
99 |
|
100 |
iface = gr.Interface(
|
101 |
fn=predict,
|
|
|
82 |
cv2.imwrite(dir, mask * 255)
|
83 |
|
84 |
def predict(image, text):
|
85 |
+
predictions, visualized_output = demo.run_on_image(image, text)
|
|
|
|
|
|
|
86 |
#save_masks(predictions, text.split(','))
|
87 |
canvas = fc(visualized_output.fig)
|
88 |
canvas.draw()
|
|
|
93 |
if __name__ == "__main__":
|
94 |
args = get_parser().parse_args()
|
95 |
cfg = setup_cfg(args)
|
96 |
+
global demo
|
97 |
+
demo = VisualizationDemo(cfg)
|
98 |
|
99 |
iface = gr.Interface(
|
100 |
fn=predict,
|
demo/predictor.py
CHANGED
@@ -49,7 +49,7 @@ class VisualizationDemo(object):
|
|
49 |
self.metadata = ns()
|
50 |
self.metadata.stuff_classes = pred.test_class_texts
|
51 |
|
52 |
-
def run_on_image(self, image):
|
53 |
"""
|
54 |
Args:
|
55 |
image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
@@ -59,6 +59,17 @@ class VisualizationDemo(object):
|
|
59 |
vis_output (VisImage): the visualized image output.
|
60 |
"""
|
61 |
vis_output = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
predictions = self.predictor(image)
|
63 |
# Convert image from OpenCV BGR format to Matplotlib RGB format.
|
64 |
image = image[:, :, ::-1]
|
|
|
49 |
self.metadata = ns()
|
50 |
self.metadata.stuff_classes = pred.test_class_texts
|
51 |
|
52 |
+
def run_on_image(self, image, text=None):
|
53 |
"""
|
54 |
Args:
|
55 |
image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
|
|
59 |
vis_output (VisImage): the visualized image output.
|
60 |
"""
|
61 |
vis_output = None
|
62 |
+
|
63 |
+
if text is not None:
|
64 |
+
pred = self.predictor.model.sem_seg_head.predictor
|
65 |
+
pred.test_class_texts = text.split(',')
|
66 |
+
pred.text_features_test = pred.class_embeddings(pred.test_class_texts,
|
67 |
+
#imagenet_templates.IMAGENET_TEMPLATES,
|
68 |
+
['A photo of a {} in the scene',],
|
69 |
+
pred.clip_model).permute(1, 0, 2).float().repeat(1, 80, 1)
|
70 |
+
self.metadata = ns()
|
71 |
+
self.metadata.stuff_classes = pred.test_class_texts
|
72 |
+
|
73 |
predictions = self.predictor(image)
|
74 |
# Convert image from OpenCV BGR format to Matplotlib RGB format.
|
75 |
image = image[:, :, ::-1]
|