from transformers.modeling_utils import PretrainedConfig, PreTrainedModel import torch from torch import nn class TextEncoder(nn.Module): T5_VOCAB_SIZE = 32128 def __init__(self, embedding_width=4096, num_heads=2, max_seq_len=512): super().__init__() mdim = embedding_width // 4 ffdim = mdim * 2 self.embedding = nn.Embedding(TextEncoder.T5_VOCAB_SIZE, mdim) scale = embedding_width ** -0.5 self.positional_encoding = nn.Parameter(torch.randn(max_seq_len, mdim) * scale) self.encoder = nn.Sequential( nn.TransformerEncoderLayer(d_model=mdim, nhead=num_heads, dim_feedforward=ffdim, batch_first=True), nn.TransformerEncoderLayer(d_model=mdim, nhead=num_heads, dim_feedforward=ffdim, batch_first=True), nn.Linear(mdim, embedding_width), nn.ReLU(), nn.Linear(embedding_width, embedding_width) ) self.max_seq_len = max_seq_len def forward(self, input_ids): embedding = self.embedding(input_ids) # Pad sequence to max length. padded_embeds = nn.functional.pad(embedding, (0, 0, 0, self.max_seq_len - input_ids.shape[1])) current_seq_len = self.max_seq_len positional_embeddings = self.positional_encoding[:current_seq_len, :] input_to_transformer = padded_embeds + positional_embeddings return self.encoder(input_to_transformer) class PretrainedTextEncoder(PreTrainedModel): # Call by: # t5 = PretrainedTextEncoder(PretrainedConfig()) # t5.load_model('text_encoder_2.safetensors') # ... # FluxPipeline.from_pretrained(..., text_encoder_2=t5) def __init__(self, config): super().__init__(config) self.model = TextEncoder() def load_model(self, filepath): load_model(self.model, filepath) def forward(self, x, output_hidden_states=False): return self.model(x),