Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class SeqClassifier(nn.Module): | |
def __init__( | |
self, | |
embeddings: torch.tensor, | |
hidden_size: int, | |
num_layers: int, | |
dropout: float, | |
bidirectional: bool, | |
num_class: int, | |
) -> None: | |
super(SeqClassifier, self).__init__() | |
self.embed = nn.Embedding.from_pretrained(embeddings, freeze=False) | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
self.dropout = dropout | |
self.bidirectional = bidirectional | |
self.num_class = num_class | |
# Model architecture | |
self.rnn = nn.GRU( | |
input_size=embeddings.size(1), | |
hidden_size=hidden_size, | |
num_layers=num_layers, | |
dropout=dropout, | |
bidirectional=bidirectional, | |
batch_first=True | |
) | |
self.dropout_layer = nn.Dropout(p=self.dropout) | |
self.fc = nn.Linear(self.encoder_output_size, num_class) | |
def encoder_output_size(self) -> int: | |
# Calculate the output dimension of RNN | |
if self.bidirectional: | |
return self.hidden_size * 2 | |
else: | |
return self.hidden_size | |
class SeqTagger(SeqClassifier): | |
def __init__(self, embeddings, hidden_size, num_layers, dropout, bidirectional, num_class): | |
super(SeqTagger, self).__init__(embeddings, hidden_size, num_layers, dropout, bidirectional, num_class) | |
def forward(self, batch) -> torch.Tensor: | |
# Apply the embedding layer that maps each token to its embedding | |
batch = self.embed(batch) | |
# Run the LSTM along the sentences of length batch_max_len | |
batch, _ = self.rnn(batch) # dim: batch_size x max_len x hidden_size | |
batch = self.dropout_layer(batch) | |
if not self.training: | |
# Remove this block after completing train_slot, if batch and predict should be combined | |
batch = batch.reshape(-1, batch.shape[2]) # dim: batch_size*max_len x hidden_size | |
# Pass through the fully connected layer | |
batch = self.fc(batch) | |
return F.log_softmax(batch, dim=1) # dim: batch_size*max_len x num_tags | |
batch = batch.reshape(-1, batch.shape[2]) # dim: batch_size*max_len x hidden_size | |
# Pass through the fully connected layer | |
batch = self.fc(batch) | |
return F.log_softmax(batch, dim=1) # dim: batch_size*max_len x num_tags | |