Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from typing import List | |
from torch.utils.data import Dataset | |
class MMapIndexDataset(Dataset): | |
# datapaths 是所有的内存映射文件的路径 | |
# input_tensor_name 是输入的tensor的名字 例如 ['input_ids'] 会存储在对应的文件里面 | |
def __init__(self, datapaths: List[str], input_tensor_name: List[str]): | |
dict_idx_fp = {} | |
dict_bin_fp = {} | |
idx_len = [] | |
for tensor_name in input_tensor_name: | |
idx_fp = [] | |
bin_fp = [] | |
len = 0 | |
for data_path in datapaths: | |
idx_fp += [np.load( | |
data_path + '_' + tensor_name + '.npy', mmap_mode='r')] | |
bin_fp += [np.memmap( | |
data_path + '_' + tensor_name + '.bin', | |
dtype='long', | |
mode='r')] | |
len += idx_fp[-1].shape[0] | |
idx_len += [idx_fp[-1].shape[0]] | |
dict_idx_fp[tensor_name] = idx_fp | |
dict_bin_fp[tensor_name] = bin_fp | |
# 通常情况下不同的tensor的长度是一样的 | |
self._len = len | |
self._input_tensor_name = input_tensor_name | |
self._dict_idx_fp = dict_idx_fp | |
self._dict_bin_fp = dict_bin_fp | |
self._idx_len = idx_len | |
def __len__(self): | |
return self._len | |
def __getitem__(self, idx): | |
sample = {} | |
for i in range(len(self._idx_len)): | |
if idx >= self._idx_len[i]: | |
idx -= self._idx_len[i] | |
else: | |
break | |
for tensor_name in self._input_tensor_name: | |
sample[tensor_name] = torch.tensor(self._dict_bin_fp[tensor_name][i][ | |
self._dict_idx_fp[tensor_name][i][idx, 0]: | |
self._dict_idx_fp[tensor_name][i][idx, 1] | |
], dtype=torch.long) | |
# print(sample) | |
return sample | |