File size: 2,002 Bytes
fba58f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# dataset.py

from typing import List, Dict
import torch
from torch.utils.data import Dataset
from utils import Vocab
import numpy as np
import re

class SeqClsDataset(Dataset):
    def __init__(
        self,
        data: List[Dict],
        vocab: Vocab,
        label_mapping: Dict[str, int],
        max_len: int,
    ):
        self.data = data
        self.vocab = vocab
        self.label_mapping = label_mapping
        self._idx2label = {idx: intent for intent, idx in self.label_mapping.items()}
        self.max_len = max_len

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index) -> Dict:
        instance = self.data[index]
        return instance

    @property
    def num_classes(self) -> int:
        return len(self.label_mapping)

    def label2idx(self, label: str):
        return self.label_mapping[label]

    def idx2label(self, idx: int):
        return self._idx2label[idx]


class SeqTaggingClsDataset(SeqClsDataset):
    def collate_fn(self, samples: List[Dict]) -> Dict:
        batch_size = len(samples['tokens'])

        tokens = samples["tokens"]
        tags = samples["tags"]  # list[str]

        batch_data = self.vocab.token_to_id("[PAD]") * np.ones((batch_size, self.max_len))
        batch_labels = -1 * np.ones((batch_size, self.max_len))

        # Copy the data to the numpy array
        for j in range(batch_size):
            tokens[j] = eval(tokens[j])
            cur_len = len(tokens[j])
            tags[j] = [self.label_mapping["O"]] * cur_len

            batch_data[j][:cur_len] = self.vocab.encode(tokens[j])
            batch_labels[j][:cur_len] = tags[j]

        # Convert integer index sequences to PyTorch tensors
        batch_data = torch.LongTensor(batch_data)
        batch_labels = torch.LongTensor(batch_labels)

        # Create a batch data dictionary
        batch_data = {
            "encoded_tokens": batch_data,
            "encoded_tags": batch_labels
        }
        return batch_data