FinLLaVA / llava_llama3 /eval /eval_textvqa.py
TobyYang7's picture
Upload 96 files
6a83074 verified
raw
history blame
2.23 kB
import os
import argparse
import json
import re
from llava_llama3.eval.m4c_evaluator import TextVQAAccuracyEvaluator
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--annotation-file', type=str)
parser.add_argument('--result-file', type=str)
parser.add_argument('--result-dir', type=str)
return parser.parse_args()
def prompt_processor(prompt):
if prompt.startswith('OCR tokens: '):
pattern = r"Question: (.*?) Short answer:"
match = re.search(pattern, prompt, re.DOTALL)
question = match.group(1)
elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
if prompt.startswith('Reference OCR token:'):
question = prompt.split('\n')[1]
else:
question = prompt.split('\n')[0]
elif len(prompt.split('\n')) == 2:
question = prompt.split('\n')[0]
else:
assert False
return question.lower()
def eval_single(annotation_file, result_file):
experiment_name = os.path.splitext(os.path.basename(result_file))[0]
print(experiment_name)
annotations = json.load(open(annotation_file))['data']
annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations}
results = [json.loads(line) for line in open(result_file)]
pred_list = []
for result in results:
annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))]
pred_list.append({
"pred_answer": result['text'],
"gt_answers": annotation['answers'],
})
evaluator = TextVQAAccuracyEvaluator()
print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list)))
if __name__ == "__main__":
args = get_args()
if args.result_file is not None:
eval_single(args.annotation_file, args.result_file)
if args.result_dir is not None:
for result_file in sorted(os.listdir(args.result_dir)):
if not result_file.endswith('.jsonl'):
print(f'Skipping {result_file}')
continue
eval_single(args.annotation_file, os.path.join(args.result_dir, result_file))