AlekseyCalvin commited on
Commit
d5ec5a6
1 Parent(s): b5deeef

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +863 -0
app.py ADDED
@@ -0,0 +1,863 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ if os.environ.get("SPACES_ZERO_GPU") is not None:
3
+ import spaces
4
+ else:
5
+ class spaces:
6
+ @staticmethod
7
+ def GPU(func):
8
+ def wrapper(*args, **kwargs):
9
+ return func(*args, **kwargs)
10
+ return wrapper
11
+
12
+ import gradio as gr
13
+ import json
14
+ import logging
15
+ import argparse
16
+ import torch
17
+ import torchvision
18
+ from os import path
19
+ from PIL import Image
20
+ import numpy as np
21
+ import spaces
22
+ import copy
23
+ import random
24
+ import time
25
+ from torchvision import transforms
26
+ from dataclasses import dataclass
27
+
28
+ import math
29
+ from pathlib import Path
30
+ from typing import Any, Callable, Dict, List, Optional, Union
31
+ from huggingface_hub import hf_hub_download
32
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoPipelineForImage2Image, FluxTransformer2DModel
33
+ import safetensors.torch
34
+ from safetensors.torch import load_file
35
+ import random
36
+ from tqdm import tqdm
37
+ from einops import rearrange, repeat
38
+ from torch import Tensor, nn
39
+ from pipeline import FluxWithCFGPipeline
40
+ from transformers import CLIPModel, CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPConfig, T5EncoderModel, T5Tokenizer
41
+ import gc
42
+ import warnings
43
+ model_path = snapshot_download(repo_id="nyanko7/flux-dev-de-distill")
44
+ cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
45
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
46
+ os.environ["HF_HUB_CACHE"] = cache_path
47
+ os.environ["HF_HOME"] = cache_path
48
+
49
+ device = "cuda" if torch.cuda.is_available() else "cpu"
50
+
51
+ torch.backends.cuda.matmul.allow_tf32 = True
52
+
53
+ # Load LoRAs from JSON file
54
+ with open('loras.json', 'r') as f:
55
+ loras = json.load(f)
56
+
57
+ dtype = torch.bfloat16
58
+
59
+ # ---------------- Encoders ----------------
60
+ class HFEmbedder(nn.Module):
61
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
62
+ super().__init__()
63
+ self.is_clip = version.startswith("openai")
64
+ self.max_length = max_length
65
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
66
+
67
+ if self.is_clip:
68
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
69
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
70
+ else:
71
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
72
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
73
+
74
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
75
+
76
+ def forward(self, text: list[str]) -> Tensor:
77
+ batch_encoding = self.tokenizer(
78
+ text,
79
+ truncation=True,
80
+ max_length=self.max_length,
81
+ return_length=False,
82
+ return_overflowing_tokens=False,
83
+ padding="max_length",
84
+ return_tensors="pt",
85
+ )
86
+
87
+ outputs = self.hf_module(
88
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
89
+ attention_mask=None,
90
+ output_hidden_states=False,
91
+ )
92
+ return outputs[self.output_key]
93
+
94
+ pipe = FluxWithCFGPipeline.from_pretrained("ostris/OpenFLUX.1", torch_dtype=dtype).to("cuda")
95
+ pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to("cuda")
96
+
97
+ pipe.to("cuda")
98
+ clipmodel = 'norm'
99
+ if clipmodel == "long":
100
+ model_id = "zer0int/LongCLIP-GmP-ViT-L-14"
101
+ config = CLIPConfig.from_pretrained(model_id)
102
+ maxtokens = 77
103
+ if clipmodel == "norm":
104
+ model_id = "zer0int/CLIP-GmP-ViT-L-14"
105
+ config = CLIPConfig.from_pretrained(model_id)
106
+ maxtokens = 77
107
+ clip_model = CLIPModel.from_pretrained(model_id, torch_dtype=torch.bfloat16, config=config, ignore_mismatched_sizes=True).to("cuda")
108
+ clip_processor = CLIPProcessor.from_pretrained(model_id, padding="max_length", max_length=maxtokens, ignore_mismatched_sizes=True, return_tensors="pt", truncation=True)
109
+ pipe.tokenizer = clip_processor.tokenizer
110
+ pipe.text_encoder = clip_model.text_model
111
+ pipe.text_encoder.dtype = torch.bfloat16
112
+
113
+ pipe.to("cuda")
114
+ #clipmodel = 'norm'
115
+ #if clipmodel == "long":
116
+ # model_id = "zer0int/LongCLIP-GmP-ViT-L-14"
117
+ # config = CLIPConfig.from_pretrained(model_id)
118
+ # maxtokens = 77
119
+ #if clipmodel == "norm":
120
+ # model_id = "zer0int/CLIP-GmP-ViT-L-14"
121
+ # config = CLIPConfig.from_pretrained(model_id)
122
+ # maxtokens = 77
123
+ #clip_model = CLIPModel.from_pretrained(model_id, torch_dtype=torch.bfloat16, config=config, ignore_mismatched_sizes=True).to("cuda")
124
+ #clip_processor = CLIPProcessor.from_pretrained(model_id, padding="max_length", max_length=maxtokens, ignore_mismatched_sizes=True, return_tensors="pt", truncation=True)
125
+
126
+ #pipe.tokenizer = clip_processor.tokenizer
127
+ #pipe.text_encoder = clip_model.text_model
128
+ #pipe.tokenizer_max_length = maxtokens
129
+ #pipe.text_encoder.dtype = torch.bfloat16
130
+
131
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
132
+ q, k = apply_rope(q, k, pe)
133
+
134
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
135
+ # x = rearrange(x, "B H L D -> B L (H D)")
136
+ x = x.permute(0, 2, 1, 3).reshape(x.size(0), x.size(2), -1)
137
+
138
+ return x
139
+
140
+
141
+ def rope(pos, dim, theta):
142
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
143
+ omega = 1.0 / (theta ** scale)
144
+
145
+ # out = torch.einsum("...n,d->...nd", pos, omega)
146
+ out = pos.unsqueeze(-1) * omega.unsqueeze(0)
147
+
148
+ cos_out = torch.cos(out)
149
+ sin_out = torch.sin(out)
150
+ out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
151
+
152
+ # out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
153
+ b, n, d, _ = out.shape
154
+ out = out.view(b, n, d, 2, 2)
155
+
156
+ return out.float()
157
+
158
+
159
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
160
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
161
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
162
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
163
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
164
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
165
+
166
+
167
+ class EmbedND(nn.Module):
168
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
169
+ super().__init__()
170
+ self.dim = dim
171
+ self.theta = theta
172
+ self.axes_dim = axes_dim
173
+
174
+ def forward(self, ids: Tensor) -> Tensor:
175
+ n_axes = ids.shape[-1]
176
+ emb = torch.cat(
177
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
178
+ dim=-3,
179
+ )
180
+
181
+ return emb.unsqueeze(1)
182
+
183
+
184
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
185
+ """
186
+ Create sinusoidal timestep embeddings.
187
+ :param t: a 1-D Tensor of N indices, one per batch element.
188
+ These may be fractional.
189
+ :param dim: the dimension of the output.
190
+ :param max_period: controls the minimum frequency of the embeddings.
191
+ :return: an (N, D) Tensor of positional embeddings.
192
+ """
193
+ t = time_factor * t
194
+ half = dim // 2
195
+
196
+ # Do not block CUDA steam, but having about 1e-4 differences with Flux official codes:
197
+ # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
198
+
199
+ # Block CUDA steam, but consistent with official codes:
200
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
201
+
202
+ args = t[:, None].float() * freqs[None]
203
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
204
+ if dim % 2:
205
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
206
+ if torch.is_floating_point(t):
207
+ embedding = embedding.to(t)
208
+ return embedding
209
+
210
+
211
+ class MLPEmbedder(nn.Module):
212
+ def __init__(self, in_dim: int, hidden_dim: int):
213
+ super().__init__()
214
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
215
+ self.silu = nn.SiLU()
216
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
217
+
218
+ def forward(self, x: Tensor) -> Tensor:
219
+ return self.out_layer(self.silu(self.in_layer(x)))
220
+
221
+
222
+ class RMSNorm(torch.nn.Module):
223
+ def __init__(self, dim: int):
224
+ super().__init__()
225
+ self.scale = nn.Parameter(torch.ones(dim))
226
+
227
+ def forward(self, x: Tensor):
228
+ x_dtype = x.dtype
229
+ x = x.float()
230
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
231
+ return (x * rrms).to(dtype=x_dtype) * self.scale
232
+
233
+
234
+ class QKNorm(torch.nn.Module):
235
+ def __init__(self, dim: int):
236
+ super().__init__()
237
+ self.query_norm = RMSNorm(dim)
238
+ self.key_norm = RMSNorm(dim)
239
+
240
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
241
+ q = self.query_norm(q)
242
+ k = self.key_norm(k)
243
+ return q.to(v), k.to(v)
244
+
245
+
246
+ class SelfAttention(nn.Module):
247
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
248
+ super().__init__()
249
+ self.num_heads = num_heads
250
+ head_dim = dim // num_heads
251
+
252
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
253
+ self.norm = QKNorm(head_dim)
254
+ self.proj = nn.Linear(dim, dim)
255
+
256
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
257
+ qkv = self.qkv(x)
258
+ # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
259
+ B, L, _ = qkv.shape
260
+ qkv = qkv.view(B, L, 3, self.num_heads, -1)
261
+ q, k, v = qkv.permute(2, 0, 3, 1, 4)
262
+ q, k = self.norm(q, k, v)
263
+ x = attention(q, k, v, pe=pe)
264
+ x = self.proj(x)
265
+ return x
266
+
267
+
268
+ @dataclass
269
+ class ModulationOut:
270
+ shift: Tensor
271
+ scale: Tensor
272
+ gate: Tensor
273
+
274
+
275
+ class Modulation(nn.Module):
276
+ def __init__(self, dim: int, double: bool):
277
+ super().__init__()
278
+ self.is_double = double
279
+ self.multiplier = 6 if double else 3
280
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
281
+
282
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
283
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
284
+
285
+ return (
286
+ ModulationOut(*out[:3]),
287
+ ModulationOut(*out[3:]) if self.is_double else None,
288
+ )
289
+
290
+
291
+ class DoubleStreamBlock(nn.Module):
292
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
293
+ super().__init__()
294
+
295
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
296
+ self.num_heads = num_heads
297
+ self.hidden_size = hidden_size
298
+ self.img_mod = Modulation(hidden_size, double=True)
299
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
300
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
301
+
302
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
303
+ self.img_mlp = nn.Sequential(
304
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
305
+ nn.GELU(approximate="tanh"),
306
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
307
+ )
308
+
309
+ self.txt_mod = Modulation(hidden_size, double=True)
310
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
311
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
312
+
313
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
314
+ self.txt_mlp = nn.Sequential(
315
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
316
+ nn.GELU(approximate="tanh"),
317
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
318
+ )
319
+
320
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
321
+ img_mod1, img_mod2 = self.img_mod(vec)
322
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
323
+
324
+ # prepare image for attention
325
+ img_modulated = self.img_norm1(img)
326
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
327
+ img_qkv = self.img_attn.qkv(img_modulated)
328
+ # img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
329
+ B, L, _ = img_qkv.shape
330
+ H = self.num_heads
331
+ D = img_qkv.shape[-1] // (3 * H)
332
+ img_q, img_k, img_v = img_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
333
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
334
+
335
+ # prepare txt for attention
336
+ txt_modulated = self.txt_norm1(txt)
337
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
338
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
339
+ # txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
340
+ B, L, _ = txt_qkv.shape
341
+ txt_q, txt_k, txt_v = txt_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
342
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
343
+
344
+ # run actual attention
345
+ q = torch.cat((txt_q, img_q), dim=2)
346
+ k = torch.cat((txt_k, img_k), dim=2)
347
+ v = torch.cat((txt_v, img_v), dim=2)
348
+
349
+ attn = attention(q, k, v, pe=pe)
350
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
351
+
352
+ # calculate the img bloks
353
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
354
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
355
+
356
+ # calculate the txt bloks
357
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
358
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
359
+ return img, txt
360
+
361
+
362
+ class SingleStreamBlock(nn.Module):
363
+ """
364
+ A DiT block with parallel linear layers as described in
365
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
366
+ """
367
+
368
+ def __init__(
369
+ self,
370
+ hidden_size: int,
371
+ num_heads: int,
372
+ mlp_ratio: float = 4.0,
373
+ qk_scale: float | None = None,
374
+ ):
375
+ super().__init__()
376
+ self.hidden_dim = hidden_size
377
+ self.num_heads = num_heads
378
+ head_dim = hidden_size // num_heads
379
+ self.scale = qk_scale or head_dim**-0.5
380
+
381
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
382
+ # qkv and mlp_in
383
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
384
+ # proj and mlp_out
385
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
386
+
387
+ self.norm = QKNorm(head_dim)
388
+
389
+ self.hidden_size = hidden_size
390
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
391
+
392
+ self.mlp_act = nn.GELU(approximate="tanh")
393
+ self.modulation = Modulation(hidden_size, double=False)
394
+
395
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
396
+ mod, _ = self.modulation(vec)
397
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
398
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
399
+
400
+ # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
401
+ qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.hidden_size // self.num_heads)
402
+ q, k, v = qkv.permute(2, 0, 3, 1, 4)
403
+ q, k = self.norm(q, k, v)
404
+
405
+ # compute attention
406
+ attn = attention(q, k, v, pe=pe)
407
+ # compute activation in mlp stream, cat again and run second linear layer
408
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
409
+ return x + mod.gate * output
410
+
411
+
412
+ class LastLayer(nn.Module):
413
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
414
+ super().__init__()
415
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
416
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
417
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
418
+
419
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
420
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
421
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
422
+ x = self.linear(x)
423
+ return x
424
+
425
+
426
+ class FluxParams:
427
+ in_channels: int = 64
428
+ vec_in_dim: int = 768
429
+ context_in_dim: int = 4096
430
+ hidden_size: int = 3072
431
+ mlp_ratio: float = 4.0
432
+ num_heads: int = 24
433
+ depth: int = 19
434
+ depth_single_blocks: int = 38
435
+ axes_dim: list = [16, 56, 56]
436
+ theta: int = 10_000
437
+ qkv_bias: bool = True
438
+ guidance_embed: bool = True
439
+
440
+
441
+ class Flux(nn.Module):
442
+ """
443
+ Transformer model for flow matching on sequences.
444
+ """
445
+
446
+ def __init__(self, params = FluxParams()):
447
+ super().__init__()
448
+
449
+ self.params = params
450
+ self.in_channels = params.in_channels
451
+ self.out_channels = self.in_channels
452
+ if params.hidden_size % params.num_heads != 0:
453
+ raise ValueError(
454
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
455
+ )
456
+ pe_dim = params.hidden_size // params.num_heads
457
+ if sum(params.axes_dim) != pe_dim:
458
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
459
+ self.hidden_size = params.hidden_size
460
+ self.num_heads = params.num_heads
461
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
462
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
463
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
464
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
465
+ # self.guidance_in = (
466
+ # MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
467
+ # )
468
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
469
+
470
+ self.double_blocks = nn.ModuleList(
471
+ [
472
+ DoubleStreamBlock(
473
+ self.hidden_size,
474
+ self.num_heads,
475
+ mlp_ratio=params.mlp_ratio,
476
+ qkv_bias=params.qkv_bias,
477
+ )
478
+ for _ in range(params.depth)
479
+ ]
480
+ )
481
+
482
+ self.single_blocks = nn.ModuleList(
483
+ [
484
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
485
+ for _ in range(params.depth_single_blocks)
486
+ ]
487
+ )
488
+
489
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
490
+
491
+ def forward(
492
+ self,
493
+ img: Tensor,
494
+ img_ids: Tensor,
495
+ txt: Tensor,
496
+ txt_ids: Tensor,
497
+ timesteps: Tensor,
498
+ y: Tensor,
499
+ guidance: Tensor | None = None,
500
+ use_guidance_vec = True,
501
+ ) -> Tensor:
502
+ if img.ndim != 3 or txt.ndim != 3:
503
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
504
+
505
+ # running on sequences img
506
+ img = self.img_in(img)
507
+ vec = self.time_in(timestep_embedding(timesteps, 256))
508
+ # if self.params.guidance_embed and use_guidance_vec:
509
+ # if guidance is None:
510
+ # raise ValueError("Didn't get guidance strength for guidance distilled model.")
511
+ # vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
512
+ vec = vec + self.vector_in(y)
513
+ txt = self.txt_in(txt)
514
+
515
+ ids = torch.cat((txt_ids, img_ids), dim=1)
516
+ pe = self.pe_embedder(ids)
517
+
518
+ for block in self.double_blocks:
519
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
520
+
521
+ img = torch.cat((txt, img), 1)
522
+ for block in self.single_blocks:
523
+ img = block(img, vec=vec, pe=pe)
524
+ img = img[:, txt.shape[1] :, ...]
525
+
526
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
527
+ return img
528
+
529
+
530
+ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
531
+ bs, c, h, w = img.shape
532
+ if bs == 1 and not isinstance(prompt, str):
533
+ bs = len(prompt)
534
+
535
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
536
+ if img.shape[0] == 1 and bs > 1:
537
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
538
+
539
+ img_ids = torch.zeros(h // 2, w // 2, 3)
540
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
541
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
542
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
543
+
544
+ if isinstance(prompt, str):
545
+ prompt = [prompt]
546
+ txt = t5(prompt)
547
+ if txt.shape[0] == 1 and bs > 1:
548
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
549
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
550
+
551
+ vec = clip(prompt)
552
+ if vec.shape[0] == 1 and bs > 1:
553
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
554
+
555
+ return {
556
+ "img": img,
557
+ "img_ids": img_ids.to(img.device),
558
+ "txt": txt.to(img.device),
559
+ "txt_ids": txt_ids.to(img.device),
560
+ "vec": vec.to(img.device),
561
+ }
562
+
563
+
564
+ def time_shift(mu: float, sigma: float, t: Tensor):
565
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
566
+
567
+
568
+ def get_lin_function(
569
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
570
+ ) -> Callable[[float], float]:
571
+ m = (y2 - y1) / (x2 - x1)
572
+ b = y1 - m * x1
573
+ return lambda x: m * x + b
574
+
575
+
576
+ def get_schedule(
577
+ num_steps: int,
578
+ image_seq_len: int,
579
+ base_shift: float = 0.5,
580
+ max_shift: float = 1.15,
581
+ shift: bool = True,
582
+ ) -> list[float]:
583
+ # extra step for zero
584
+ timesteps = torch.linspace(1, 0, num_steps + 1)
585
+
586
+ # shifting the schedule to favor high timesteps for higher signal images
587
+ if shift:
588
+ # eastimate mu based on linear estimation between two points
589
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
590
+ timesteps = time_shift(mu, 1.0, timesteps)
591
+
592
+ return timesteps.tolist()
593
+
594
+
595
+ def denoise(
596
+ model: Flux,
597
+ # model input
598
+ img: Tensor,
599
+ img_ids: Tensor,
600
+ txt: Tensor,
601
+ txt_ids: Tensor,
602
+ vec: Tensor,
603
+ # sampling parameters
604
+ timesteps: list[float],
605
+ guidance: float = 4.0,
606
+ use_cfg_guidance = False,
607
+ ):
608
+ # this is ignored for schnell
609
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
610
+ for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:])):
611
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
612
+
613
+ if use_cfg_guidance:
614
+ half_x = img[:len(img)//2]
615
+ img = torch.cat([half_x, half_x], dim=0)
616
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
617
+
618
+ pred = model(
619
+ img=img,
620
+ img_ids=img_ids,
621
+ txt=txt,
622
+ txt_ids=txt_ids,
623
+ y=vec,
624
+ timesteps=t_vec,
625
+ guidance=guidance_vec,
626
+ use_guidance_vec=not use_cfg_guidance,
627
+ )
628
+
629
+ if use_cfg_guidance:
630
+ uncond, cond = pred.chunk(2, dim=0)
631
+ model_output = uncond + guidance * (cond - uncond)
632
+ pred = torch.cat([model_output, model_output], dim=0)
633
+
634
+ img = img + (t_prev - t_curr) * pred
635
+
636
+ return img
637
+
638
+
639
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
640
+ return rearrange(
641
+ x,
642
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
643
+ h=math.ceil(height / 16),
644
+ w=math.ceil(width / 16),
645
+ ph=2,
646
+ pw=2,
647
+ )
648
+
649
+ @dataclass
650
+ class SamplingOptions:
651
+ prompt: str
652
+ width: int
653
+ height: int
654
+ guidance: float
655
+ seed: int | None
656
+
657
+
658
+ def get_image(image) -> torch.Tensor | None:
659
+ if image is None:
660
+ return None
661
+ image = Image.fromarray(image).convert("RGB")
662
+
663
+ transform = transforms.Compose([
664
+ transforms.ToTensor(),
665
+ transforms.Lambda(lambda x: 2.0 * x - 1.0),
666
+ ])
667
+ img: torch.Tensor = transform(image)
668
+ return img[None, ...]
669
+
670
+
671
+ # ---------------- Demo ----------------
672
+
673
+
674
+ class EmptyInitWrapper(torch.overrides.TorchFunctionMode):
675
+ def __init__(self, device=None):
676
+ self.device = device
677
+
678
+ def __torch_function__(self, func, types, args=(), kwargs=None):
679
+ kwargs = kwargs or {}
680
+ if getattr(func, "__module__", None) == "torch.nn.init":
681
+ if "tensor" in kwargs:
682
+ return kwargs["tensor"]
683
+ else:
684
+ return args[0]
685
+ if (
686
+ self.device is not None
687
+ and func in torch.utils._device._device_constructors()
688
+ and kwargs.get("device") is None
689
+ ):
690
+ kwargs["device"] = self.device
691
+ return func(*args, **kwargs)
692
+
693
+ with EmptyInitWrapper():
694
+ model = Flux().to(dtype=torch.bfloat16, device="cuda")
695
+
696
+ sd = load_file(f"{model_path}/consolidated_s6700.safetensors")
697
+ sd = {k.replace("model.", ""): v for k, v in sd.items()}
698
+ result = model.load_state_dict(sd)
699
+
700
+ @spaces.GPU(duration=120)
701
+ @torch.inference_mode()
702
+
703
+ #@torch.cuda.empty_cache()
704
+
705
+ class calculateDuration:
706
+ def __init__(self, activity_name=""):
707
+ self.activity_name = activity_name
708
+
709
+ def __enter__(self):
710
+ self.start_time = time.time()
711
+ return self
712
+
713
+ def __exit__(self, exc_type, exc_value, traceback):
714
+ self.end_time = time.time()
715
+ self.elapsed_time = self.end_time - self.start_time
716
+ if self.activity_name:
717
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
718
+ else:
719
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
720
+
721
+
722
+ def update_selection(evt: gr.SelectData, width, height):
723
+ selected_lora = loras[evt.index]
724
+ new_placeholder = f"Type a prompt for {selected_lora['title']}"
725
+ lora_repo = selected_lora["repo"]
726
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
727
+ if "aspect" in selected_lora:
728
+ if selected_lora["aspect"] == "portrait":
729
+ width = 768
730
+ height = 1024
731
+ elif selected_lora["aspect"] == "landscape":
732
+ width = 1024
733
+ height = 768
734
+ return (
735
+ gr.update(placeholder=new_placeholder),
736
+ updated_text,
737
+ evt.index,
738
+ width,
739
+ height,
740
+ )
741
+
742
+ def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, negative_prompt, lora_scale, progress):
743
+ pipe.to("cuda")
744
+ generator = torch.Generator(device="cuda").manual_seed(seed)
745
+
746
+ with calculateDuration("Generating image"):
747
+ # Generate image
748
+ image = pipe(
749
+ prompt=f"{prompt} {trigger_word}",
750
+ negative_prompt=negative_prompt,
751
+ num_inference_steps=steps,
752
+ guidance_scale=cfg_scale,
753
+ width=width,
754
+ height=height,
755
+ generator=generator,
756
+ joint_attention_kwargs={"scale": lora_scale},
757
+ ).images[0]
758
+ return image
759
+
760
+ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, negative_prompt, lora_scale, progress=gr.Progress(track_tqdm=True)):
761
+ if negative_prompt == "":
762
+ negative_prompt = None
763
+ if selected_index is None:
764
+ raise gr.Error("You must select a LoRA before proceeding.")
765
+
766
+ selected_lora = loras[selected_index]
767
+ lora_path = selected_lora["repo"]
768
+ trigger_word = selected_lora["trigger_word"]
769
+
770
+ # Load LoRA weights
771
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
772
+ if "weights" in selected_lora:
773
+ pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
774
+ else:
775
+ pipe.load_lora_weights(lora_path)
776
+
777
+ # Set random seed for reproducibility
778
+ with calculateDuration("Randomizing seed"):
779
+ if randomize_seed:
780
+ seed = random.randint(0, 2**32-1)
781
+
782
+ image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, negative_prompt, lora_scale, progress)
783
+ pipe.to("cpu")
784
+ pipe.unload_lora_weights()
785
+ return image, seed
786
+
787
+ run_lora.zerogpu = True
788
+
789
+ css = '''
790
+ #gen_btn{height: 100%}
791
+ #title{text-align: center}
792
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
793
+ #title img{width: 100px; margin-right: 0.5em}
794
+ #gallery .grid-wrap{height: 10vh}
795
+ '''
796
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
797
+ title = gr.HTML(
798
+ """<h1><img src="https://huggingface.co/AlekseyCalvin/HSTklimbimOPENfluxLora/resolve/main/acs62iv.png" alt="LoRA">OpenFlux LoRAsoon®</h1>""",
799
+ elem_id="title",
800
+ )
801
+ # Info blob stating what the app is running
802
+ info_blob = gr.HTML(
803
+ """<div id="info_blob"> SOON®'s curated LoRa Gallery & Art Manufactory Space.|Runs on Ostris' OpenFLUX.1 model + fast-gen LoRA & Zer0int's fine-tuned CLIP-GmP-ViT-L-14*! (*'normal' 77 tokens)| Largely stocked w/our trained LoRAs: Historic Color, Silver Age Poets, Sots Art, more!|</div>"""
804
+ )
805
+ # Info blob stating what the app is running
806
+ info_blob = gr.HTML(
807
+ """<div id="info_blob"> *Auto-planting of prompts with a choice LoRA trigger errors out in this space over flaws yet unclear. In its stead, we pose numbered LoRA-box rows & a matched token cheat-sheet: ungainly & free. So, prephrase your prompts w/: 1-2. HST style autochrome |3. RCA style Communist poster |4. SOTS art |5. HST Austin Osman Spare style |6. Vladimir Mayakovsky |7-8. Marina Tsvetaeva Tsvetaeva_02.CR2 |9. Anna Akhmatova |10. Osip Mandelshtam |11-12. Alexander Blok |13. Blok_02.CR2 |14. LEN Lenin |15. Leon Trotsky |16. Rosa Fluxemburg |17. HST Peterhof photo |18-19. HST |20. HST portrait |21. HST |22. HST 80s Perestroika-era Soviet photo |23-30. HST |31. How2Draw a__ |32. propaganda poster |33. TOK hybrid photo of__ with cartoon of__ |34. 2004 IMG_1099.CR2 photo |35. unexpected photo of |36. flmft |37. 80s yearbook photo |38. TOK portra |39. pficonics |40. retrofuturism |41. wh3r3sw4ld0 |42. amateur photo |43. crisp |44-45. IMG_1099.CR2 |46. FilmFotos |47. ff-collage |48. HST |49-50. AOS |51. cover </div>"""
808
+ )
809
+ selected_index = gr.State(None)
810
+ with gr.Row():
811
+ with gr.Column(scale=3):
812
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Select LoRa/Style & type prompt!")
813
+ with gr.Row():
814
+ with gr.Column(scale=3):
815
+ negative_prompt = gr.Textbox(label="Negative Prompt", lines=1, placeholder="List unwanted conditions, open-fluxedly!")
816
+ with gr.Column(scale=1, elem_id="gen_column"):
817
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
818
+ with gr.Row():
819
+ with gr.Column(scale=3):
820
+ selected_info = gr.Markdown("")
821
+ gallery = gr.Gallery(
822
+ [(item["image"], item["title"]) for item in loras],
823
+ label="LoRA Inventory",
824
+ allow_preview=False,
825
+ columns=3,
826
+ elem_id="gallery"
827
+ )
828
+
829
+ with gr.Column(scale=4):
830
+ result = gr.Image(label="Generated Image")
831
+
832
+ with gr.Row():
833
+ with gr.Accordion("Advanced Settings", open=True):
834
+ with gr.Column():
835
+ with gr.Row():
836
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=1, value=3)
837
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=6)
838
+
839
+ with gr.Row():
840
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=768)
841
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=768)
842
+
843
+ with gr.Row():
844
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
845
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
846
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95)
847
+
848
+ gallery.select(
849
+ update_selection,
850
+ inputs=[width, height],
851
+ outputs=[prompt, selected_info, selected_index, width, height]
852
+ )
853
+
854
+ gr.on(
855
+ triggers=[generate_button.click, prompt.submit],
856
+ fn=run_lora,
857
+ inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, negative_prompt, lora_scale],
858
+ outputs=[result, seed]
859
+ )
860
+
861
+ warnings.filterwarnings("ignore", category=FutureWarning)
862
+ app.queue(default_concurrency_limit=2).launch(show_error=True)
863
+ app.launch()