|
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel |
|
import torch |
|
from torch import nn |
|
|
|
class TextEncoder(nn.Module): |
|
T5_VOCAB_SIZE = 32128 |
|
|
|
def __init__(self, |
|
embedding_width=4096, |
|
num_heads=2, |
|
max_seq_len=512): |
|
super().__init__() |
|
mdim = embedding_width // 4 |
|
ffdim = mdim * 2 |
|
self.embedding = nn.Embedding(TextEncoder.T5_VOCAB_SIZE, mdim) |
|
scale = embedding_width ** -0.5 |
|
self.positional_encoding = nn.Parameter(torch.randn(max_seq_len, mdim) * scale) |
|
self.encoder = nn.Sequential( |
|
nn.TransformerEncoderLayer(d_model=mdim, nhead=num_heads, dim_feedforward=ffdim, batch_first=True), |
|
nn.TransformerEncoderLayer(d_model=mdim, nhead=num_heads, dim_feedforward=ffdim, batch_first=True), |
|
nn.Linear(mdim, embedding_width), |
|
nn.ReLU(), |
|
nn.Linear(embedding_width, embedding_width) |
|
) |
|
self.max_seq_len = max_seq_len |
|
|
|
def forward(self, input_ids): |
|
embedding = self.embedding(input_ids) |
|
|
|
padded_embeds = nn.functional.pad(embedding, (0, 0, 0, self.max_seq_len - input_ids.shape[1])) |
|
current_seq_len = self.max_seq_len |
|
positional_embeddings = self.positional_encoding[:current_seq_len, :] |
|
input_to_transformer = padded_embeds + positional_embeddings |
|
|
|
return self.encoder(input_to_transformer) |
|
|
|
class PretrainedTextEncoder(PreTrainedModel): |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = TextEncoder() |
|
|
|
def load_model(self, filepath): |
|
load_model(self.model, filepath) |
|
|
|
def forward(self, x, output_hidden_states=False): |
|
return self.model(x), |
|
|