slot_tagging / dataset.py
xjlulu's picture
"good run"
fba58f1
raw
history blame
2 kB
# dataset.py
from typing import List, Dict
import torch
from torch.utils.data import Dataset
from utils import Vocab
import numpy as np
import re
class SeqClsDataset(Dataset):
def __init__(
self,
data: List[Dict],
vocab: Vocab,
label_mapping: Dict[str, int],
max_len: int,
):
self.data = data
self.vocab = vocab
self.label_mapping = label_mapping
self._idx2label = {idx: intent for intent, idx in self.label_mapping.items()}
self.max_len = max_len
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, index) -> Dict:
instance = self.data[index]
return instance
@property
def num_classes(self) -> int:
return len(self.label_mapping)
def label2idx(self, label: str):
return self.label_mapping[label]
def idx2label(self, idx: int):
return self._idx2label[idx]
class SeqTaggingClsDataset(SeqClsDataset):
def collate_fn(self, samples: List[Dict]) -> Dict:
batch_size = len(samples['tokens'])
tokens = samples["tokens"]
tags = samples["tags"] # list[str]
batch_data = self.vocab.token_to_id("[PAD]") * np.ones((batch_size, self.max_len))
batch_labels = -1 * np.ones((batch_size, self.max_len))
# Copy the data to the numpy array
for j in range(batch_size):
tokens[j] = eval(tokens[j])
cur_len = len(tokens[j])
tags[j] = [self.label_mapping["O"]] * cur_len
batch_data[j][:cur_len] = self.vocab.encode(tokens[j])
batch_labels[j][:cur_len] = tags[j]
# Convert integer index sequences to PyTorch tensors
batch_data = torch.LongTensor(batch_data)
batch_labels = torch.LongTensor(batch_labels)
# Create a batch data dictionary
batch_data = {
"encoded_tokens": batch_data,
"encoded_tags": batch_labels
}
return batch_data