import itertools import json import os import re from collections import namedtuple import torch from tqdm import tqdm class InferenceSampler(torch.utils.data.sampler.Sampler): def __init__(self, size): self._size = int(size) assert size > 0 self._rank = torch.distributed.get_rank() self._world_size = torch.distributed.get_world_size() self._local_indices = self._get_local_indices(size, self._world_size, self._rank) @staticmethod def _get_local_indices(total_size, world_size, rank): shard_size = total_size // world_size left = total_size % world_size shard_sizes = [shard_size + int(r < left) for r in range(world_size)] begin = sum(shard_sizes[:rank]) end = min(sum(shard_sizes[:rank + 1]), total_size) return range(begin, end) def __iter__(self): yield from self._local_indices def __len__(self): return len(self._local_indices) def collate_fn_vqa(batches): ''' ''' image_paths = [_['image_path'] for _ in batches] questions = [_['question'] for _ in batches] gt_answers = [_['gt_answers'] for _ in batches] ocr_tokens = [_['ocr_tokens'] if 'ocr_tokens' in _ else None for _ in batches] question_ids = [_['question_id'] if 'question_id' in _ else None for _ in batches] question_type = [_['question_type'] if 'question_type' in _ else None for _ in batches] return image_paths, questions, gt_answers, ocr_tokens, question_ids, question_type def has_word(sentence, word): if word[0].isalnum(): start_pattern = r"\b" else: start_pattern = r"" if word[-1].isalnum(): end_pattern = r"\b" else: end_pattern = r"" pattern = start_pattern + re.escape(word) + end_pattern match = re.search(pattern, sentence) return bool(match) def remove_special_chars(s): pattern = r"[^a-zA-Z0-9\s]" s = re.sub(pattern, "", s) return s def levenshtein_distance(s1, s2): if len(s1) > len(s2): s1, s2 = s2, s1 distances = range(len(s1) + 1) for i2, c2 in enumerate(s2): distances_ = [i2+1] for i1, c1 in enumerate(s1): if c1 == c2: distances_.append(distances[i1]) else: distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) distances = distances_ return distances[-1] class VQAEval: def __init__(self): self.contractions = { "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": "you'll", "youre": "you're", "youve": "you've", } self.manualMap = { "none": "0", "zero": "0", "one": "1", "two": "2", "three": "3", "four": "4", "five": "5", "six": "6", "seven": "7", "eight": "8", "nine": "9", "ten": "10", } self.articles = ["a", "an", "the"] self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") self.commaStrip = re.compile("(\d)(\,)(\d)") self.punct = [ ";", r"/", "[", "]", '"', "{", "}", "(", ")", "=", "+", "\\", "_", "-", ">", "<", "@", "`", ",", "?", "!", ] def clean_text(self, text): text = text.replace("\n", " ").replace("\t", " ").strip() text = self.processPunctuation(text) text = self.processDigitArticle(text) return text def evaluate_vqa_human(self, answer, gt_answers): '''TextVQA, VQAv2, OKVQA, vizwiz''' answer = answer.replace("\n", " ").replace("\t", " ").strip() answer = self.processPunctuation(answer) answer = self.processDigitArticle(answer) gt_answers = [self.processPunctuation(ans) for ans in gt_answers] gt_answers = [self.processDigitArticle(ans) for ans in gt_answers] gtAcc = [] for idx, gtAnsDatum in enumerate(gt_answers): otherGTAns = gt_answers[:idx] + gt_answers[idx+1:] matchingAns = [item for item in otherGTAns if answer == item] acc = min(1, float(len(matchingAns)) / 3) gtAcc.append(acc) avgGTAcc = float(sum(gtAcc)) / len(gtAcc) if gtAcc else 0 return avgGTAcc def evaluate_anls(self, answer, gt_answers, threshold=0.5): '''DOcVQA, InfographicsVQA, STVQA''' answer = ' '.join(answer.strip().lower().split()) if not isinstance(gt_answers, list): gt_answers = [gt_answers] gt_answers = [' '.join(gt_answer.strip().lower().split()) for gt_answer in gt_answers] values = [] for gt_answer in gt_answers: dist = levenshtein_distance(answer, gt_answer) length = max(len(answer), len(gt_answer)) values.append(0.0 if length == 0 else float(dist) / float(length)) score = 1 - min(values) score = 0 if score < threshold else score return score def processPunctuation(self, inText): outText = inText for p in self.punct: if (p + " " in inText or " " + p in inText) or ( re.search(self.commaStrip, inText) != None ): outText = outText.replace(p, "") else: outText = outText.replace(p, " ") outText = self.periodStrip.sub("", outText, re.UNICODE) return outText def processDigitArticle(self, inText): outText = [] tempText = inText.lower().split() for word in tempText: word = self.manualMap.setdefault(word, word) if word not in self.articles: outText.append(word) else: pass for wordId, word in enumerate(outText): if word in self.contractions: outText[wordId] = self.contractions[word] outText = " ".join(outText) return outText def evaluate_dataset(dataset_name, answer_file_path, model_name, method = None): with open(answer_file_path, 'r', encoding='utf-8') as f: predictions = json.load(f) eval = VQAEval() total_accuracy = 0 num = 0 Entry = namedtuple('Entry', ['text', 'bbox']) for item in predictions: gt_answers = item['gt_answers'] answer = item['answer'] if method is not None: pass if dataset_name in ["textVQA"]: if num == 0: print(f"evaluating vqa...") accuracy = eval.evaluate_vqa_human(answer, gt_answers) elif dataset_name in ['docVQA']: if num == 0: print(f"evaluating anls...") accuracy = eval.evaluate_anls(answer, gt_answers) else: accuracy = eval.evaluate_has(answer, gt_answers) item['accuracy'] = accuracy total_accuracy += accuracy num += 1 average_accuracy = total_accuracy / num print(f'{dataset_name}:{average_accuracy}') answer_model_method_path = answer_file_path.replace('.json', f'_{model_name}_{method}.json') with open(answer_model_method_path, "w", encoding='utf-8') as f: json.dump(predictions, f, indent=4, ensure_ascii=False) return average_accuracy def evaluate_VQA( model, dataset, model_name, dataset_name, time, batch_size=1, generate_method="interleave", answer_path='./answers', ): print(f"answer path:{answer_path}") sampler = None if torch.distributed.is_initialized(): sampler=InferenceSampler(len(dataset)) dataloader = torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, sampler=sampler, collate_fn=collate_fn_vqa ) now_rank = torch.distributed.get_rank() answer_dir = os.path.join(answer_path, model_name, time) os.makedirs(answer_dir, exist_ok=True) image_list = [] for item in dataset: image_list.append(item["image_path"]) predictions = [] for batch in tqdm(dataloader, desc="Running inference"): image_paths, questions, gt_answers, ocr_tokens_list, question_ids, question_type = batch with torch.no_grad(): if model_name != "minicpm": if model_name != "codellama": outputs = model.generate(images=image_paths, questions=questions, datasetname=dataset_name) else: outputs = model.generate() elif model_name == "minicpm": if generate_method == "old": outputs = model.generate(images=image_paths, questions=questions, datasetname=dataset_name) elif generate_method == "interleave": outputs = model.generate_with_interleaved(images=image_paths, questions=questions, datasetname=dataset_name) else: raise Exception(f"Wrong generate paradigm {generate_method}!") for i in range(len(outputs)): answer_dict = { 'question_id': question_ids[i], 'question': questions[i], 'answer': outputs[i], 'gt_answers': gt_answers[i], 'image_path': image_paths[i], 'model_name': model_name, 'question_type': question_type[i] } predictions.append(answer_dict) if torch.distributed.is_initialized(): torch.distributed.barrier() if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() merged_predictions = [None for _ in range(world_size)] torch.distributed.all_gather_object(merged_predictions, predictions) predictions = [_ for _ in itertools.chain.from_iterable(merged_predictions)] if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0: return None answer_file_path = os.path.join(answer_dir, f"{dataset_name}.json") print(f"answer_file_path:{answer_file_path}") with open(answer_file_path, "w", encoding='utf-8') as f: json.dump(predictions, f, indent=4, ensure_ascii=False) if dataset_name in ["docVQATest"]: return -1.0 return evaluate_dataset(answer_file_path=answer_file_path, dataset_name=dataset_name, model_name=model_name)