event_detection_app / event_detection_dataset.py
SHSH0819's picture
Upload event_detection_dataset.py
0a0161b
raw
history blame
2.7 kB
import numpy as np
import torch
from torch.utils.data import Dataset
from collections import defaultdict
class event_detection_data(Dataset):
def __init__(self, raw_data, tokenizer, max_len, domain_adaption=False, wwm_prob=0.1):
self.len = len(raw_data)
self.data = raw_data
self.tokenizer = tokenizer
self.max_len = max_len
self.domain_adaption = domain_adaption
self.wwm_prob = wwm_prob
def __getitem__(self, index):
tokenized_inputs = self.tokenizer(
self.data[index]["text"],
add_special_tokens=True,
max_length=self.max_len,
padding='max_length',
return_token_type_ids=True,
truncation=True,
is_split_into_words=True
)
ids = tokenized_inputs['input_ids']
mask = tokenized_inputs['attention_mask']
if self.domain_adaption:
if self.tokenizer.is_fast:
input_ids, labels = self._whole_word_masking(self.tokenizer, tokenized_inputs, self.wwm_prob)
return {
'input_ids': torch.tensor(input_ids),
'attention_mask': torch.tensor(mask),
'labels': torch.tensor(labels, dtype=torch.long)
}
else:
print("requires fast tokenizer for word_ids")
else:
return {
'input_ids': torch.tensor(ids),
'attention_mask': torch.tensor(mask),
'targets': torch.tensor(self.data[index]["text_tag_id"][0], dtype=torch.long)
}
def __len__(self):
return self.len
def _whole_word_masking(self, tokenizer, tokenized_inputs, wwm_prob):
word_ids = tokenized_inputs.word_ids(0)
# create a map between words_ids and natural id
mapping = defaultdict(list)
current_word_index = -1
current_word = None
for idx, word_id in enumerate(word_ids):
if word_id is not None:
if word_id != current_word:
current_word = word_id
current_word_index += 1
mapping[current_word_index].append(idx)
# randomly mask words
mask = np.random.binomial(1, wwm_prob, (len(mapping),))
input_ids = tokenized_inputs["input_ids"]
# labels only contains masked words as target
labels = [-100] * len(input_ids)
for word_id in np.where(mask == 1)[0]:
for idx in mapping[word_id]:
labels[idx] = tokenized_inputs["input_ids"][idx]
input_ids[idx] = tokenizer.mask_token_id
return input_ids, labels