HaloMaster's picture
add fengshen
50f0fbb
import numpy as np
import torch
from typing import List
from torch.utils.data import Dataset
class MMapIndexDataset(Dataset):
# datapaths 是所有的内存映射文件的路径
# input_tensor_name 是输入的tensor的名字 例如 ['input_ids'] 会存储在对应的文件里面
def __init__(self, datapaths: List[str], input_tensor_name: List[str]):
dict_idx_fp = {}
dict_bin_fp = {}
idx_len = []
for tensor_name in input_tensor_name:
idx_fp = []
bin_fp = []
len = 0
for data_path in datapaths:
idx_fp += [np.load(
data_path + '_' + tensor_name + '.npy', mmap_mode='r')]
bin_fp += [np.memmap(
data_path + '_' + tensor_name + '.bin',
dtype='long',
mode='r')]
len += idx_fp[-1].shape[0]
idx_len += [idx_fp[-1].shape[0]]
dict_idx_fp[tensor_name] = idx_fp
dict_bin_fp[tensor_name] = bin_fp
#  通常情况下不同的tensor的长度是一样的
self._len = len
self._input_tensor_name = input_tensor_name
self._dict_idx_fp = dict_idx_fp
self._dict_bin_fp = dict_bin_fp
self._idx_len = idx_len
def __len__(self):
return self._len
def __getitem__(self, idx):
sample = {}
for i in range(len(self._idx_len)):
if idx >= self._idx_len[i]:
idx -= self._idx_len[i]
else:
break
for tensor_name in self._input_tensor_name:
sample[tensor_name] = torch.tensor(self._dict_bin_fp[tensor_name][i][
self._dict_idx_fp[tensor_name][i][idx, 0]:
self._dict_idx_fp[tensor_name][i][idx, 1]
], dtype=torch.long)
# print(sample)
return sample