Spaces:
Runtime error
Runtime error
File size: 2,002 Bytes
fba58f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
# dataset.py
from typing import List, Dict
import torch
from torch.utils.data import Dataset
from utils import Vocab
import numpy as np
import re
class SeqClsDataset(Dataset):
def __init__(
self,
data: List[Dict],
vocab: Vocab,
label_mapping: Dict[str, int],
max_len: int,
):
self.data = data
self.vocab = vocab
self.label_mapping = label_mapping
self._idx2label = {idx: intent for intent, idx in self.label_mapping.items()}
self.max_len = max_len
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, index) -> Dict:
instance = self.data[index]
return instance
@property
def num_classes(self) -> int:
return len(self.label_mapping)
def label2idx(self, label: str):
return self.label_mapping[label]
def idx2label(self, idx: int):
return self._idx2label[idx]
class SeqTaggingClsDataset(SeqClsDataset):
def collate_fn(self, samples: List[Dict]) -> Dict:
batch_size = len(samples['tokens'])
tokens = samples["tokens"]
tags = samples["tags"] # list[str]
batch_data = self.vocab.token_to_id("[PAD]") * np.ones((batch_size, self.max_len))
batch_labels = -1 * np.ones((batch_size, self.max_len))
# Copy the data to the numpy array
for j in range(batch_size):
tokens[j] = eval(tokens[j])
cur_len = len(tokens[j])
tags[j] = [self.label_mapping["O"]] * cur_len
batch_data[j][:cur_len] = self.vocab.encode(tokens[j])
batch_labels[j][:cur_len] = tags[j]
# Convert integer index sequences to PyTorch tensors
batch_data = torch.LongTensor(batch_data)
batch_labels = torch.LongTensor(batch_labels)
# Create a batch data dictionary
batch_data = {
"encoded_tokens": batch_data,
"encoded_tags": batch_labels
}
return batch_data
|