SunderAli17 commited on
Commit
be0d2f2
1 Parent(s): 2a12457

Create toonmage/fluxencoders.py

Browse files
Files changed (1) hide show
  1. toonmage/fluxencoders.py +207 -0
toonmage/fluxencoders.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ # FFN
8
+ def FeedForward(dim, mult=4):
9
+ inner_dim = int(dim * mult)
10
+ return nn.Sequential(
11
+ nn.LayerNorm(dim),
12
+ nn.Linear(dim, inner_dim, bias=False),
13
+ nn.GELU(),
14
+ nn.Linear(inner_dim, dim, bias=False),
15
+ )
16
+
17
+
18
+ def reshape_tensor(x, heads):
19
+ bs, length, width = x.shape
20
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
21
+ x = x.view(bs, length, heads, -1)
22
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
23
+ x = x.transpose(1, 2)
24
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
25
+ x = x.reshape(bs, heads, length, -1)
26
+ return x
27
+
28
+
29
+ class PerceiverAttentionCA(nn.Module):
30
+ def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048):
31
+ super().__init__()
32
+ self.scale = dim_head ** -0.5
33
+ self.dim_head = dim_head
34
+ self.heads = heads
35
+ inner_dim = dim_head * heads
36
+
37
+ self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
38
+ self.norm2 = nn.LayerNorm(dim)
39
+
40
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
41
+ self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
42
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
43
+
44
+ def forward(self, x, latents):
45
+ """
46
+ Args:
47
+ x (torch.Tensor): image features
48
+ shape (b, n1, D)
49
+ latent (torch.Tensor): latent features
50
+ shape (b, n2, D)
51
+ """
52
+ x = self.norm1(x)
53
+ latents = self.norm2(latents)
54
+
55
+ b, seq_len, _ = latents.shape
56
+
57
+ q = self.to_q(latents)
58
+ k, v = self.to_kv(x).chunk(2, dim=-1)
59
+
60
+ q = reshape_tensor(q, self.heads)
61
+ k = reshape_tensor(k, self.heads)
62
+ v = reshape_tensor(v, self.heads)
63
+
64
+ # attention
65
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
66
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
67
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
68
+ out = weight @ v
69
+
70
+ out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
71
+
72
+ return self.to_out(out)
73
+
74
+
75
+ class PerceiverAttention(nn.Module):
76
+ def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None):
77
+ super().__init__()
78
+ self.scale = dim_head ** -0.5
79
+ self.dim_head = dim_head
80
+ self.heads = heads
81
+ inner_dim = dim_head * heads
82
+
83
+ self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
84
+ self.norm2 = nn.LayerNorm(dim)
85
+
86
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
87
+ self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
88
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
89
+
90
+ def forward(self, x, latents):
91
+ """
92
+ Args:
93
+ x (torch.Tensor): image features
94
+ shape (b, n1, D)
95
+ latent (torch.Tensor): latent features
96
+ shape (b, n2, D)
97
+ """
98
+ x = self.norm1(x)
99
+ latents = self.norm2(latents)
100
+
101
+ b, seq_len, _ = latents.shape
102
+
103
+ q = self.to_q(latents)
104
+ kv_input = torch.cat((x, latents), dim=-2)
105
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
106
+
107
+ q = reshape_tensor(q, self.heads)
108
+ k = reshape_tensor(k, self.heads)
109
+ v = reshape_tensor(v, self.heads)
110
+
111
+ # attention
112
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
113
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
114
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
115
+ out = weight @ v
116
+
117
+ out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
118
+
119
+ return self.to_out(out)
120
+
121
+
122
+ class IDFormer(nn.Module):
123
+ """
124
+ - perceiver resampler like arch (compared with previous MLP-like arch)
125
+ - we concat id embedding (generated by arcface) and query tokens as latents
126
+ - latents will attend each other and interact with vit features through cross-attention
127
+ - vit features are multi-scaled and inserted into IDFormer in order, currently, each scale corresponds to two
128
+ IDFormer layers
129
+ """
130
+ def __init__(
131
+ self,
132
+ dim=1024,
133
+ depth=10,
134
+ dim_head=64,
135
+ heads=16,
136
+ num_id_token=5,
137
+ num_queries=32,
138
+ output_dim=2048,
139
+ ff_mult=4,
140
+ ):
141
+ super().__init__()
142
+
143
+ self.num_id_token = num_id_token
144
+ self.dim = dim
145
+ self.num_queries = num_queries
146
+ assert depth % 5 == 0
147
+ self.depth = depth // 5
148
+ scale = dim ** -0.5
149
+
150
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) * scale)
151
+ self.proj_out = nn.Parameter(scale * torch.randn(dim, output_dim))
152
+
153
+ self.layers = nn.ModuleList([])
154
+ for _ in range(depth):
155
+ self.layers.append(
156
+ nn.ModuleList(
157
+ [
158
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
159
+ FeedForward(dim=dim, mult=ff_mult),
160
+ ]
161
+ )
162
+ )
163
+
164
+ for i in range(5):
165
+ setattr(
166
+ self,
167
+ f'mapping_{i}',
168
+ nn.Sequential(
169
+ nn.Linear(1024, 1024),
170
+ nn.LayerNorm(1024),
171
+ nn.LeakyReLU(),
172
+ nn.Linear(1024, 1024),
173
+ nn.LayerNorm(1024),
174
+ nn.LeakyReLU(),
175
+ nn.Linear(1024, dim),
176
+ ),
177
+ )
178
+
179
+ self.id_embedding_mapping = nn.Sequential(
180
+ nn.Linear(1280, 1024),
181
+ nn.LayerNorm(1024),
182
+ nn.LeakyReLU(),
183
+ nn.Linear(1024, 1024),
184
+ nn.LayerNorm(1024),
185
+ nn.LeakyReLU(),
186
+ nn.Linear(1024, dim * num_id_token),
187
+ )
188
+
189
+ def forward(self, x, y):
190
+
191
+ latents = self.latents.repeat(x.size(0), 1, 1)
192
+
193
+ x = self.id_embedding_mapping(x)
194
+ x = x.reshape(-1, self.num_id_token, self.dim)
195
+
196
+ latents = torch.cat((latents, x), dim=1)
197
+
198
+ for i in range(5):
199
+ vit_feature = getattr(self, f'mapping_{i}')(y[i])
200
+ ctx_feature = torch.cat((x, vit_feature), dim=1)
201
+ for attn, ff in self.layers[i * self.depth: (i + 1) * self.depth]:
202
+ latents = attn(ctx_feature, latents) + latents
203
+ latents = ff(latents) + latents
204
+
205
+ latents = latents[:, :self.num_queries]
206
+ latents = latents @ self.proj_out
207
+ return latents