Spaces:
Runtime error
Runtime error
# dataset.py | |
from typing import List, Dict | |
import torch | |
from torch.utils.data import Dataset | |
from utils import Vocab | |
import numpy as np | |
import re | |
class SeqClsDataset(Dataset): | |
def __init__( | |
self, | |
data: List[Dict], | |
vocab: Vocab, | |
label_mapping: Dict[str, int], | |
max_len: int, | |
): | |
self.data = data | |
self.vocab = vocab | |
self.label_mapping = label_mapping | |
self._idx2label = {idx: intent for intent, idx in self.label_mapping.items()} | |
self.max_len = max_len | |
def __len__(self) -> int: | |
return len(self.data) | |
def __getitem__(self, index) -> Dict: | |
instance = self.data[index] | |
return instance | |
def num_classes(self) -> int: | |
return len(self.label_mapping) | |
def label2idx(self, label: str): | |
return self.label_mapping[label] | |
def idx2label(self, idx: int): | |
return self._idx2label[idx] | |
class SeqTaggingClsDataset(SeqClsDataset): | |
def collate_fn(self, samples: List[Dict]) -> Dict: | |
batch_size = len(samples['tokens']) | |
tokens = samples["tokens"] | |
tags = samples["tags"] # list[str] | |
batch_data = self.vocab.token_to_id("[PAD]") * np.ones((batch_size, self.max_len)) | |
batch_labels = -1 * np.ones((batch_size, self.max_len)) | |
# Copy the data to the numpy array | |
for j in range(batch_size): | |
tokens[j] = eval(tokens[j]) | |
cur_len = len(tokens[j]) | |
tags[j] = [self.label_mapping["O"]] * cur_len | |
batch_data[j][:cur_len] = self.vocab.encode(tokens[j]) | |
batch_labels[j][:cur_len] = tags[j] | |
# Convert integer index sequences to PyTorch tensors | |
batch_data = torch.LongTensor(batch_data) | |
batch_labels = torch.LongTensor(batch_labels) | |
# Create a batch data dictionary | |
batch_data = { | |
"encoded_tokens": batch_data, | |
"encoded_tags": batch_labels | |
} | |
return batch_data | |