File size: 2,696 Bytes
0a0161b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
import torch
from torch.utils.data import Dataset
from collections import defaultdict



class event_detection_data(Dataset):
    def __init__(self, raw_data, tokenizer, max_len, domain_adaption=False, wwm_prob=0.1):
        self.len = len(raw_data)
        self.data = raw_data
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.domain_adaption = domain_adaption
        self.wwm_prob = wwm_prob

    def __getitem__(self, index):
        tokenized_inputs = self.tokenizer(
            self.data[index]["text"],
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
            is_split_into_words=True
        )

        ids = tokenized_inputs['input_ids']
        mask = tokenized_inputs['attention_mask']

        if self.domain_adaption:
            if self.tokenizer.is_fast:
                input_ids, labels = self._whole_word_masking(self.tokenizer, tokenized_inputs, self.wwm_prob)
                return {
                    'input_ids': torch.tensor(input_ids),
                    'attention_mask': torch.tensor(mask),
                    'labels': torch.tensor(labels, dtype=torch.long)
                }
            else:
                print("requires fast tokenizer for word_ids")
        else:
            return {
                'input_ids': torch.tensor(ids),
                'attention_mask': torch.tensor(mask),
                'targets': torch.tensor(self.data[index]["text_tag_id"][0], dtype=torch.long)
            }

    def __len__(self):
        return self.len

    def _whole_word_masking(self, tokenizer, tokenized_inputs, wwm_prob):
        word_ids = tokenized_inputs.word_ids(0)

        # create a map between words_ids and natural id
        mapping = defaultdict(list)
        current_word_index = -1
        current_word = None

        for idx, word_id in enumerate(word_ids):
            if word_id is not None:
                if word_id != current_word:
                    current_word = word_id
                    current_word_index += 1
                mapping[current_word_index].append(idx)

        # randomly mask words
        mask = np.random.binomial(1, wwm_prob, (len(mapping),))
        input_ids = tokenized_inputs["input_ids"]

        # labels only contains masked words as target
        labels = [-100] * len(input_ids)

        for word_id in np.where(mask == 1)[0]:
            for idx in mapping[word_id]:
                labels[idx] = tokenized_inputs["input_ids"][idx]
                input_ids[idx] = tokenizer.mask_token_id
        return input_ids, labels