File size: 1,976 Bytes
9cfcf85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
import torch
from torch import nn

class TextEncoder(nn.Module):
    T5_VOCAB_SIZE = 32128

    def __init__(self,
                 embedding_width=4096,
                 num_heads=2,
                 max_seq_len=512):
        super().__init__()
        mdim = embedding_width // 4
        ffdim = mdim * 2
        self.embedding = nn.Embedding(TextEncoder.T5_VOCAB_SIZE, mdim)
        scale = embedding_width ** -0.5
        self.positional_encoding = nn.Parameter(torch.randn(max_seq_len, mdim) * scale)
        self.encoder = nn.Sequential(
            nn.TransformerEncoderLayer(d_model=mdim, nhead=num_heads, dim_feedforward=ffdim, batch_first=True),
            nn.TransformerEncoderLayer(d_model=mdim, nhead=num_heads, dim_feedforward=ffdim, batch_first=True),
            nn.Linear(mdim, embedding_width),
            nn.ReLU(),
            nn.Linear(embedding_width, embedding_width)
        )
        self.max_seq_len = max_seq_len

    def forward(self, input_ids):
        embedding = self.embedding(input_ids)
        # Pad sequence to max length.
        padded_embeds = nn.functional.pad(embedding, (0, 0, 0, self.max_seq_len - input_ids.shape[1]))
        current_seq_len = self.max_seq_len
        positional_embeddings = self.positional_encoding[:current_seq_len, :]
        input_to_transformer = padded_embeds + positional_embeddings

        return self.encoder(input_to_transformer)

class PretrainedTextEncoder(PreTrainedModel):
    # Call by:
    # t5 = PretrainedTextEncoder(PretrainedConfig())
    # t5.load_model('text_encoder_2.safetensors')
    # ...
    # FluxPipeline.from_pretrained(..., text_encoder_2=t5)
    def __init__(self, config):
        super().__init__(config)
        self.model = TextEncoder()

    def load_model(self, filepath):
        load_model(self.model, filepath)

    def forward(self, x, output_hidden_states=False):
        return self.model(x),