TwT-6's picture
Upload 2667 files
256a159 verified
raw
history blame
2.76 kB
import os
from typing import List
from mmengine.dataset import Compose
from torch.utils.data import Dataset
from opencompass.registry import DATASETS
@DATASETS.register_module()
class MMEDataset(Dataset):
"""Dataset to load MME dataset.
Args:
data_dir (str): The path of the dataset.
pipeline (List[dict]): The data augmentation.
"""
tasks = [
'artwork', 'celebrity', 'code_reasoning', 'color',
'commonsense_reasoning', 'count', 'existence', 'landmark',
'numerical_calculation', 'OCR', 'position', 'posters', 'scene',
'text_translation'
]
sub_dir_name = ('images', 'questions_answers_YN')
def __init__(self, data_dir: str, pipeline: List[dict]) -> None:
self.pipeline = Compose(pipeline)
self.load_data(data_dir)
def load_data(self, data_dir: str):
self.data_list = []
image_dir, question_dir = self.sub_dir_name
for task in self.tasks:
if os.path.exists(os.path.join(data_dir, task, question_dir)):
q_list = os.listdir(os.path.join(data_dir, task, question_dir))
i_list = os.listdir(os.path.join(data_dir, task, image_dir))
q_prefix = os.path.join(data_dir, task, question_dir)
i_prefix = os.path.join(data_dir, task, image_dir)
else:
fn_list = os.listdir(os.path.join(data_dir, task))
q_list = [fn for fn in fn_list if '.txt' in fn]
i_list = [fn for fn in fn_list if fn not in q_list]
q_prefix = i_prefix = os.path.join(data_dir, task)
q_list.sort()
i_list.sort()
assert len(q_list) == len(i_list)
for q_fn, i_fn in zip(q_list, i_list):
assert q_fn.split('.')[0] == i_fn.split('.')[0]
q_path = os.path.join(q_prefix, q_fn)
image_path = os.path.join(i_prefix, i_fn)
with open(q_path, 'r') as f:
q1, a1 = f.readline().strip().split('\t')
q2, a2 = f.readline().strip().split('\t')
self.data_list.append({
'img_path': image_path,
'question': q1,
'answer': a1,
'task': task
})
self.data_list.append({
'img_path': image_path,
'question': q2,
'answer': a2,
'task': task
})
def __len__(self) -> None:
return len(self.data_list)
def __getitem__(self, idx: int) -> dict:
data_sample = self.data_list[idx]
data_sample = self.pipeline(data_sample)
return data_sample