flux-text-encoder-neutered / text_encoder_pico.py
twodgirl's picture
Rename old text encoder.
4afb4b7 verified
raw
history blame
No virus
1.98 kB
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
import torch
from torch import nn
class TextEncoder(nn.Module):
T5_VOCAB_SIZE = 32128
def __init__(self,
embedding_width=4096,
num_heads=2,
max_seq_len=512):
super().__init__()
mdim = embedding_width // 4
ffdim = mdim * 2
self.embedding = nn.Embedding(TextEncoder.T5_VOCAB_SIZE, mdim)
scale = embedding_width ** -0.5
self.positional_encoding = nn.Parameter(torch.randn(max_seq_len, mdim) * scale)
self.encoder = nn.Sequential(
nn.TransformerEncoderLayer(d_model=mdim, nhead=num_heads, dim_feedforward=ffdim, batch_first=True),
nn.TransformerEncoderLayer(d_model=mdim, nhead=num_heads, dim_feedforward=ffdim, batch_first=True),
nn.Linear(mdim, embedding_width),
nn.ReLU(),
nn.Linear(embedding_width, embedding_width)
)
self.max_seq_len = max_seq_len
def forward(self, input_ids):
embedding = self.embedding(input_ids)
# Pad sequence to max length.
padded_embeds = nn.functional.pad(embedding, (0, 0, 0, self.max_seq_len - input_ids.shape[1]))
current_seq_len = self.max_seq_len
positional_embeddings = self.positional_encoding[:current_seq_len, :]
input_to_transformer = padded_embeds + positional_embeddings
return self.encoder(input_to_transformer)
class PretrainedTextEncoder(PreTrainedModel):
# Call by:
# t5 = PretrainedTextEncoder(PretrainedConfig())
# t5.load_model('text_encoder_2.safetensors')
# ...
# FluxPipeline.from_pretrained(..., text_encoder_2=t5)
def __init__(self, config):
super().__init__(config)
self.model = TextEncoder()
def load_model(self, filepath):
load_model(self.model, filepath)
def forward(self, x, output_hidden_states=False):
return self.model(x),