Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) Tencent Inc. All rights reserved. | |
import json | |
import random | |
from typing import Tuple | |
import numpy as np | |
from mmyolo.registry import TRANSFORMS | |
class RandomLoadText: | |
def __init__(self, | |
text_path: str = None, | |
prompt_format: str = '{}', | |
num_neg_samples: Tuple[int, int] = (80, 80), | |
max_num_samples: int = 80, | |
padding_to_max: bool = False, | |
padding_value: str = '') -> None: | |
self.prompt_format = prompt_format | |
self.num_neg_samples = num_neg_samples | |
self.max_num_samples = max_num_samples | |
self.padding_to_max = padding_to_max | |
self.padding_value = padding_value | |
if text_path is not None: | |
with open(text_path, 'r') as f: | |
self.class_texts = json.load(f) | |
def __call__(self, results: dict) -> dict: | |
assert 'texts' in results or hasattr(self, 'class_texts'), ( | |
'No texts found in results.') | |
class_texts = results.get( | |
'texts', | |
getattr(self, 'class_texts', None)) | |
num_classes = len(class_texts) | |
if 'gt_labels' in results: | |
gt_label_tag = 'gt_labels' | |
elif 'gt_bboxes_labels' in results: | |
gt_label_tag = 'gt_bboxes_labels' | |
else: | |
raise ValueError('No valid labels found in results.') | |
positive_labels = set(results[gt_label_tag]) | |
if len(positive_labels) > self.max_num_samples: | |
positive_labels = set(random.sample(list(positive_labels), | |
k=self.max_num_samples)) | |
num_neg_samples = min( | |
min(num_classes, self.max_num_samples) - len(positive_labels), | |
random.randint(*self.num_neg_samples)) | |
candidate_neg_labels = [] | |
for idx in range(num_classes): | |
if idx not in positive_labels: | |
candidate_neg_labels.append(idx) | |
negative_labels = random.sample( | |
candidate_neg_labels, k=num_neg_samples) | |
sampled_labels = list(positive_labels) + list(negative_labels) | |
random.shuffle(sampled_labels) | |
label2ids = {label: i for i, label in enumerate(sampled_labels)} | |
gt_valid_mask = np.zeros(len(results['gt_bboxes']), dtype=bool) | |
for idx, label in enumerate(results[gt_label_tag]): | |
if label in label2ids: | |
gt_valid_mask[idx] = True | |
results[gt_label_tag][idx] = label2ids[label] | |
results['gt_bboxes'] = results['gt_bboxes'][gt_valid_mask] | |
results[gt_label_tag] = results[gt_label_tag][gt_valid_mask] | |
if 'instances' in results: | |
retaged_instances = [] | |
for idx, inst in enumerate(results['instances']): | |
label = inst['bbox_label'] | |
if label in label2ids: | |
inst['bbox_label'] = label2ids[label] | |
retaged_instances.append(inst) | |
results['instances'] = retaged_instances | |
texts = [] | |
for label in sampled_labels: | |
cls_caps = class_texts[label] | |
assert len(cls_caps) > 0 | |
cap_id = random.randrange(len(cls_caps)) | |
sel_cls_cap = self.prompt_format.format(cls_caps[cap_id]) | |
texts.append(sel_cls_cap) | |
if self.padding_to_max: | |
num_valid_labels = len(positive_labels) + len(negative_labels) | |
num_padding = self.max_num_samples - num_valid_labels | |
if num_padding > 0: | |
texts += [self.padding_value] * num_padding | |
results['texts'] = texts | |
return results | |
class LoadText: | |
def __init__(self, | |
text_path: str = None, | |
prompt_format: str = '{}', | |
multi_prompt_flag: str = '/') -> None: | |
self.prompt_format = prompt_format | |
self.multi_prompt_flag = multi_prompt_flag | |
if text_path is not None: | |
with open(text_path, 'r') as f: | |
self.class_texts = json.load(f) | |
def __call__(self, results: dict) -> dict: | |
assert 'texts' in results or hasattr(self, 'class_texts'), ( | |
'No texts found in results.') | |
class_texts = results.get( | |
'texts', | |
getattr(self, 'class_texts', None)) | |
texts = [] | |
for idx, cls_caps in enumerate(class_texts): | |
assert len(cls_caps) > 0 | |
sel_cls_cap = cls_caps[0] | |
sel_cls_cap = self.prompt_format.format(sel_cls_cap) | |
texts.append(sel_cls_cap) | |
results['texts'] = texts | |
return results | |