from models.segment_models.semgent_anything_model import SegmentAnything from models.segment_models.semantic_segment_anything_model import SemanticSegment class RegionSemantic(): def __init__(self) -> None: self.init_models() def init_models(self): self.segment_model = SegmentAnything() self.semantic_segment_model = SemanticSegment() def semantic_prompt_gen(self, anns): """ fliter too small objects and objects with low stability score anns: [{'class_name': 'person', 'bbox': [0.0, 0.0, 0.0, 0.0], 'size': [0, 0], 'stability_score': 0.0}, ...] semantic_prompt: "person: [0.0, 0.0, 0.0, 0.0]; ..." """ # Sort annotations by area in descending order sorted_annotations = sorted(anns, key=lambda x: x['area'], reverse=True) # Select the top 10 largest regions top_10_largest_regions = sorted_annotations[:10] semantic_prompt = "" print('*'*100) print("\nStep3, Semantic Prompt:") for region in top_10_largest_regions: semantic_prompt += region['class_name'] + ': ' + str(region['bbox']) + "; " print(semantic_prompt) print('*'*100) return semantic_prompt def region_semantic(self, img_src): anns = self.segment_model.generate_mask(img_src) anns_w_class = self.semantic_segment_model.semantic_class_w_mask(img_src, anns) return self.semantic_prompt_gen(anns_w_class) def region_semantic_debug(self, img_src): return "region_semantic_debug"