Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import itertools | |
import logging | |
import os | |
import numpy as np | |
from fairseq import tokenizer, utils | |
from fairseq.data import ConcatDataset, Dictionary, data_utils, indexed_dataset | |
from fairseq.data.legacy.block_pair_dataset import BlockPairDataset | |
from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset | |
from fairseq.data.legacy.masked_lm_dictionary import BertDictionary | |
from fairseq.tasks import LegacyFairseqTask, register_task | |
logger = logging.getLogger(__name__) | |
class LegacyMaskedLMTask(LegacyFairseqTask): | |
""" | |
Task for training Masked LM (BERT) model. | |
Args: | |
dictionary (Dictionary): the dictionary for the input of the task | |
""" | |
def add_args(parser): | |
"""Add task-specific arguments to the parser.""" | |
parser.add_argument( | |
"data", | |
help="colon separated path to data directories list, \ | |
will be iterated upon during epochs in round-robin manner", | |
) | |
parser.add_argument( | |
"--tokens-per-sample", | |
default=512, | |
type=int, | |
help="max number of total tokens over all segments" | |
" per sample for BERT dataset", | |
) | |
parser.add_argument( | |
"--break-mode", default="doc", type=str, help="mode for breaking sentence" | |
) | |
parser.add_argument("--shuffle-dataset", action="store_true", default=False) | |
def __init__(self, args, dictionary): | |
super().__init__(args) | |
self.dictionary = dictionary | |
self.seed = args.seed | |
def load_dictionary(cls, filename): | |
return BertDictionary.load(filename) | |
def build_dictionary( | |
cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 | |
): | |
d = BertDictionary() | |
for filename in filenames: | |
Dictionary.add_file_to_dictionary( | |
filename, d, tokenizer.tokenize_line, workers | |
) | |
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) | |
return d | |
def target_dictionary(self): | |
return self.dictionary | |
def setup_task(cls, args, **kwargs): | |
"""Setup the task.""" | |
paths = utils.split_paths(args.data) | |
assert len(paths) > 0 | |
dictionary = BertDictionary.load(os.path.join(paths[0], "dict.txt")) | |
logger.info("dictionary: {} types".format(len(dictionary))) | |
return cls(args, dictionary) | |
def load_dataset(self, split, epoch=1, combine=False): | |
"""Load a given dataset split. | |
Args: | |
split (str): name of the split (e.g., train, valid, test) | |
""" | |
loaded_datasets = [] | |
paths = utils.split_paths(self.args.data) | |
assert len(paths) > 0 | |
data_path = paths[(epoch - 1) % len(paths)] | |
logger.info("data_path", data_path) | |
for k in itertools.count(): | |
split_k = split + (str(k) if k > 0 else "") | |
path = os.path.join(data_path, split_k) | |
ds = indexed_dataset.make_dataset( | |
path, | |
impl=self.args.dataset_impl, | |
fix_lua_indexing=True, | |
dictionary=self.dictionary, | |
) | |
if ds is None: | |
if k > 0: | |
break | |
else: | |
raise FileNotFoundError( | |
"Dataset not found: {} ({})".format(split, data_path) | |
) | |
with data_utils.numpy_seed(self.seed + k): | |
loaded_datasets.append( | |
BlockPairDataset( | |
ds, | |
self.dictionary, | |
ds.sizes, | |
self.args.tokens_per_sample, | |
break_mode=self.args.break_mode, | |
doc_break_size=1, | |
) | |
) | |
logger.info( | |
"{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1])) | |
) | |
if not combine: | |
break | |
if len(loaded_datasets) == 1: | |
dataset = loaded_datasets[0] | |
sizes = dataset.sizes | |
else: | |
dataset = ConcatDataset(loaded_datasets) | |
sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) | |
self.datasets[split] = MaskedLMDataset( | |
dataset=dataset, | |
sizes=sizes, | |
vocab=self.dictionary, | |
pad_idx=self.dictionary.pad(), | |
mask_idx=self.dictionary.mask(), | |
classif_token_idx=self.dictionary.cls(), | |
sep_token_idx=self.dictionary.sep(), | |
shuffle=self.args.shuffle_dataset, | |
seed=self.seed, | |
) | |