Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import numpy as np | |
import torch | |
from . import Dictionary, FairseqDataset, data_utils | |
def collate( | |
samples, | |
pad_idx, | |
eos_idx, | |
vocab, | |
left_pad_source=False, | |
left_pad_target=False, | |
input_feeding=True, | |
pad_to_length=None, | |
): | |
assert input_feeding | |
if len(samples) == 0: | |
return {} | |
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): | |
return data_utils.collate_tokens( | |
[s[key] for s in samples], | |
pad_idx, | |
eos_idx=None, # use eos_idx of each sample instead of vocab.eos() | |
left_pad=left_pad, | |
move_eos_to_beginning=move_eos_to_beginning, | |
pad_to_length=pad_to_length, | |
) | |
id = torch.LongTensor([s["id"] for s in samples]) | |
src_tokens = merge( | |
"source", | |
left_pad=left_pad_source, | |
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, | |
) | |
# sort by descending source length | |
src_lengths = torch.LongTensor([s["source"].numel() for s in samples]) | |
src_lengths, sort_order = src_lengths.sort(descending=True) | |
id = id.index_select(0, sort_order) | |
src_tokens = src_tokens.index_select(0, sort_order) | |
prev_output_tokens = None | |
target = None | |
if samples[0].get("target", None) is not None: | |
target = merge( | |
"target", | |
left_pad=left_pad_target, | |
pad_to_length=pad_to_length["target"] | |
if pad_to_length is not None | |
else None, | |
) | |
target = target.index_select(0, sort_order) | |
ntokens = sum(len(s["target"]) for s in samples) | |
if input_feeding: | |
# we create a shifted version of targets for feeding the | |
# previous output token(s) into the next decoder step | |
prev_output_tokens = merge( | |
"target", | |
left_pad=left_pad_target, | |
move_eos_to_beginning=True, | |
pad_to_length=pad_to_length["target"] | |
if pad_to_length is not None | |
else None, | |
) | |
prev_output_tokens = prev_output_tokens.index_select(0, sort_order) | |
else: | |
ntokens = sum(len(s["source"]) for s in samples) | |
batch = { | |
"id": id, | |
"ntokens": ntokens, | |
"net_input": { | |
"src_tokens": src_tokens, | |
"src_lengths": src_lengths, | |
}, | |
"target": target, | |
"target_lengths": torch.LongTensor([len(t) for t in target]), | |
"nsentences": samples[0]["source"].size(0), | |
"sort_order": sort_order, | |
} | |
if prev_output_tokens is not None: | |
batch["net_input"]["prev_output_tokens"] = prev_output_tokens | |
return batch | |
class SpanMaskedTokensDataset(FairseqDataset): | |
""" | |
A wrapper around TokenBlockDataset for T5 dataset. | |
Args: | |
dataset (~torch.utils.data.Dataset): dataset to wrap | |
vocab (~fairseq.data.Dictionary): vocabulary | |
noise_density (float): fraction of the tokens to select as noise. | |
mean_noise_span_length (float): mean noise span length. | |
shuffle (bool, optional): shuffle the elements before batching. | |
Default: ``True`` | |
seed: Seed for random number generator for reproducibility. | |
""" | |
def __init__( | |
self, | |
dataset: torch.utils.data.Dataset, | |
vocab: Dictionary, | |
noise_density: float, | |
mean_noise_span_length: float, | |
shuffle: bool, | |
seed: int = 1, | |
): | |
self.dataset = dataset | |
self.vocab = vocab | |
self.seed = seed | |
self.noise_density = noise_density | |
self.mean_noise_span_length = mean_noise_span_length | |
self.shuffle = shuffle | |
self.epoch = 0 | |
def can_reuse_epoch_itr_across_epochs(self): | |
return True # only the noise changes, not item sizes | |
def set_epoch(self, epoch, **unused): | |
self.epoch = epoch | |
def __getitem__(self, index): | |
with data_utils.numpy_seed(self.seed, self.epoch, index): | |
item = self.dataset[index] | |
assert item[-1] == self.vocab.eos() | |
noise_mask = self.random_spans_noise_mask(len(item)) | |
source_sentinel_ids = self.create_sentinel_ids(noise_mask.astype(np.int8)) | |
source = self.filter_input_ids(item, source_sentinel_ids) | |
target_sentinel_ids = self.create_sentinel_ids( | |
(~noise_mask).astype(np.int8) | |
) | |
target = self.filter_input_ids(item, target_sentinel_ids) | |
return { | |
"id": index, | |
"source": torch.from_numpy(source), | |
"target": torch.from_numpy(target), | |
} | |
def random_spans_noise_mask(self, length): | |
""" | |
This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ . | |
Noise mask consisting of random spans of noise tokens. | |
The number of noise tokens and the number of noise spans and non-noise spans | |
are determined deterministically as follows: | |
num_noise_tokens = round(length * noise_density) | |
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length) | |
Spans alternate between non-noise and noise, beginning with non-noise. | |
Subject to the above restrictions, all masks are equally likely. | |
Args: | |
length: an int32 scalar (length of the incoming token sequence) | |
Returns: | |
a boolean tensor with shape [length] | |
""" | |
orig_length = length | |
num_noise_tokens = int(np.round(length * self.noise_density)) | |
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. | |
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) | |
num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length)) | |
# avoid degeneracy by ensuring positive number of noise spans | |
num_noise_spans = max(num_noise_spans, 1) | |
num_nonnoise_tokens = length - num_noise_tokens | |
# pick the lengths of the noise spans and the non-noise spans | |
def _random_segmentation(num_items, num_segments): | |
""" | |
Partition a sequence of items randomly into non-empty segments. | |
Args: | |
num_items: an integer scalar > 0 | |
num_segments: an integer scalar in [1, num_items] | |
Returns: | |
a Tensor with shape [num_segments] containing positive integers that add up to num_items | |
""" | |
mask_indices = np.arange(num_items - 1) < (num_segments - 1) | |
np.random.shuffle(mask_indices) | |
first_in_segment = np.pad(mask_indices, [[1, 0]]) | |
segment_id = np.cumsum(first_in_segment) | |
# count length of subsegments assuming that list is sorted | |
_, segment_length = np.unique(segment_id, return_counts=True) | |
return segment_length | |
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) | |
nonnoise_span_lengths = _random_segmentation( | |
num_nonnoise_tokens, num_noise_spans | |
) | |
interleaved_span_lengths = np.reshape( | |
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), | |
[num_noise_spans * 2], | |
) | |
span_starts = np.cumsum(interleaved_span_lengths)[:-1] | |
span_start_indicator = np.zeros((length,), dtype=np.int8) | |
span_start_indicator[span_starts] = True | |
span_num = np.cumsum(span_start_indicator) | |
is_noise = np.equal(span_num % 2, 1) | |
return is_noise[:orig_length] | |
def create_sentinel_ids(self, mask_indices): | |
""" | |
Sentinel ids creation given the indices that should be masked. | |
The start indices of each mask are replaced by the sentinel ids in increasing | |
order. Consecutive mask indices to be deleted are replaced with `-1`. | |
""" | |
start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices | |
sentinel_ids = np.where( | |
start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices | |
) | |
# making sure all sentinel tokens are unique over the example | |
sentinel_ids = np.where(sentinel_ids != 0, len(self.vocab) - sentinel_ids, 0) | |
sentinel_ids -= mask_indices - start_indices | |
return sentinel_ids | |
def filter_input_ids(input_ids, sentinel_ids): | |
""" | |
Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting. | |
This will reduce the sequence length from `expanded_inputs_length` to `input_length`. | |
""" | |
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids) | |
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are | |
# masked tokens coming after sentinel tokens and should be removed | |
return input_ids_full[input_ids_full >= 0] | |
def __len__(self): | |
return len(self.dataset) | |
def collater(self, samples, pad_to_length=None): | |
""" | |
Merge a list of samples to form a mini-batch. | |
Args: | |
samples (List[dict]): samples to collate | |
Returns: | |
dict: a mini-batch of data | |
""" | |
return collate( | |
samples, | |
self.vocab.pad(), | |
self.vocab.eos(), | |
self.vocab, | |
pad_to_length=pad_to_length, | |
) | |
def num_tokens(self, index): | |
"""Return the number of tokens in a sample. This value is used to | |
enforce ``--max-tokens`` during batching.""" | |
return self.dataset.sizes[index] | |
def size(self, index): | |
"""Return an example's size as a float or tuple. This value is used when | |
filtering a dataset with ``--max-positions``.""" | |
return self.dataset.sizes[index] | |
def ordered_indices(self): | |
"""Return an ordered list of indices. Batches will be constructed based | |
on this order.""" | |
if self.shuffle: | |
indices = np.random.permutation(len(self)) | |
else: | |
indices = np.arange(len(self)) | |
return indices[np.argsort(self.dataset.sizes[indices], kind="mergesort")] | |
def prefetch(self, indices): | |
self.src.prefetch(indices) | |
self.tgt.prefetch(indices) | |
def supports_prefetch(self): | |
return ( | |
hasattr(self.src, "supports_prefetch") | |
and self.src.supports_prefetch | |
and hasattr(self.tgt, "supports_prefetch") | |
and self.tgt.supports_prefetch | |
) | |