File size: 3,907 Bytes
2cc518e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Datasets for converting to MDS Shards."""
import os
import warnings
from typing import Dict, Iterable, Union
import datasets as hf_datasets
import numpy as np
from torch.utils.data import IterableDataset
from transformers import PreTrainedTokenizerBase

class NoConcatDataset(IterableDataset):
    """An IterableDataset that returns text samples for MDSWriter.

    Returns dicts of {'text': bytes}
    """

    def __init__(self, hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset]):
        self.hf_dataset = hf_dataset

    def __iter__(self) -> Iterable[Dict[str, bytes]]:
        for sample in self.hf_dataset:
            yield {'text': sample['text'].encode('utf-8')}

class ConcatTokensDataset(IterableDataset):
    """An IterableDataset that returns token samples for MDSWriter.

    Returns dicts of {'tokens': bytes}

    To use data created by this class and written to MDS format:

    ```python
        import torch
        from streaming.base import StreamingDataset
        from transformers import AutoTokenizer

        tokenizer = AutoTokenizer.from_pretrained('your/tokenizer')
        ds = StreamingDataset(local='mds-data-folder', split='val')

        # note, you need to copy the numpy array because the original is non-writeable
        # and torch does not support non-writeable tensors, so you get a scary warning and
        # if you do try to write to the tensor you get undefined behavior
        tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy())
        print(tokenizer.decode(tokens))
    ```
    """

    def __init__(self, hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset], tokenizer: PreTrainedTokenizerBase, max_length: int, bos_text: str, eos_text: str, no_wrap: bool):
        self.hf_dataset = hf_dataset
        self.tokenizer = tokenizer
        os.environ['TOKENIZERS_PARALLELISM'] = 'false'
        self.max_length = max_length
        self.bos_text = bos_text
        self.eos_text = eos_text
        self.should_wrap = not no_wrap
        self.bos_tokens = self.tokenizer(self.bos_text, truncation=False, padding=False, add_special_tokens=False)['input_ids']
        if len(self.bos_tokens) > 1:
            warnings.warn(f'You specified --concat_tokens with --bos_text, but your BOS text is not tokenizing to one token                , instead we got {self.bos_tokens}. Quit if this was in error.')
        self.eos_tokens = self.tokenizer(self.eos_text, truncation=False, padding=False, add_special_tokens=False)['input_ids']
        if len(self.eos_tokens) > 1:
            warnings.warn(f'You specified --concat_tokens with --eos_text, but your EOS text is not tokenizing to one token                , instead we got {self.eos_tokens}. Quit if this was in error.')
        eos_text_provided = self.eos_text != ''
        bos_text_provided = self.bos_text != ''
        test_text = self.tokenizer('')
        if len(test_text['input_ids']) > 0 and (eos_text_provided or bos_text_provided):
            message = 'both eos and bos' if eos_text_provided and bos_text_provided else 'eos_text' if eos_text_provided else 'bos_text'
            warnings.warn(f'The provided tokenizer adds special tokens, but you also specified {message}. This may result ' + 'in duplicated special tokens. Please be sure this is what you intend.')

    def __iter__(self) -> Iterable[Dict[str, bytes]]:
        buffer = []
        for sample in self.hf_dataset:
            encoded = self.tokenizer(sample['text'], truncation=False, padding=False)
            iids = encoded['input_ids']
            buffer = buffer + self.bos_tokens + iids + self.eos_tokens
            while len(buffer) >= self.max_length:
                concat_sample = buffer[:self.max_length]
                buffer = buffer[self.max_length:] if self.should_wrap else []
                yield {'tokens': np.asarray(concat_sample).tobytes()}