|
import os |
|
from typing import List |
|
|
|
from mmengine.dataset import Compose |
|
from torch.utils.data import Dataset |
|
|
|
from opencompass.registry import DATASETS |
|
|
|
|
|
@DATASETS.register_module() |
|
class MMEDataset(Dataset): |
|
"""Dataset to load MME dataset. |
|
|
|
Args: |
|
data_dir (str): The path of the dataset. |
|
pipeline (List[dict]): The data augmentation. |
|
""" |
|
tasks = [ |
|
'artwork', 'celebrity', 'code_reasoning', 'color', |
|
'commonsense_reasoning', 'count', 'existence', 'landmark', |
|
'numerical_calculation', 'OCR', 'position', 'posters', 'scene', |
|
'text_translation' |
|
] |
|
sub_dir_name = ('images', 'questions_answers_YN') |
|
|
|
def __init__(self, data_dir: str, pipeline: List[dict]) -> None: |
|
self.pipeline = Compose(pipeline) |
|
self.load_data(data_dir) |
|
|
|
def load_data(self, data_dir: str): |
|
self.data_list = [] |
|
image_dir, question_dir = self.sub_dir_name |
|
for task in self.tasks: |
|
if os.path.exists(os.path.join(data_dir, task, question_dir)): |
|
q_list = os.listdir(os.path.join(data_dir, task, question_dir)) |
|
i_list = os.listdir(os.path.join(data_dir, task, image_dir)) |
|
q_prefix = os.path.join(data_dir, task, question_dir) |
|
i_prefix = os.path.join(data_dir, task, image_dir) |
|
else: |
|
fn_list = os.listdir(os.path.join(data_dir, task)) |
|
q_list = [fn for fn in fn_list if '.txt' in fn] |
|
i_list = [fn for fn in fn_list if fn not in q_list] |
|
q_prefix = i_prefix = os.path.join(data_dir, task) |
|
|
|
q_list.sort() |
|
i_list.sort() |
|
assert len(q_list) == len(i_list) |
|
for q_fn, i_fn in zip(q_list, i_list): |
|
assert q_fn.split('.')[0] == i_fn.split('.')[0] |
|
q_path = os.path.join(q_prefix, q_fn) |
|
image_path = os.path.join(i_prefix, i_fn) |
|
with open(q_path, 'r') as f: |
|
q1, a1 = f.readline().strip().split('\t') |
|
q2, a2 = f.readline().strip().split('\t') |
|
self.data_list.append({ |
|
'img_path': image_path, |
|
'question': q1, |
|
'answer': a1, |
|
'task': task |
|
}) |
|
self.data_list.append({ |
|
'img_path': image_path, |
|
'question': q2, |
|
'answer': a2, |
|
'task': task |
|
}) |
|
|
|
def __len__(self) -> None: |
|
return len(self.data_list) |
|
|
|
def __getitem__(self, idx: int) -> dict: |
|
data_sample = self.data_list[idx] |
|
data_sample = self.pipeline(data_sample) |
|
return data_sample |
|
|