import numpy as np import torch from typing import List from torch.utils.data import Dataset class MMapIndexDataset(Dataset): # datapaths 是所有的内存映射文件的路径 # input_tensor_name 是输入的tensor的名字 例如 ['input_ids'] 会存储在对应的文件里面 def __init__(self, datapaths: List[str], input_tensor_name: List[str]): dict_idx_fp = {} dict_bin_fp = {} idx_len = [] for tensor_name in input_tensor_name: idx_fp = [] bin_fp = [] len = 0 for data_path in datapaths: idx_fp += [np.load( data_path + '_' + tensor_name + '.npy', mmap_mode='r')] bin_fp += [np.memmap( data_path + '_' + tensor_name + '.bin', dtype='long', mode='r')] len += idx_fp[-1].shape[0] idx_len += [idx_fp[-1].shape[0]] dict_idx_fp[tensor_name] = idx_fp dict_bin_fp[tensor_name] = bin_fp #  通常情况下不同的tensor的长度是一样的 self._len = len self._input_tensor_name = input_tensor_name self._dict_idx_fp = dict_idx_fp self._dict_bin_fp = dict_bin_fp self._idx_len = idx_len def __len__(self): return self._len def __getitem__(self, idx): sample = {} for i in range(len(self._idx_len)): if idx >= self._idx_len[i]: idx -= self._idx_len[i] else: break for tensor_name in self._input_tensor_name: sample[tensor_name] = torch.tensor(self._dict_bin_fp[tensor_name][i][ self._dict_idx_fp[tensor_name][i][idx, 0]: self._dict_idx_fp[tensor_name][i][idx, 1] ], dtype=torch.long) # print(sample) return sample