Spaces:
Running
on
Zero
Running
on
Zero
PommesPeter
commited on
Commit
•
5aadc08
1
Parent(s):
24e677f
Upload 8 files
Browse files- app.py +598 -0
- models/__init__.py +2 -0
- models/components.py +54 -0
- models/model.py +908 -0
- models/model_5b.py +894 -0
- requirements.txt +12 -0
app.py
ADDED
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import builtins
|
3 |
+
import json
|
4 |
+
import multiprocessing as mp
|
5 |
+
import os, sys
|
6 |
+
import random
|
7 |
+
import socket
|
8 |
+
import traceback
|
9 |
+
|
10 |
+
import fairscale.nn.model_parallel.initialize as fs_init
|
11 |
+
import gradio as gr
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.distributed as dist
|
15 |
+
from torchvision.transforms.functional import to_pil_image
|
16 |
+
|
17 |
+
import models
|
18 |
+
from PIL import Image
|
19 |
+
from lumina_t2i.transport import create_transport, Sampler
|
20 |
+
|
21 |
+
description = """
|
22 |
+
# Lumina Next Text-to-Image
|
23 |
+
|
24 |
+
Lumina-Next-T2I is a 2B Next-DiT model with 2B text encoder.
|
25 |
+
|
26 |
+
Demo current model: `Lumina-Next-T2I`
|
27 |
+
|
28 |
+
### <span style='color: red;'>Due to the high volume of access, we have temporarily disabled the resolution extrapolation functionality.
|
29 |
+
|
30 |
+
### Additionally, we offer three alternative links for Lumina-T2X access. Try to visit other demo sites. [[demo1](http://106.14.2.150:10022/)] [[demo2](http://106.14.2.150:10023/)]
|
31 |
+
|
32 |
+
"""
|
33 |
+
|
34 |
+
examples = [
|
35 |
+
["👽🤖👹👻"],
|
36 |
+
["孤舟蓑笠翁"],
|
37 |
+
["两只黄鹂鸣翠柳"],
|
38 |
+
["大漠孤烟直,长河落日圆"],
|
39 |
+
["秋风起兮白云飞,草木黄落兮雁南归"],
|
40 |
+
["도쿄 타워, 최고 품질의 우키요에, 에도 시대"],
|
41 |
+
["味噌ラーメン, 最高品質の浮世絵、江戸時代。"],
|
42 |
+
["東京タワー、最高品質の浮世絵、江戸時代。"],
|
43 |
+
["Astronaut on Mars During sunset"],
|
44 |
+
["Tour de Tokyo, estampes ukiyo-e de la plus haute qualité, période Edo"],
|
45 |
+
["🐔 playing 🏀"],
|
46 |
+
["☃️ with 🌹 in the ❄️"],
|
47 |
+
["🐶 wearing 😎 flying on 🌈 "],
|
48 |
+
["A small 🍎 and 🍊 with 😁 emoji in the Sahara desert"],
|
49 |
+
["Токийская башня, лучшие укиё-э, период Эдо"],
|
50 |
+
["Tokio-Turm, hochwertigste Ukiyo-e, Edo-Zeit"],
|
51 |
+
["A scared cute rabbit in Happy Tree Friends style and punk vibe."], # noqa
|
52 |
+
["A humanoid eagle soldier of the First World War."], # noqa
|
53 |
+
["A cute Christmas mockup on an old wooden industrial desk table with Christmas decorations and bokeh lights in the background."],
|
54 |
+
["A front view of a romantic flower shop in France filled with various blooming flowers including lavenders and roses."],
|
55 |
+
["An old man, portrayed as a retro superhero, stands in the streets of New York City at night"],
|
56 |
+
["many trees are surrounded by a lake in autumn colors, in the style of nature-inspired imagery, havencore, brightly colored, dark white and dark orange, bright primary colors, environmental activism, forestpunk --ar 64:51"],
|
57 |
+
["A fluffy mouse holding a watermelon, in a magical and colorful setting, illustrated in the style of Hayao Miyazaki anime by Studio Ghibli."],
|
58 |
+
["Inka warrior with a war make up, medium shot, natural light, Award winning wildlife photography, hyperrealistic, 8k resolution, --ar 9:16"],
|
59 |
+
["Character of lion in style of saiyan, mafia, gangsta, citylights background, Hyper detailed, hyper realistic, unreal engine ue5, cgi 3d, cinematic shot, 8k"],
|
60 |
+
["In the sky above, a giant, whimsical cloud shaped like the 😊 emoji casts a soft, golden light over the scene"],
|
61 |
+
["Cyberpunk eagle, neon ambiance, abstract black oil, gear mecha, detailed acrylic, grunge, intricate complexity, rendered in unreal engine 5, photorealistic, 8k"],
|
62 |
+
["close-up photo of a beautiful red rose breaking through a cube made of ice , splintered cracked ice surface, frosted colors, blood dripping from rose, melting ice, Valentine’s Day vibes, cinematic, sharp focus, intricate, cinematic, dramatic light"],
|
63 |
+
["3D cartoon Fox Head with Human Body, Wearing Iridescent Holographic Liquid Texture & Translucent Material Sun Protective Shirt, Boss Feel, Nike or Addidas Sun Protective Shirt, WitchPunk, Y2K Style, Green and blue, Blue, Metallic Feel, Strong Reflection, plain background, no background, pure single color background, Digital Fashion, Surreal Futurism, Supreme Kong NFT Artwork Style, disney style, headshot photography for portrait studio shoot, fashion editorial aesthetic, high resolution in the style of HAPE PRIME NFT, NFT 3D IP Feel, Bored Ape Yacht Club NFT project Feel, high detail, fine luster, 3D render, oc render, best quality, 8K, bright, front lighting, Face Shot, fine luster, ultra detailed"],
|
64 |
+
],
|
65 |
+
|
66 |
+
class ModelFailure:
|
67 |
+
pass
|
68 |
+
|
69 |
+
|
70 |
+
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
|
71 |
+
def encode_prompt(
|
72 |
+
prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True
|
73 |
+
):
|
74 |
+
|
75 |
+
captions = []
|
76 |
+
for caption in prompt_batch:
|
77 |
+
if random.random() < proportion_empty_prompts:
|
78 |
+
captions.append("")
|
79 |
+
elif isinstance(caption, str):
|
80 |
+
captions.append(caption)
|
81 |
+
elif isinstance(caption, (list, np.ndarray)):
|
82 |
+
# take a random caption if there are multiple
|
83 |
+
captions.append(random.choice(caption) if is_train else caption[0])
|
84 |
+
|
85 |
+
with torch.no_grad():
|
86 |
+
text_inputs = tokenizer(
|
87 |
+
captions,
|
88 |
+
padding=True,
|
89 |
+
pad_to_multiple_of=8,
|
90 |
+
max_length=256,
|
91 |
+
truncation=True,
|
92 |
+
return_tensors="pt",
|
93 |
+
)
|
94 |
+
|
95 |
+
text_input_ids = text_inputs.input_ids
|
96 |
+
prompt_masks = text_inputs.attention_mask
|
97 |
+
|
98 |
+
prompt_embeds = text_encoder(
|
99 |
+
input_ids=text_input_ids.cuda(),
|
100 |
+
attention_mask=prompt_masks.cuda(),
|
101 |
+
output_hidden_states=True,
|
102 |
+
).hidden_states[-2]
|
103 |
+
|
104 |
+
return prompt_embeds, prompt_masks
|
105 |
+
|
106 |
+
|
107 |
+
@torch.no_grad()
|
108 |
+
def model_main(args, master_port, rank, request_queue, response_queue, mp_barrier):
|
109 |
+
# import here to avoid huggingface Tokenizer parallelism warnings
|
110 |
+
from diffusers.models import AutoencoderKL
|
111 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
112 |
+
|
113 |
+
# override the default print function since the delay can be large for child process
|
114 |
+
original_print = builtins.print
|
115 |
+
|
116 |
+
# Redefine the print function with flush=True by default
|
117 |
+
def print(*args, **kwargs):
|
118 |
+
kwargs.setdefault("flush", True)
|
119 |
+
original_print(*args, **kwargs)
|
120 |
+
|
121 |
+
# Override the built-in print with the new version
|
122 |
+
builtins.print = print
|
123 |
+
|
124 |
+
os.environ["MASTER_PORT"] = str(master_port)
|
125 |
+
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
126 |
+
os.environ["RANK"] = str(rank)
|
127 |
+
os.environ["WORLD_SIZE"] = str(args.num_gpus)
|
128 |
+
|
129 |
+
dist.init_process_group("nccl")
|
130 |
+
# set up fairscale environment because some methods of the Lumina model need it,
|
131 |
+
# though for single-GPU inference fairscale actually has no effect
|
132 |
+
fs_init.initialize_model_parallel(args.num_gpus)
|
133 |
+
torch.cuda.set_device(rank)
|
134 |
+
|
135 |
+
train_args = torch.load(os.path.join(args.ckpt, "model_args.pth"))
|
136 |
+
if dist.get_rank() == 0:
|
137 |
+
print("Loaded model arguments:", json.dumps(train_args.__dict__, indent=2))
|
138 |
+
|
139 |
+
if dist.get_rank() == 0:
|
140 |
+
print(f"Creating lm: Gemma-2B")
|
141 |
+
|
142 |
+
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[
|
143 |
+
args.precision
|
144 |
+
]
|
145 |
+
|
146 |
+
text_encoder = (
|
147 |
+
AutoModelForCausalLM.from_pretrained(
|
148 |
+
"google/gemma-2b", torch_dtype=dtype, device_map="cuda"
|
149 |
+
)
|
150 |
+
.get_decoder()
|
151 |
+
.eval()
|
152 |
+
)
|
153 |
+
cap_feat_dim = text_encoder.config.hidden_size
|
154 |
+
if args.num_gpus > 1:
|
155 |
+
raise NotImplementedError("Inference with >1 GPUs not yet supported")
|
156 |
+
|
157 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
158 |
+
"google/gemma-2b", add_bos_token=True, add_eos_token=True
|
159 |
+
)
|
160 |
+
tokenizer.padding_side = "right"
|
161 |
+
|
162 |
+
if dist.get_rank() == 0:
|
163 |
+
print(f"Creating vae: sdxl-vae")
|
164 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae",
|
165 |
+
torch_dtype=torch.float32,
|
166 |
+
).cuda()
|
167 |
+
|
168 |
+
if dist.get_rank() == 0:
|
169 |
+
print(f"Creating DiT: Next-DiT")
|
170 |
+
# latent_size = train_args.image_size // 8
|
171 |
+
model = models.__dict__["DiT_Llama_2B_patch2"](
|
172 |
+
qk_norm=train_args.qk_norm,
|
173 |
+
cap_feat_dim=cap_feat_dim,
|
174 |
+
)
|
175 |
+
model.eval().to("cuda", dtype=dtype)
|
176 |
+
|
177 |
+
assert train_args.model_parallel_size == args.num_gpus
|
178 |
+
if args.ema:
|
179 |
+
print("Loading ema model.")
|
180 |
+
ckpt = torch.load(
|
181 |
+
os.path.join(
|
182 |
+
args.ckpt,
|
183 |
+
f"consolidated{'_ema' if args.ema else ''}.{rank:02d}-of-{args.num_gpus:02d}.pth",
|
184 |
+
),
|
185 |
+
map_location="cpu",
|
186 |
+
)
|
187 |
+
model.load_state_dict(ckpt, strict=True)
|
188 |
+
|
189 |
+
mp_barrier.wait()
|
190 |
+
|
191 |
+
with torch.autocast("cuda", dtype):
|
192 |
+
while True:
|
193 |
+
(
|
194 |
+
cap,
|
195 |
+
resolution,
|
196 |
+
num_sampling_steps,
|
197 |
+
cfg_scale,
|
198 |
+
solver,
|
199 |
+
t_shift,
|
200 |
+
seed,
|
201 |
+
ntk_scaling,
|
202 |
+
proportional_attn,
|
203 |
+
) = request_queue.get()
|
204 |
+
|
205 |
+
print(
|
206 |
+
"> params:",
|
207 |
+
cap,
|
208 |
+
resolution,
|
209 |
+
num_sampling_steps,
|
210 |
+
cfg_scale,
|
211 |
+
solver,
|
212 |
+
t_shift,
|
213 |
+
seed,
|
214 |
+
ntk_scaling,
|
215 |
+
proportional_attn,
|
216 |
+
)
|
217 |
+
try:
|
218 |
+
# begin sampler
|
219 |
+
transport = create_transport(
|
220 |
+
args.path_type,
|
221 |
+
args.prediction,
|
222 |
+
args.loss_weight,
|
223 |
+
args.train_eps,
|
224 |
+
args.sample_eps,
|
225 |
+
)
|
226 |
+
sampler = Sampler(transport)
|
227 |
+
if args.sampler_mode == "ODE":
|
228 |
+
if args.likelihood:
|
229 |
+
# assert args.cfg_scale == 1, "Likelihood is incompatible with guidance" # todo
|
230 |
+
sample_fn = sampler.sample_ode_likelihood(
|
231 |
+
sampling_method=solver,
|
232 |
+
num_steps=num_sampling_steps,
|
233 |
+
atol=args.atol,
|
234 |
+
rtol=args.rtol,
|
235 |
+
)
|
236 |
+
else:
|
237 |
+
sample_fn = sampler.sample_ode(
|
238 |
+
sampling_method=solver,
|
239 |
+
num_steps=num_sampling_steps,
|
240 |
+
atol=args.atol,
|
241 |
+
rtol=args.rtol,
|
242 |
+
reverse=args.reverse,
|
243 |
+
time_shifting_factor=t_shift,
|
244 |
+
)
|
245 |
+
elif args.sampler_mode == "SDE":
|
246 |
+
sample_fn = sampler.sample_sde(
|
247 |
+
sampling_method=solver,
|
248 |
+
diffusion_form=args.diffusion_form,
|
249 |
+
diffusion_norm=args.diffusion_norm,
|
250 |
+
last_step=args.last_step,
|
251 |
+
last_step_size=args.last_step_size,
|
252 |
+
num_steps=num_sampling_steps,
|
253 |
+
)
|
254 |
+
# end sampler
|
255 |
+
|
256 |
+
resolution = resolution.split(" ")[-1]
|
257 |
+
w, h = resolution.split("x")
|
258 |
+
w, h = int(w), int(h)
|
259 |
+
latent_w, latent_h = w // 8, h // 8
|
260 |
+
if int(seed) != 0:
|
261 |
+
torch.random.manual_seed(int(seed))
|
262 |
+
z = torch.randn([1, 4, latent_h, latent_w], device="cuda").to(dtype)
|
263 |
+
z = z.repeat(2, 1, 1, 1)
|
264 |
+
|
265 |
+
with torch.no_grad():
|
266 |
+
cap_feats, cap_mask = encode_prompt(
|
267 |
+
[cap] + [""], text_encoder, tokenizer, 0.0
|
268 |
+
)
|
269 |
+
cap_mask = cap_mask.to(cap_feats.device)
|
270 |
+
|
271 |
+
train_res = 1024
|
272 |
+
res_cat = (w * h) ** 0.5
|
273 |
+
print(f"res_cat: {res_cat}")
|
274 |
+
max_seq_len = (res_cat // 16) ** 2 + (res_cat // 16) * 2
|
275 |
+
print(f"max_seq_len: {max_seq_len}")
|
276 |
+
|
277 |
+
rope_scaling_factor = 1.0
|
278 |
+
ntk_factor = max_seq_len / (train_res // 16) ** 2
|
279 |
+
print(f"ntk_factor: {ntk_factor}")
|
280 |
+
|
281 |
+
model_kwargs = dict(
|
282 |
+
cap_feats=cap_feats,
|
283 |
+
cap_mask=cap_mask,
|
284 |
+
cfg_scale=cfg_scale,
|
285 |
+
rope_scaling_factor=rope_scaling_factor,
|
286 |
+
ntk_factor=ntk_factor,
|
287 |
+
)
|
288 |
+
|
289 |
+
if dist.get_rank() == 0:
|
290 |
+
print(f"caption: {cap}")
|
291 |
+
print(f"num_sampling_steps: {num_sampling_steps}")
|
292 |
+
print(f"cfg_scale: {cfg_scale}")
|
293 |
+
|
294 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
295 |
+
print("> [debug] start sample")
|
296 |
+
samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
|
297 |
+
samples = samples[:1]
|
298 |
+
|
299 |
+
factor = 0.18215 if train_args.vae != "sdxl" else 0.13025
|
300 |
+
print(f"vae factor: {factor}")
|
301 |
+
samples = vae.decode(samples / factor).sample
|
302 |
+
samples = (samples + 1.0) / 2.0
|
303 |
+
samples.clamp_(0.0, 1.0)
|
304 |
+
img = to_pil_image(samples[0].float())
|
305 |
+
|
306 |
+
if response_queue is not None:
|
307 |
+
response_queue.put(img)
|
308 |
+
|
309 |
+
except Exception:
|
310 |
+
print(traceback.format_exc())
|
311 |
+
response_queue.put(ModelFailure())
|
312 |
+
|
313 |
+
|
314 |
+
def none_or_str(value):
|
315 |
+
if value == "None":
|
316 |
+
return None
|
317 |
+
return value
|
318 |
+
|
319 |
+
|
320 |
+
def parse_transport_args(parser):
|
321 |
+
group = parser.add_argument_group("Transport arguments")
|
322 |
+
group.add_argument(
|
323 |
+
"--path-type",
|
324 |
+
type=str,
|
325 |
+
default="Linear",
|
326 |
+
choices=["Linear", "GVP", "VP"],
|
327 |
+
help="the type of path for transport: 'Linear', 'GVP' (Geodesic Vector Pursuit), or 'VP' (Vector Pursuit).",
|
328 |
+
)
|
329 |
+
group.add_argument(
|
330 |
+
"--prediction",
|
331 |
+
type=str,
|
332 |
+
default="velocity",
|
333 |
+
choices=["velocity", "score", "noise"],
|
334 |
+
help="the prediction model for the transport dynamics.",
|
335 |
+
)
|
336 |
+
group.add_argument(
|
337 |
+
"--loss-weight",
|
338 |
+
type=none_or_str,
|
339 |
+
default=None,
|
340 |
+
choices=[None, "velocity", "likelihood"],
|
341 |
+
help="the weighting of different components in the loss function, can be 'velocity' for dynamic modeling, 'likelihood' for statistical consistency, or None for no weighting.",
|
342 |
+
)
|
343 |
+
group.add_argument(
|
344 |
+
"--sample-eps", type=float, help="sampling in the transport model."
|
345 |
+
)
|
346 |
+
group.add_argument(
|
347 |
+
"--train-eps", type=float, help="training to stabilize the learning process."
|
348 |
+
)
|
349 |
+
|
350 |
+
|
351 |
+
def parse_ode_args(parser):
|
352 |
+
group = parser.add_argument_group("ODE arguments")
|
353 |
+
group.add_argument(
|
354 |
+
"--atol",
|
355 |
+
type=float,
|
356 |
+
default=1e-6,
|
357 |
+
help="Absolute tolerance for the ODE solver.",
|
358 |
+
)
|
359 |
+
group.add_argument(
|
360 |
+
"--rtol",
|
361 |
+
type=float,
|
362 |
+
default=1e-3,
|
363 |
+
help="Relative tolerance for the ODE solver.",
|
364 |
+
)
|
365 |
+
group.add_argument(
|
366 |
+
"--reverse", action="store_true", help="run the ODE solver in reverse."
|
367 |
+
)
|
368 |
+
group.add_argument(
|
369 |
+
"--likelihood",
|
370 |
+
action="store_true",
|
371 |
+
help="Enable calculation of likelihood during the ODE solving process.",
|
372 |
+
)
|
373 |
+
|
374 |
+
|
375 |
+
def parse_sde_args(parser):
|
376 |
+
group = parser.add_argument_group("SDE arguments")
|
377 |
+
group.add_argument(
|
378 |
+
"--sampling-method",
|
379 |
+
type=str,
|
380 |
+
default="Euler",
|
381 |
+
choices=["Euler", "Heun"],
|
382 |
+
help="the numerical method used for sampling the stochastic differential equation: 'Euler' for simplicity or 'Heun' for improved accuracy.",
|
383 |
+
)
|
384 |
+
group.add_argument(
|
385 |
+
"--diffusion-form",
|
386 |
+
type=str,
|
387 |
+
default="sigma",
|
388 |
+
choices=[
|
389 |
+
"constant",
|
390 |
+
"SBDM",
|
391 |
+
"sigma",
|
392 |
+
"linear",
|
393 |
+
"decreasing",
|
394 |
+
"increasing-decreasing",
|
395 |
+
],
|
396 |
+
help="form of diffusion coefficient in the SDE",
|
397 |
+
)
|
398 |
+
group.add_argument(
|
399 |
+
"--diffusion-norm",
|
400 |
+
type=float,
|
401 |
+
default=1.0,
|
402 |
+
help="Normalizes the diffusion coefficient, affecting the scale of the stochastic component.",
|
403 |
+
)
|
404 |
+
group.add_argument(
|
405 |
+
"--last-step",
|
406 |
+
type=none_or_str,
|
407 |
+
default="Mean",
|
408 |
+
choices=[None, "Mean", "Tweedie", "Euler"],
|
409 |
+
help="form of last step taken in the SDE",
|
410 |
+
)
|
411 |
+
group.add_argument(
|
412 |
+
"--last-step-size", type=float, default=0.04, help="size of the last step taken"
|
413 |
+
)
|
414 |
+
|
415 |
+
|
416 |
+
def find_free_port() -> int:
|
417 |
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
418 |
+
sock.bind(("", 0))
|
419 |
+
port = sock.getsockname()[1]
|
420 |
+
sock.close()
|
421 |
+
return port
|
422 |
+
|
423 |
+
|
424 |
+
def main():
|
425 |
+
parser = argparse.ArgumentParser()
|
426 |
+
mode = "ODE"
|
427 |
+
|
428 |
+
parser.add_argument("--num_gpus", type=int, default=1)
|
429 |
+
parser.add_argument("--ckpt", type=str, default="./checkpoints")
|
430 |
+
parser.add_argument("--ema", type=bool, default=True)
|
431 |
+
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp32"])
|
432 |
+
|
433 |
+
parse_transport_args(parser)
|
434 |
+
if mode == "ODE":
|
435 |
+
parse_ode_args(parser)
|
436 |
+
# Further processing for ODE
|
437 |
+
elif mode == "SDE":
|
438 |
+
parse_sde_args(parser)
|
439 |
+
# Further processing for SDE
|
440 |
+
|
441 |
+
args = parser.parse_known_args()[0]
|
442 |
+
|
443 |
+
if args.num_gpus != 1:
|
444 |
+
raise NotImplementedError("Multi-GPU Inference is not yet supported")
|
445 |
+
|
446 |
+
args.sampler_mode = mode
|
447 |
+
|
448 |
+
master_port = find_free_port()
|
449 |
+
|
450 |
+
processes = []
|
451 |
+
request_queues = []
|
452 |
+
response_queue = mp.Queue()
|
453 |
+
mp_barrier = mp.Barrier(args.num_gpus + 1)
|
454 |
+
for i in range(args.num_gpus):
|
455 |
+
request_queues.append(mp.Queue())
|
456 |
+
p = mp.Process(
|
457 |
+
target=model_main,
|
458 |
+
args=(
|
459 |
+
args,
|
460 |
+
master_port,
|
461 |
+
i,
|
462 |
+
request_queues[i],
|
463 |
+
response_queue if i == 0 else None,
|
464 |
+
mp_barrier,
|
465 |
+
),
|
466 |
+
)
|
467 |
+
p.start()
|
468 |
+
processes.append(p)
|
469 |
+
|
470 |
+
with gr.Blocks() as demo:
|
471 |
+
with gr.Row():
|
472 |
+
gr.Markdown(description)
|
473 |
+
with gr.Row():
|
474 |
+
with gr.Column():
|
475 |
+
cap = gr.Textbox(
|
476 |
+
lines=2,
|
477 |
+
label="Caption",
|
478 |
+
interactive=True,
|
479 |
+
value="Miss Mexico portrait of the most beautiful mexican woman, Exquisite detail, 30-megapixel, 4k, 85-mm-lens, sharp-focus, f:8, "
|
480 |
+
"ISO 100, shutter-speed 1:125, diffuse-back-lighting, award-winning photograph, small-catchlight, High-sharpness, facial-symmetry, 8k --q 2 --ar 18:32 --v 5",
|
481 |
+
)
|
482 |
+
with gr.Row():
|
483 |
+
res_choices = ["1024x1024", "512x2048", "2048x512"] + [
|
484 |
+
"(Extrapolation) 1664x1664",
|
485 |
+
"(Extrapolation) 1024x2048",
|
486 |
+
"(Extrapolation) 2048x1024",
|
487 |
+
]
|
488 |
+
resolution = gr.Dropdown(
|
489 |
+
value=res_choices[0], choices=res_choices, label="Resolution"
|
490 |
+
)
|
491 |
+
with gr.Row():
|
492 |
+
num_sampling_steps = gr.Slider(
|
493 |
+
minimum=1,
|
494 |
+
maximum=70,
|
495 |
+
value=30,
|
496 |
+
interactive=True,
|
497 |
+
label="Sampling steps",
|
498 |
+
)
|
499 |
+
seed = gr.Slider(
|
500 |
+
minimum=0,
|
501 |
+
maximum=int(1e5),
|
502 |
+
value=1,
|
503 |
+
step=1,
|
504 |
+
interactive=True,
|
505 |
+
label="Seed (0 for random)",
|
506 |
+
)
|
507 |
+
with gr.Accordion(
|
508 |
+
"Advanced Settings for Resolution Extrapolation", open=False
|
509 |
+
):
|
510 |
+
with gr.Row():
|
511 |
+
solver = gr.Dropdown(
|
512 |
+
value="euler",
|
513 |
+
choices=["euler", "dopri5", "dopri8"],
|
514 |
+
label="solver",
|
515 |
+
)
|
516 |
+
t_shift = gr.Slider(
|
517 |
+
minimum=1,
|
518 |
+
maximum=20,
|
519 |
+
value=6,
|
520 |
+
step=1,
|
521 |
+
interactive=True,
|
522 |
+
label="Time shift",
|
523 |
+
)
|
524 |
+
cfg_scale = gr.Slider(
|
525 |
+
minimum=1.0,
|
526 |
+
maximum=20.0,
|
527 |
+
value=4.0,
|
528 |
+
interactive=True,
|
529 |
+
label="CFG scale",
|
530 |
+
)
|
531 |
+
with gr.Row():
|
532 |
+
ntk_scaling = gr.Checkbox(
|
533 |
+
value=True,
|
534 |
+
interactive=True,
|
535 |
+
label="ntk scaling",
|
536 |
+
)
|
537 |
+
proportional_attn = gr.Checkbox(
|
538 |
+
value=True,
|
539 |
+
interactive=True,
|
540 |
+
label="Proportional attention",
|
541 |
+
)
|
542 |
+
with gr.Row():
|
543 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
544 |
+
# reset_btn = gr.ClearButton([
|
545 |
+
# cap, resolution,
|
546 |
+
# num_sampling_steps, cfg_scale, solver,
|
547 |
+
# t_shift, seed,
|
548 |
+
# ntk_scaling, proportional_attn
|
549 |
+
# ])
|
550 |
+
with gr.Column():
|
551 |
+
default_img = Image.open("./image.png")
|
552 |
+
output_img = gr.Image(
|
553 |
+
label="Generated image",
|
554 |
+
interactive=False,
|
555 |
+
format="png",
|
556 |
+
value=default_img,
|
557 |
+
)
|
558 |
+
|
559 |
+
with gr.Row():
|
560 |
+
gr.Examples(
|
561 |
+
examples,
|
562 |
+
[cap],
|
563 |
+
label="Examples",
|
564 |
+
)
|
565 |
+
|
566 |
+
def on_submit(*args):
|
567 |
+
for q in request_queues:
|
568 |
+
q.put(args)
|
569 |
+
result = response_queue.get()
|
570 |
+
if isinstance(result, ModelFailure):
|
571 |
+
raise RuntimeError
|
572 |
+
return result
|
573 |
+
|
574 |
+
submit_btn.click(
|
575 |
+
on_submit,
|
576 |
+
[
|
577 |
+
cap,
|
578 |
+
resolution,
|
579 |
+
num_sampling_steps,
|
580 |
+
cfg_scale,
|
581 |
+
solver,
|
582 |
+
t_shift,
|
583 |
+
seed,
|
584 |
+
ntk_scaling,
|
585 |
+
proportional_attn,
|
586 |
+
],
|
587 |
+
[output_img],
|
588 |
+
)
|
589 |
+
|
590 |
+
mp_barrier.wait()
|
591 |
+
demo.queue().launch(share=True, server_name="0.0.0.0")
|
592 |
+
|
593 |
+
|
594 |
+
if __name__ == "__main__":
|
595 |
+
os.system("mkdir -p ./checkpoints")
|
596 |
+
os.system("huggingface-cli download --resume-download Alpha-VLLM/Lumina-Next-T2I --local-dir ./checkpoints --local-dir-use-symlinks False")
|
597 |
+
mp.set_start_method("spawn")
|
598 |
+
main()
|
models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# from .model import DiT_Llama_5B_patch2
|
2 |
+
from .model import DiT_Llama_2B_patch2
|
models/components.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
try:
|
7 |
+
from apex.normalization import FusedRMSNorm as RMSNorm
|
8 |
+
except ImportError:
|
9 |
+
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
|
10 |
+
|
11 |
+
class RMSNorm(torch.nn.Module):
|
12 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
13 |
+
"""
|
14 |
+
Initialize the RMSNorm normalization layer.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
dim (int): The dimension of the input tensor.
|
18 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
19 |
+
|
20 |
+
Attributes:
|
21 |
+
eps (float): A small value added to the denominator for numerical stability.
|
22 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
23 |
+
|
24 |
+
"""
|
25 |
+
super().__init__()
|
26 |
+
self.eps = eps
|
27 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
28 |
+
|
29 |
+
def _norm(self, x):
|
30 |
+
"""
|
31 |
+
Apply the RMSNorm normalization to the input tensor.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
x (torch.Tensor): The input tensor.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
torch.Tensor: The normalized tensor.
|
38 |
+
|
39 |
+
"""
|
40 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
"""
|
44 |
+
Forward pass through the RMSNorm layer.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
x (torch.Tensor): The input tensor.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
51 |
+
|
52 |
+
"""
|
53 |
+
output = self._norm(x.float()).type_as(x)
|
54 |
+
return output * self.weight
|
models/model.py
ADDED
@@ -0,0 +1,908 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
9 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
10 |
+
# --------------------------------------------------------
|
11 |
+
|
12 |
+
import functools
|
13 |
+
import logging
|
14 |
+
import math
|
15 |
+
from typing import Optional, Tuple, List
|
16 |
+
|
17 |
+
# from apex.normalization import FusedRMSNorm as RMSNorm
|
18 |
+
from .components import RMSNorm
|
19 |
+
import fairscale.nn.model_parallel.initialize as fs_init
|
20 |
+
from fairscale.nn.model_parallel.layers import (
|
21 |
+
ColumnParallelLinear, RowParallelLinear, ParallelEmbedding,
|
22 |
+
)
|
23 |
+
from flash_attn import flash_attn_varlen_func
|
24 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
25 |
+
import torch
|
26 |
+
import torch.distributed as dist
|
27 |
+
import torch.nn as nn
|
28 |
+
import torch.nn.functional as F
|
29 |
+
|
30 |
+
logger = logging.getLogger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
def modulate(x, shift, scale):
|
34 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
35 |
+
|
36 |
+
|
37 |
+
#############################################################################
|
38 |
+
# Embedding Layers for Timesteps and Class Labels #
|
39 |
+
#############################################################################
|
40 |
+
|
41 |
+
class ParallelTimestepEmbedder(nn.Module):
|
42 |
+
"""
|
43 |
+
Embeds scalar timesteps into vector representations.
|
44 |
+
"""
|
45 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
46 |
+
super().__init__()
|
47 |
+
self.mlp = nn.Sequential(
|
48 |
+
ColumnParallelLinear(
|
49 |
+
frequency_embedding_size, hidden_size, bias=True,
|
50 |
+
gather_output=False,
|
51 |
+
init_method=functools.partial(nn.init.normal_, std=0.02),
|
52 |
+
),
|
53 |
+
nn.SiLU(),
|
54 |
+
RowParallelLinear(
|
55 |
+
hidden_size, hidden_size, bias=True, input_is_parallel=True,
|
56 |
+
init_method=functools.partial(nn.init.normal_, std=0.02),
|
57 |
+
),
|
58 |
+
)
|
59 |
+
self.frequency_embedding_size = frequency_embedding_size
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def timestep_embedding(t, dim, max_period=10000):
|
63 |
+
"""
|
64 |
+
Create sinusoidal timestep embeddings.
|
65 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
66 |
+
These may be fractional.
|
67 |
+
:param dim: the dimension of the output.
|
68 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
69 |
+
:return: an (N, D) Tensor of positional embeddings.
|
70 |
+
"""
|
71 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
72 |
+
half = dim // 2
|
73 |
+
freqs = torch.exp(
|
74 |
+
-math.log(max_period) * torch.arange(
|
75 |
+
start=0, end=half, dtype=torch.float32
|
76 |
+
) / half
|
77 |
+
).to(device=t.device)
|
78 |
+
args = t[:, None].float() * freqs[None]
|
79 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
80 |
+
if dim % 2:
|
81 |
+
embedding = torch.cat([
|
82 |
+
embedding, torch.zeros_like(embedding[:, :1])
|
83 |
+
], dim=-1)
|
84 |
+
return embedding
|
85 |
+
|
86 |
+
def forward(self, t):
|
87 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
88 |
+
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
|
89 |
+
return t_emb
|
90 |
+
|
91 |
+
|
92 |
+
class ParallelLabelEmbedder(nn.Module):
|
93 |
+
r"""Embeds class labels into vector representations. Also handles label
|
94 |
+
dropout for classifier-free guidance.
|
95 |
+
"""
|
96 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
97 |
+
super().__init__()
|
98 |
+
use_cfg_embedding = int(dropout_prob > 0)
|
99 |
+
self.embedding_table = ParallelEmbedding(
|
100 |
+
num_classes + use_cfg_embedding, hidden_size,
|
101 |
+
init_method=functools.partial(nn.init.normal_, std=0.02),
|
102 |
+
)
|
103 |
+
self.num_classes = num_classes
|
104 |
+
self.dropout_prob = dropout_prob
|
105 |
+
|
106 |
+
def token_drop(self, labels, force_drop_ids=None):
|
107 |
+
"""
|
108 |
+
Drops labels to enable classifier-free guidance.
|
109 |
+
"""
|
110 |
+
if force_drop_ids is None:
|
111 |
+
drop_ids = torch.rand(
|
112 |
+
labels.shape[0], device=labels.device
|
113 |
+
) < self.dropout_prob
|
114 |
+
drop_ids = drop_ids.cuda()
|
115 |
+
dist.broadcast(
|
116 |
+
drop_ids,
|
117 |
+
fs_init.get_model_parallel_src_rank(),
|
118 |
+
fs_init.get_model_parallel_group(),
|
119 |
+
)
|
120 |
+
drop_ids = drop_ids.to(labels.device)
|
121 |
+
else:
|
122 |
+
drop_ids = force_drop_ids == 1
|
123 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
124 |
+
return labels
|
125 |
+
|
126 |
+
def forward(self, labels, train, force_drop_ids=None):
|
127 |
+
use_dropout = self.dropout_prob > 0
|
128 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
129 |
+
labels = self.token_drop(labels, force_drop_ids)
|
130 |
+
embeddings = self.embedding_table(labels)
|
131 |
+
return embeddings
|
132 |
+
|
133 |
+
|
134 |
+
#############################################################################
|
135 |
+
# Core DiT Model #
|
136 |
+
#############################################################################
|
137 |
+
|
138 |
+
|
139 |
+
class Attention(nn.Module):
|
140 |
+
"""Multi-head attention module."""
|
141 |
+
def __init__(self, dim: int, n_heads: int, n_kv_heads: Optional[int], qk_norm: bool, y_dim: int):
|
142 |
+
"""
|
143 |
+
Initialize the Attention module.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
dim (int): Number of input dimensions.
|
147 |
+
n_heads (int): Number of heads.
|
148 |
+
n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
|
149 |
+
|
150 |
+
"""
|
151 |
+
super().__init__()
|
152 |
+
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
|
153 |
+
model_parallel_size = fs_init.get_model_parallel_world_size()
|
154 |
+
self.n_local_heads = n_heads // model_parallel_size
|
155 |
+
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
156 |
+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
157 |
+
self.head_dim = dim // n_heads
|
158 |
+
|
159 |
+
self.wq = ColumnParallelLinear(
|
160 |
+
dim, n_heads * self.head_dim, bias=False, gather_output=False,
|
161 |
+
init_method=nn.init.xavier_uniform_,
|
162 |
+
)
|
163 |
+
self.wk = ColumnParallelLinear(
|
164 |
+
dim, self.n_kv_heads * self.head_dim, bias=False,
|
165 |
+
gather_output=False, init_method=nn.init.xavier_uniform_,
|
166 |
+
)
|
167 |
+
self.wv = ColumnParallelLinear(
|
168 |
+
dim, self.n_kv_heads * self.head_dim, bias=False,
|
169 |
+
gather_output=False, init_method=nn.init.xavier_uniform_,
|
170 |
+
)
|
171 |
+
if y_dim > 0:
|
172 |
+
self.wk_y = ColumnParallelLinear(
|
173 |
+
y_dim, self.n_kv_heads * self.head_dim, bias=False,
|
174 |
+
gather_output=False, init_method=nn.init.xavier_uniform_,
|
175 |
+
)
|
176 |
+
self.wv_y = ColumnParallelLinear(
|
177 |
+
y_dim, self.n_kv_heads * self.head_dim, bias=False,
|
178 |
+
gather_output=False, init_method=nn.init.xavier_uniform_,
|
179 |
+
)
|
180 |
+
self.gate = nn.Parameter(torch.zeros([self.n_local_heads]))
|
181 |
+
|
182 |
+
self.wo = RowParallelLinear(
|
183 |
+
n_heads * self.head_dim, dim, bias=False,
|
184 |
+
input_is_parallel=True, init_method=nn.init.xavier_uniform_,
|
185 |
+
)
|
186 |
+
|
187 |
+
if qk_norm:
|
188 |
+
self.q_norm = nn.LayerNorm(self.n_local_heads * self.head_dim)
|
189 |
+
self.k_norm = nn.LayerNorm(self.n_local_kv_heads * self.head_dim)
|
190 |
+
if y_dim > 0:
|
191 |
+
self.ky_norm = nn.LayerNorm(self.n_local_kv_heads * self.head_dim)
|
192 |
+
else:
|
193 |
+
self.ky_norm = nn.Identity()
|
194 |
+
else:
|
195 |
+
self.q_norm = self.k_norm = nn.Identity()
|
196 |
+
self.ky_norm = nn.Identity()
|
197 |
+
|
198 |
+
# for proportional attention computation
|
199 |
+
self.base_seqlen = None
|
200 |
+
self.proportional_attn = False
|
201 |
+
|
202 |
+
@staticmethod
|
203 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
204 |
+
"""
|
205 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
206 |
+
|
207 |
+
This function reshapes the frequency tensor to have the same shape as
|
208 |
+
the target tensor 'x' for the purpose of broadcasting the frequency
|
209 |
+
tensor during element-wise operations.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
|
213 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
torch.Tensor: Reshaped frequency tensor.
|
217 |
+
|
218 |
+
Raises:
|
219 |
+
AssertionError: If the frequency tensor doesn't match the expected
|
220 |
+
shape.
|
221 |
+
AssertionError: If the target tensor 'x' doesn't have the expected
|
222 |
+
number of dimensions.
|
223 |
+
"""
|
224 |
+
ndim = x.ndim
|
225 |
+
assert 0 <= 1 < ndim
|
226 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
227 |
+
shape = [d if i == 1 or i == ndim - 1 else 1
|
228 |
+
for i, d in enumerate(x.shape)]
|
229 |
+
return freqs_cis.view(*shape)
|
230 |
+
|
231 |
+
@staticmethod
|
232 |
+
def apply_rotary_emb(
|
233 |
+
x_in: torch.Tensor,
|
234 |
+
freqs_cis: torch.Tensor,
|
235 |
+
) -> torch.Tensor:
|
236 |
+
"""
|
237 |
+
Apply rotary embeddings to input tensors using the given frequency
|
238 |
+
tensor.
|
239 |
+
|
240 |
+
This function applies rotary embeddings to the given query 'xq' and
|
241 |
+
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
|
242 |
+
input tensors are reshaped as complex numbers, and the frequency tensor
|
243 |
+
is reshaped for broadcasting compatibility. The resulting tensors
|
244 |
+
contain rotary embeddings and are returned as real tensors.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
|
248 |
+
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
|
249 |
+
exponentials.
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
|
253 |
+
and key tensor with rotary embeddings.
|
254 |
+
"""
|
255 |
+
with torch.cuda.amp.autocast(enabled=False):
|
256 |
+
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
257 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
258 |
+
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
259 |
+
return x_out.type_as(x_in)
|
260 |
+
|
261 |
+
# copied from huggingface modeling_llama.py
|
262 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
263 |
+
|
264 |
+
def _get_unpad_data(attention_mask):
|
265 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
266 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
267 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
268 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
269 |
+
return (
|
270 |
+
indices,
|
271 |
+
cu_seqlens,
|
272 |
+
max_seqlen_in_batch,
|
273 |
+
)
|
274 |
+
|
275 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
276 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
277 |
+
|
278 |
+
key_layer = index_first_axis(
|
279 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
280 |
+
)
|
281 |
+
value_layer = index_first_axis(
|
282 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
283 |
+
)
|
284 |
+
if query_length == kv_seq_len:
|
285 |
+
query_layer = index_first_axis(
|
286 |
+
query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim), indices_k
|
287 |
+
)
|
288 |
+
cu_seqlens_q = cu_seqlens_k
|
289 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
290 |
+
indices_q = indices_k
|
291 |
+
elif query_length == 1:
|
292 |
+
max_seqlen_in_batch_q = 1
|
293 |
+
cu_seqlens_q = torch.arange(
|
294 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
295 |
+
) # There is a memcpy here, that is very bad.
|
296 |
+
indices_q = cu_seqlens_q[:-1]
|
297 |
+
query_layer = query_layer.squeeze(1)
|
298 |
+
else:
|
299 |
+
# The -q_len: slice assumes left padding.
|
300 |
+
attention_mask = attention_mask[:, -query_length:]
|
301 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
302 |
+
|
303 |
+
return (
|
304 |
+
query_layer,
|
305 |
+
key_layer,
|
306 |
+
value_layer,
|
307 |
+
indices_q,
|
308 |
+
(cu_seqlens_q, cu_seqlens_k),
|
309 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
310 |
+
)
|
311 |
+
|
312 |
+
def forward(
|
313 |
+
self,
|
314 |
+
x: torch.Tensor,
|
315 |
+
x_mask: torch.Tensor,
|
316 |
+
freqs_cis: torch.Tensor,
|
317 |
+
y: torch.Tensor,
|
318 |
+
y_mask: torch.Tensor,
|
319 |
+
) -> torch.Tensor:
|
320 |
+
"""
|
321 |
+
|
322 |
+
Args:
|
323 |
+
x:
|
324 |
+
x_mask:
|
325 |
+
freqs_cis:
|
326 |
+
y:
|
327 |
+
y_mask:
|
328 |
+
|
329 |
+
Returns:
|
330 |
+
|
331 |
+
"""
|
332 |
+
bsz, seqlen, _ = x.shape
|
333 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
334 |
+
dtype = xq.dtype
|
335 |
+
|
336 |
+
xq = self.q_norm(xq)
|
337 |
+
xk = self.k_norm(xk)
|
338 |
+
|
339 |
+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
340 |
+
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
341 |
+
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
342 |
+
|
343 |
+
xq = Attention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
|
344 |
+
xk = Attention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
|
345 |
+
|
346 |
+
xq, xk = xq.to(dtype), xk.to(dtype)
|
347 |
+
|
348 |
+
if dtype in [torch.float16, torch.bfloat16]:
|
349 |
+
# begin var_len flash attn
|
350 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
351 |
+
xq, xk, xv, x_mask, seqlen
|
352 |
+
)
|
353 |
+
|
354 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
355 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
356 |
+
|
357 |
+
if self.proportional_attn:
|
358 |
+
softmax_scale = math.sqrt(math.log(seqlen, self.base_seqlen) / self.head_dim)
|
359 |
+
else:
|
360 |
+
softmax_scale = math.sqrt(1 / self.head_dim)
|
361 |
+
|
362 |
+
attn_output_unpad = flash_attn_varlen_func(
|
363 |
+
query_states,
|
364 |
+
key_states,
|
365 |
+
value_states,
|
366 |
+
cu_seqlens_q=cu_seqlens_q,
|
367 |
+
cu_seqlens_k=cu_seqlens_k,
|
368 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
369 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
370 |
+
dropout_p=0.,
|
371 |
+
causal=False,
|
372 |
+
softmax_scale=softmax_scale
|
373 |
+
)
|
374 |
+
output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
|
375 |
+
# end var_len_flash_attn
|
376 |
+
|
377 |
+
else:
|
378 |
+
output = F.scaled_dot_product_attention(
|
379 |
+
xq.permute(0, 2, 1, 3),
|
380 |
+
xk.permute(0, 2, 1, 3),
|
381 |
+
xv.permute(0, 2, 1, 3),
|
382 |
+
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
|
383 |
+
).permute(0, 2, 1, 3).to(dtype)
|
384 |
+
|
385 |
+
if hasattr(self, "wk_y"):
|
386 |
+
# todo better flash_attn support
|
387 |
+
yk = self.ky_norm(self.wk_y(y)).view(bsz, -1, self.n_local_kv_heads, self.head_dim)
|
388 |
+
yv = self.wv_y(y).view(bsz, -1, self.n_local_kv_heads, self.head_dim)
|
389 |
+
n_rep = self.n_local_heads // self.n_local_kv_heads
|
390 |
+
if n_rep >= 1:
|
391 |
+
yk = yk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
392 |
+
yv = yv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
393 |
+
output_y = F.scaled_dot_product_attention(
|
394 |
+
xq.permute(0, 2, 1, 3),
|
395 |
+
yk.permute(0, 2, 1, 3),
|
396 |
+
yv.permute(0, 2, 1, 3),
|
397 |
+
y_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seqlen, -1)
|
398 |
+
).permute(0, 2, 1, 3)
|
399 |
+
output_y = output_y * self.gate.tanh().view(1, 1, -1, 1)
|
400 |
+
output = output + output_y
|
401 |
+
|
402 |
+
output = output.flatten(-2)
|
403 |
+
|
404 |
+
return self.wo(output)
|
405 |
+
|
406 |
+
|
407 |
+
class FeedForward(nn.Module):
|
408 |
+
def __init__(
|
409 |
+
self,
|
410 |
+
dim: int,
|
411 |
+
hidden_dim: int,
|
412 |
+
multiple_of: int,
|
413 |
+
ffn_dim_multiplier: Optional[float],
|
414 |
+
):
|
415 |
+
"""
|
416 |
+
Initialize the FeedForward module.
|
417 |
+
|
418 |
+
Args:
|
419 |
+
dim (int): Input dimension.
|
420 |
+
hidden_dim (int): Hidden dimension of the feedforward layer.
|
421 |
+
multiple_of (int): Value to ensure hidden dimension is a multiple
|
422 |
+
of this value.
|
423 |
+
ffn_dim_multiplier (float, optional): Custom multiplier for hidden
|
424 |
+
dimension. Defaults to None.
|
425 |
+
|
426 |
+
Attributes:
|
427 |
+
w1 (ColumnParallelLinear): Linear transformation for the first
|
428 |
+
layer.
|
429 |
+
w2 (RowParallelLinear): Linear transformation for the second layer.
|
430 |
+
w3 (ColumnParallelLinear): Linear transformation for the third
|
431 |
+
layer.
|
432 |
+
|
433 |
+
"""
|
434 |
+
super().__init__()
|
435 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
436 |
+
# custom dim factor multiplier
|
437 |
+
if ffn_dim_multiplier is not None:
|
438 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
439 |
+
hidden_dim = multiple_of * (
|
440 |
+
(hidden_dim + multiple_of - 1) // multiple_of
|
441 |
+
)
|
442 |
+
|
443 |
+
self.w1 = ColumnParallelLinear(
|
444 |
+
dim, hidden_dim, bias=False, gather_output=False,
|
445 |
+
init_method=nn.init.xavier_uniform_,
|
446 |
+
)
|
447 |
+
self.w2 = RowParallelLinear(
|
448 |
+
hidden_dim, dim, bias=False, input_is_parallel=True,
|
449 |
+
init_method=nn.init.xavier_uniform_,
|
450 |
+
)
|
451 |
+
self.w3 = ColumnParallelLinear(
|
452 |
+
dim, hidden_dim, bias=False, gather_output=False,
|
453 |
+
init_method=nn.init.xavier_uniform_,
|
454 |
+
)
|
455 |
+
|
456 |
+
# @torch.compile
|
457 |
+
def _forward_silu_gating(self, x1, x3):
|
458 |
+
return F.silu(x1) * x3
|
459 |
+
|
460 |
+
def forward(self, x):
|
461 |
+
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
462 |
+
|
463 |
+
|
464 |
+
class TransformerBlock(nn.Module):
|
465 |
+
def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int,
|
466 |
+
multiple_of: int, ffn_dim_multiplier: float, norm_eps: float,
|
467 |
+
qk_norm: bool, y_dim: int) -> None:
|
468 |
+
"""
|
469 |
+
Initialize a TransformerBlock.
|
470 |
+
|
471 |
+
Args:
|
472 |
+
layer_id (int): Identifier for the layer.
|
473 |
+
dim (int): Embedding dimension of the input features.
|
474 |
+
n_heads (int): Number of attention heads.
|
475 |
+
n_kv_heads (Optional[int]): Number of attention heads in key and
|
476 |
+
value features (if using GQA), or set to None for the same as
|
477 |
+
query.
|
478 |
+
multiple_of (int):
|
479 |
+
ffn_dim_multiplier (float):
|
480 |
+
norm_eps (float):
|
481 |
+
|
482 |
+
Attributes:
|
483 |
+
n_heads (int): Number of attention heads.
|
484 |
+
dim (int): Dimension size of the model.
|
485 |
+
head_dim (int): Dimension size of each attention head.
|
486 |
+
attention (Attention): Attention module.
|
487 |
+
feed_forward (FeedForward): FeedForward module.
|
488 |
+
layer_id (int): Identifier for the layer.
|
489 |
+
attention_norm (RMSNorm): Layer normalization for attention output.
|
490 |
+
ffn_norm (RMSNorm): Layer normalization for feedforward output.
|
491 |
+
|
492 |
+
"""
|
493 |
+
super().__init__()
|
494 |
+
self.dim = dim
|
495 |
+
self.head_dim = dim // n_heads
|
496 |
+
self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, y_dim)
|
497 |
+
self.feed_forward = FeedForward(
|
498 |
+
dim=dim, hidden_dim=4 * dim, multiple_of=multiple_of,
|
499 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
500 |
+
)
|
501 |
+
self.layer_id = layer_id
|
502 |
+
self.attention_norm = RMSNorm(dim, eps=norm_eps)
|
503 |
+
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
504 |
+
self.ffn_norm = RMSNorm(dim, eps=norm_eps)
|
505 |
+
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
506 |
+
|
507 |
+
self.adaLN_modulation = nn.Sequential(
|
508 |
+
nn.SiLU(),
|
509 |
+
ColumnParallelLinear(
|
510 |
+
min(dim, 1024), 6 * dim, bias=True, gather_output=True,
|
511 |
+
init_method=nn.init.zeros_,
|
512 |
+
),
|
513 |
+
)
|
514 |
+
|
515 |
+
self.attention_y_norm = RMSNorm(y_dim, eps=norm_eps)
|
516 |
+
|
517 |
+
def forward(
|
518 |
+
self,
|
519 |
+
x: torch.Tensor,
|
520 |
+
x_mask: torch.Tensor,
|
521 |
+
freqs_cis: torch.Tensor,
|
522 |
+
y: torch.Tensor,
|
523 |
+
y_mask: torch.Tensor,
|
524 |
+
adaln_input: Optional[torch.Tensor] = None,
|
525 |
+
):
|
526 |
+
"""
|
527 |
+
Perform a forward pass through the TransformerBlock.
|
528 |
+
|
529 |
+
Args:
|
530 |
+
x (torch.Tensor): Input tensor.
|
531 |
+
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
532 |
+
|
533 |
+
Returns:
|
534 |
+
torch.Tensor: Output tensor after applying attention and
|
535 |
+
feedforward layers.
|
536 |
+
|
537 |
+
"""
|
538 |
+
if adaln_input is not None:
|
539 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
|
540 |
+
self.adaLN_modulation(adaln_input).chunk(6, dim=1)
|
541 |
+
|
542 |
+
x = x + self.attention_norm1(gate_msa.unsqueeze(1) * self.attention(
|
543 |
+
modulate(self.attention_norm(x), shift_msa, scale_msa),
|
544 |
+
x_mask,
|
545 |
+
freqs_cis,
|
546 |
+
self.attention_y_norm(y),
|
547 |
+
y_mask,
|
548 |
+
))
|
549 |
+
d = x.shape[-1]
|
550 |
+
x = x + self.ffn_norm1(gate_mlp.unsqueeze(1) * self.feed_forward(
|
551 |
+
modulate(self.ffn_norm(x), shift_mlp, scale_mlp).view(-1, d),
|
552 |
+
).view(*x.shape))
|
553 |
+
|
554 |
+
else:
|
555 |
+
x = x + self.attention_norm1(self.attention(
|
556 |
+
self.attention_norm(x), x_mask, freqs_cis, self.attention_y_norm(y), y_mask
|
557 |
+
))
|
558 |
+
# for compatibility with torch.compile because the sequence length changes
|
559 |
+
B, L, D = x.shape
|
560 |
+
x = x.view(B*L, D)
|
561 |
+
x = x + self.ffn_norm1(self.feed_forward(self.ffn_norm(x)))
|
562 |
+
x = x.view(B, L, D)
|
563 |
+
|
564 |
+
return x
|
565 |
+
|
566 |
+
|
567 |
+
class ParallelFinalLayer(nn.Module):
|
568 |
+
"""
|
569 |
+
The final layer of DiT.
|
570 |
+
"""
|
571 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
572 |
+
super().__init__()
|
573 |
+
self.norm_final = nn.LayerNorm(
|
574 |
+
hidden_size, elementwise_affine=False, eps=1e-6,
|
575 |
+
)
|
576 |
+
self.linear = ColumnParallelLinear(
|
577 |
+
hidden_size, patch_size * patch_size * out_channels, bias=True,
|
578 |
+
init_method=nn.init.zeros_, gather_output=True,
|
579 |
+
)
|
580 |
+
self.adaLN_modulation = nn.Sequential(
|
581 |
+
nn.SiLU(),
|
582 |
+
ColumnParallelLinear(
|
583 |
+
min(hidden_size, 1024), 2 * hidden_size, bias=True,
|
584 |
+
init_method=nn.init.zeros_, gather_output=True,
|
585 |
+
),
|
586 |
+
)
|
587 |
+
|
588 |
+
def forward(self, x, c):
|
589 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
590 |
+
x = modulate(self.norm_final(x), shift, scale)
|
591 |
+
x = self.linear(x)
|
592 |
+
return x
|
593 |
+
|
594 |
+
|
595 |
+
class DiT_Llama(nn.Module):
|
596 |
+
"""
|
597 |
+
Diffusion model with a Transformer backbone.
|
598 |
+
"""
|
599 |
+
def __init__(
|
600 |
+
self,
|
601 |
+
patch_size: int = 2,
|
602 |
+
in_channels: int = 4,
|
603 |
+
dim: int = 4096,
|
604 |
+
n_layers: int = 32,
|
605 |
+
n_heads: int = 32,
|
606 |
+
n_kv_heads: Optional[int] = None,
|
607 |
+
multiple_of: int = 256,
|
608 |
+
ffn_dim_multiplier: Optional[float] = None,
|
609 |
+
norm_eps: float = 1e-5,
|
610 |
+
learn_sigma: bool = True,
|
611 |
+
qk_norm: bool = False,
|
612 |
+
cap_feat_dim: int = 5120,
|
613 |
+
rope_scaling_factor: float = 1.,
|
614 |
+
ntk_factor: float=1.
|
615 |
+
) -> None:
|
616 |
+
super().__init__()
|
617 |
+
self.learn_sigma = learn_sigma
|
618 |
+
self.in_channels = in_channels
|
619 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
620 |
+
self.patch_size = patch_size
|
621 |
+
|
622 |
+
self.x_embedder = ColumnParallelLinear(
|
623 |
+
in_features=patch_size * patch_size * in_channels,
|
624 |
+
out_features=dim,
|
625 |
+
bias=True,
|
626 |
+
gather_output=True,
|
627 |
+
init_method=nn.init.xavier_uniform_,
|
628 |
+
)
|
629 |
+
nn.init.constant_(self.x_embedder.bias, 0.)
|
630 |
+
|
631 |
+
self.t_embedder = ParallelTimestepEmbedder(min(dim, 1024))
|
632 |
+
self.cap_embedder = nn.Sequential(
|
633 |
+
nn.LayerNorm(cap_feat_dim),
|
634 |
+
ColumnParallelLinear(cap_feat_dim, min(dim, 1024), bias=True, gather_output=True,
|
635 |
+
init_method=nn.init.zeros_),
|
636 |
+
)
|
637 |
+
|
638 |
+
self.layers = nn.ModuleList([
|
639 |
+
TransformerBlock(layer_id, dim, n_heads, n_kv_heads, multiple_of,
|
640 |
+
ffn_dim_multiplier, norm_eps, qk_norm, cap_feat_dim)
|
641 |
+
for layer_id in range(n_layers)
|
642 |
+
])
|
643 |
+
self.final_layer = ParallelFinalLayer(dim, patch_size, self.out_channels)
|
644 |
+
|
645 |
+
assert (dim // n_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
|
646 |
+
self.dim = dim
|
647 |
+
self.n_heads = n_heads
|
648 |
+
self.freqs_cis = DiT_Llama.precompute_freqs_cis(
|
649 |
+
dim // n_heads, 384, rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
|
650 |
+
)
|
651 |
+
self.rope_scaling_factor = rope_scaling_factor
|
652 |
+
self.ntk_factor = ntk_factor
|
653 |
+
# self.eol_token = nn.Parameter(torch.empty(dim))
|
654 |
+
self.pad_token = nn.Parameter(torch.empty(dim))
|
655 |
+
# nn.init.normal_(self.eol_token, std=0.02)
|
656 |
+
nn.init.normal_(self.pad_token, std=0.02)
|
657 |
+
|
658 |
+
def unpatchify(self, x: torch.Tensor, img_size: List[Tuple[int, int]], return_tensor=False) -> List[torch.Tensor]:
|
659 |
+
"""
|
660 |
+
x: (N, T, patch_size**2 * C)
|
661 |
+
imgs: (N, H, W, C)
|
662 |
+
"""
|
663 |
+
pH = pW = self.patch_size
|
664 |
+
if return_tensor:
|
665 |
+
H, W = img_size[0]
|
666 |
+
B = x.size(0)
|
667 |
+
L = (H // pH) * (W // pW)
|
668 |
+
x = x[:, :L].view(B, H // pH, W // pW, pH, pW, self.out_channels)
|
669 |
+
x = x.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
|
670 |
+
return x
|
671 |
+
else:
|
672 |
+
imgs = []
|
673 |
+
for i in range(x.size(0)):
|
674 |
+
H, W = img_size[i]
|
675 |
+
L = (H // pH) * (W // pW)
|
676 |
+
imgs.append(x[i][:L].view(
|
677 |
+
H // pH, W // pW, pH, pW, self.out_channels
|
678 |
+
).permute(4, 0, 2, 1, 3).flatten(3, 4).flatten(1, 2))
|
679 |
+
return imgs
|
680 |
+
|
681 |
+
def patchify_and_embed(
|
682 |
+
self,
|
683 |
+
x: List[torch.Tensor] | torch.Tensor
|
684 |
+
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]:
|
685 |
+
self.freqs_cis = self.freqs_cis.to(x[0].device)
|
686 |
+
if isinstance(x, torch.Tensor):
|
687 |
+
pH = pW = self.patch_size
|
688 |
+
B, C, H, W = x.size()
|
689 |
+
x = x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 1, 3, 5).flatten(3)
|
690 |
+
x = self.x_embedder(x)
|
691 |
+
x = x.flatten(1, 2)
|
692 |
+
|
693 |
+
mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
|
694 |
+
# leave the first line for text
|
695 |
+
return x, mask, [(H, W)] * B, self.freqs_cis[:H//pH, :W//pW].flatten(0,1).unsqueeze(0)
|
696 |
+
else:
|
697 |
+
pH = pW = self.patch_size
|
698 |
+
x_embed = []
|
699 |
+
freqs_cis = []
|
700 |
+
img_size = []
|
701 |
+
l_effective_seq_len = []
|
702 |
+
|
703 |
+
for img in x:
|
704 |
+
C, H, W = img.size()
|
705 |
+
item_freqs_cis = self.freqs_cis[:H//pH, :W//pW]
|
706 |
+
freqs_cis.append(item_freqs_cis.flatten(0,1))
|
707 |
+
img_size.append((H, W))
|
708 |
+
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 0, 2, 4).flatten(2)
|
709 |
+
img = self.x_embedder(img)
|
710 |
+
img = img.flatten(0, 1)
|
711 |
+
l_effective_seq_len.append(len(img))
|
712 |
+
x_embed.append(img)
|
713 |
+
|
714 |
+
max_seq_len = max(l_effective_seq_len)
|
715 |
+
mask = torch.zeros(len(x), max_seq_len, dtype=torch.int32, device=x[0].device)
|
716 |
+
padded_x_embed = []
|
717 |
+
padded_freqs_cis = []
|
718 |
+
for i, (item_embed, item_freqs_cis, item_seq_len) in enumerate(zip(
|
719 |
+
x_embed, freqs_cis, l_effective_seq_len
|
720 |
+
)):
|
721 |
+
item_embed = torch.cat([
|
722 |
+
item_embed,
|
723 |
+
self.pad_token.view(1, -1).expand(max_seq_len - item_seq_len, -1),
|
724 |
+
], dim=0)
|
725 |
+
item_freqs_cis = torch.cat([
|
726 |
+
item_freqs_cis,
|
727 |
+
item_freqs_cis[-1:].expand(max_seq_len - item_seq_len, -1)
|
728 |
+
], dim=0)
|
729 |
+
padded_x_embed.append(item_embed)
|
730 |
+
padded_freqs_cis.append(item_freqs_cis)
|
731 |
+
mask[i][:item_seq_len] = 1
|
732 |
+
|
733 |
+
x_embed = torch.stack(padded_x_embed, dim=0)
|
734 |
+
freqs_cis = torch.stack(padded_freqs_cis, dim=0)
|
735 |
+
return x_embed, mask, img_size, freqs_cis
|
736 |
+
|
737 |
+
def forward(self, x, t, cap_feats, cap_mask):
|
738 |
+
"""
|
739 |
+
Forward pass of DiT.
|
740 |
+
t: (N,) tensor of diffusion timesteps
|
741 |
+
y: (N,) tensor of class labels
|
742 |
+
"""
|
743 |
+
x_is_tensor = isinstance(x, torch.Tensor)
|
744 |
+
x, mask, img_size, freqs_cis = self.patchify_and_embed(x)
|
745 |
+
freqs_cis = freqs_cis.to(x.device)
|
746 |
+
|
747 |
+
# cap_freqs_cis = self.freqs_cis[:1, :cap_feats.shape[1]].to(x.device)
|
748 |
+
|
749 |
+
t = self.t_embedder(t) # (N, D)
|
750 |
+
cap_mask_float = cap_mask.float().unsqueeze(-1)
|
751 |
+
cap_feats_pool = (cap_feats * cap_mask_float).sum(dim=1) / cap_mask_float.sum(dim=1)
|
752 |
+
cap_feats_pool = cap_feats_pool.to(cap_feats)
|
753 |
+
cap_emb = self.cap_embedder(cap_feats_pool)
|
754 |
+
adaln_input = t + cap_emb
|
755 |
+
|
756 |
+
cap_mask = cap_mask.bool()
|
757 |
+
for layer in self.layers:
|
758 |
+
x = layer(
|
759 |
+
x, mask, freqs_cis, cap_feats, cap_mask,
|
760 |
+
adaln_input=adaln_input
|
761 |
+
)
|
762 |
+
|
763 |
+
x = self.final_layer(x, adaln_input)
|
764 |
+
x = self.unpatchify(x, img_size, return_tensor=x_is_tensor)
|
765 |
+
if self.learn_sigma:
|
766 |
+
if x_is_tensor:
|
767 |
+
x, _ = x.chunk(2, dim=1)
|
768 |
+
else:
|
769 |
+
x = [_.chunk(2, dim=0)[0] for _ in x]
|
770 |
+
return x
|
771 |
+
|
772 |
+
def forward_with_cfg(self, x, t, cap_feats, cap_mask, cfg_scale, rope_scaling_factor=None, ntk_factor=None, base_seqlen: Optional[int] = None, proportional_attn: bool = False):
|
773 |
+
# """
|
774 |
+
# Forward pass of DiT, but also batches the unconditional forward pass
|
775 |
+
# for classifier-free guidance.
|
776 |
+
# """
|
777 |
+
# # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
778 |
+
# print(ntk_factor, rope_scaling_factor, self.ntk_factor, self.rope_scaling_factor)
|
779 |
+
if rope_scaling_factor is not None or ntk_factor is not None:
|
780 |
+
rope_scaling_factor = rope_scaling_factor if rope_scaling_factor is not None else self.rope_scaling_factor
|
781 |
+
ntk_factor = ntk_factor if ntk_factor is not None else self.ntk_factor
|
782 |
+
if rope_scaling_factor != self.rope_scaling_factor or ntk_factor != self.ntk_factor:
|
783 |
+
print(f"override freqs_cis, rope_scaling {rope_scaling_factor}, ntk {ntk_factor}", flush=True)
|
784 |
+
self.freqs_cis = DiT_Llama.precompute_freqs_cis(
|
785 |
+
self.dim // self.n_heads, 384,
|
786 |
+
rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
|
787 |
+
)
|
788 |
+
self.rope_scaling_factor = rope_scaling_factor
|
789 |
+
self.ntk_factor = ntk_factor
|
790 |
+
|
791 |
+
if proportional_attn:
|
792 |
+
assert base_seqlen is not None
|
793 |
+
for layer in self.layers:
|
794 |
+
layer.attention.base_seqlen = base_seqlen
|
795 |
+
layer.attention.proportional_attn = proportional_attn
|
796 |
+
else:
|
797 |
+
for layer in self.layers:
|
798 |
+
layer.attention.base_seqlen = None
|
799 |
+
layer.attention.proportional_attn = proportional_attn
|
800 |
+
|
801 |
+
half = x[: len(x) // 2]
|
802 |
+
combined = torch.cat([half, half], dim=0)
|
803 |
+
model_out = self.forward(combined, t, cap_feats, cap_mask)
|
804 |
+
# For exact reproducibility reasons, we apply classifier-free guidance on only
|
805 |
+
# three channels by default. The standard approach to cfg applies it to all channels.
|
806 |
+
# This can be done by uncommenting the following line and commenting-out the line following that.
|
807 |
+
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
808 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
809 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
810 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
811 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
812 |
+
return torch.cat([eps, rest], dim=1)
|
813 |
+
|
814 |
+
@staticmethod
|
815 |
+
def precompute_freqs_cis(
|
816 |
+
dim: int,
|
817 |
+
end: int,
|
818 |
+
theta: float = 10000.0,
|
819 |
+
rope_scaling_factor: float = 1.0,
|
820 |
+
ntk_factor: float = 1.0
|
821 |
+
):
|
822 |
+
"""
|
823 |
+
Precompute the frequency tensor for complex exponentials (cis) with
|
824 |
+
given dimensions.
|
825 |
+
|
826 |
+
This function calculates a frequency tensor with complex exponentials
|
827 |
+
using the given dimension 'dim' and the end index 'end'. The 'theta'
|
828 |
+
parameter scales the frequencies. The returned tensor contains complex
|
829 |
+
values in complex64 data type.
|
830 |
+
|
831 |
+
Args:
|
832 |
+
dim (int): Dimension of the frequency tensor.
|
833 |
+
end (int): End index for precomputing frequencies.
|
834 |
+
theta (float, optional): Scaling factor for frequency computation.
|
835 |
+
Defaults to 10000.0.
|
836 |
+
|
837 |
+
Returns:
|
838 |
+
torch.Tensor: Precomputed frequency tensor with complex
|
839 |
+
exponentials.
|
840 |
+
"""
|
841 |
+
|
842 |
+
theta = theta * ntk_factor
|
843 |
+
|
844 |
+
logger.info(f"theta {theta} rope scaling {rope_scaling_factor} ntk {ntk_factor}")
|
845 |
+
freqs = 1.0 / (theta ** (
|
846 |
+
torch.arange(0, dim, 4)[: (dim // 4)].float().cuda() / dim
|
847 |
+
))
|
848 |
+
t = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore
|
849 |
+
t = t / rope_scaling_factor
|
850 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
851 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
852 |
+
|
853 |
+
freqs_cis_h = freqs_cis.view(end, 1, dim//4, 1).repeat(1, end, 1, 1)
|
854 |
+
freqs_cis_w = freqs_cis.view(1, end, dim//4, 1).repeat(end, 1, 1, 1)
|
855 |
+
freqs_cis = torch.cat([freqs_cis_h, freqs_cis_w], dim=-1).flatten(2)
|
856 |
+
return freqs_cis
|
857 |
+
|
858 |
+
def parameter_count(self) -> int:
|
859 |
+
tensor_parallel_module_list = (
|
860 |
+
ColumnParallelLinear, RowParallelLinear, ParallelEmbedding,
|
861 |
+
)
|
862 |
+
total_params = 0
|
863 |
+
|
864 |
+
def _recursive_count_params(module):
|
865 |
+
nonlocal total_params
|
866 |
+
is_tp_module = isinstance(module, tensor_parallel_module_list)
|
867 |
+
for param in module.parameters(recurse=False):
|
868 |
+
total_params += param.numel() * (
|
869 |
+
fs_init.get_model_parallel_world_size()
|
870 |
+
if is_tp_module else 1
|
871 |
+
)
|
872 |
+
for submodule in module.children():
|
873 |
+
_recursive_count_params(submodule)
|
874 |
+
|
875 |
+
_recursive_count_params(self)
|
876 |
+
return total_params
|
877 |
+
|
878 |
+
def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
|
879 |
+
return list(self.layers)
|
880 |
+
|
881 |
+
|
882 |
+
#############################################################################
|
883 |
+
# DiT Configs #
|
884 |
+
#############################################################################
|
885 |
+
|
886 |
+
|
887 |
+
def DiT_Llama_600M_patch2(**kwargs):
|
888 |
+
return DiT_Llama(
|
889 |
+
patch_size=2, dim=1536, n_layers=16, n_heads=32, **kwargs
|
890 |
+
)
|
891 |
+
|
892 |
+
|
893 |
+
def DiT_Llama_2B_patch2(**kwargs):
|
894 |
+
return DiT_Llama(
|
895 |
+
patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs
|
896 |
+
)
|
897 |
+
|
898 |
+
|
899 |
+
def DiT_Llama_3B_patch2(**kwargs):
|
900 |
+
return DiT_Llama(
|
901 |
+
patch_size=2, dim=3072, n_layers=32, n_heads=32, **kwargs
|
902 |
+
)
|
903 |
+
|
904 |
+
|
905 |
+
def DiT_Llama_7B_patch2(**kwargs):
|
906 |
+
return DiT_Llama(
|
907 |
+
patch_size=2, dim=4096, n_layers=32, n_heads=32, **kwargs
|
908 |
+
)
|
models/model_5b.py
ADDED
@@ -0,0 +1,894 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
9 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
10 |
+
# --------------------------------------------------------
|
11 |
+
|
12 |
+
import functools
|
13 |
+
import math
|
14 |
+
from typing import Optional, Tuple, List
|
15 |
+
|
16 |
+
from .components import RMSNorm
|
17 |
+
import fairscale.nn.model_parallel.initialize as fs_init
|
18 |
+
from fairscale.nn.model_parallel.layers import (
|
19 |
+
ColumnParallelLinear, RowParallelLinear, ParallelEmbedding,
|
20 |
+
)
|
21 |
+
from flash_attn import flash_attn_varlen_func
|
22 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
23 |
+
import torch
|
24 |
+
import torch.distributed as dist
|
25 |
+
import torch.nn as nn
|
26 |
+
import torch.nn.functional as F
|
27 |
+
|
28 |
+
|
29 |
+
def modulate(x, shift, scale):
|
30 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
31 |
+
|
32 |
+
|
33 |
+
#############################################################################
|
34 |
+
# Embedding Layers for Timesteps and Class Labels #
|
35 |
+
#############################################################################
|
36 |
+
|
37 |
+
class ParallelTimestepEmbedder(nn.Module):
|
38 |
+
"""
|
39 |
+
Embeds scalar timesteps into vector representations.
|
40 |
+
"""
|
41 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
42 |
+
super().__init__()
|
43 |
+
self.mlp = nn.Sequential(
|
44 |
+
ColumnParallelLinear(
|
45 |
+
frequency_embedding_size, hidden_size, bias=True,
|
46 |
+
gather_output=False,
|
47 |
+
init_method=functools.partial(nn.init.normal_, std=0.02),
|
48 |
+
),
|
49 |
+
nn.SiLU(),
|
50 |
+
RowParallelLinear(
|
51 |
+
hidden_size, hidden_size, bias=True, input_is_parallel=True,
|
52 |
+
init_method=functools.partial(nn.init.normal_, std=0.02),
|
53 |
+
),
|
54 |
+
)
|
55 |
+
self.frequency_embedding_size = frequency_embedding_size
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def timestep_embedding(t, dim, max_period=10000):
|
59 |
+
"""
|
60 |
+
Create sinusoidal timestep embeddings.
|
61 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
62 |
+
These may be fractional.
|
63 |
+
:param dim: the dimension of the output.
|
64 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
65 |
+
:return: an (N, D) Tensor of positional embeddings.
|
66 |
+
"""
|
67 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
68 |
+
half = dim // 2
|
69 |
+
freqs = torch.exp(
|
70 |
+
-math.log(max_period) * torch.arange(
|
71 |
+
start=0, end=half, dtype=torch.float32
|
72 |
+
) / half
|
73 |
+
).to(device=t.device)
|
74 |
+
args = t[:, None].float() * freqs[None]
|
75 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
76 |
+
if dim % 2:
|
77 |
+
embedding = torch.cat([
|
78 |
+
embedding, torch.zeros_like(embedding[:, :1])
|
79 |
+
], dim=-1)
|
80 |
+
return embedding
|
81 |
+
|
82 |
+
def forward(self, t):
|
83 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
84 |
+
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
|
85 |
+
return t_emb
|
86 |
+
|
87 |
+
|
88 |
+
class ParallelLabelEmbedder(nn.Module):
|
89 |
+
r"""Embeds class labels into vector representations. Also handles label
|
90 |
+
dropout for classifier-free guidance.
|
91 |
+
"""
|
92 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
93 |
+
super().__init__()
|
94 |
+
use_cfg_embedding = int(dropout_prob > 0)
|
95 |
+
self.embedding_table = ParallelEmbedding(
|
96 |
+
num_classes + use_cfg_embedding, hidden_size,
|
97 |
+
init_method=functools.partial(nn.init.normal_, std=0.02),
|
98 |
+
)
|
99 |
+
self.num_classes = num_classes
|
100 |
+
self.dropout_prob = dropout_prob
|
101 |
+
|
102 |
+
def token_drop(self, labels, force_drop_ids=None):
|
103 |
+
"""
|
104 |
+
Drops labels to enable classifier-free guidance.
|
105 |
+
"""
|
106 |
+
if force_drop_ids is None:
|
107 |
+
drop_ids = torch.rand(
|
108 |
+
labels.shape[0], device=labels.device
|
109 |
+
) < self.dropout_prob
|
110 |
+
drop_ids = drop_ids.cuda()
|
111 |
+
dist.broadcast(
|
112 |
+
drop_ids,
|
113 |
+
fs_init.get_model_parallel_src_rank(),
|
114 |
+
fs_init.get_model_parallel_group(),
|
115 |
+
)
|
116 |
+
drop_ids = drop_ids.to(labels.device)
|
117 |
+
else:
|
118 |
+
drop_ids = force_drop_ids == 1
|
119 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
120 |
+
return labels
|
121 |
+
|
122 |
+
def forward(self, labels, train, force_drop_ids=None):
|
123 |
+
use_dropout = self.dropout_prob > 0
|
124 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
125 |
+
labels = self.token_drop(labels, force_drop_ids)
|
126 |
+
embeddings = self.embedding_table(labels)
|
127 |
+
return embeddings
|
128 |
+
|
129 |
+
|
130 |
+
#############################################################################
|
131 |
+
# Core DiT Model #
|
132 |
+
#############################################################################
|
133 |
+
|
134 |
+
|
135 |
+
class Attention(nn.Module):
|
136 |
+
"""Multi-head attention module."""
|
137 |
+
def __init__(self, dim: int, n_heads: int, n_kv_heads: Optional[int], qk_norm: bool, y_dim: int):
|
138 |
+
"""
|
139 |
+
Initialize the Attention module.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
dim (int): Number of input dimensions.
|
143 |
+
n_heads (int): Number of heads.
|
144 |
+
n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
|
145 |
+
|
146 |
+
Attributes:
|
147 |
+
n_kv_heads (int): Number of key and value heads.
|
148 |
+
n_local_heads (int): Number of local query heads.
|
149 |
+
n_local_kv_heads (int): Number of local key and value heads.
|
150 |
+
n_rep (int): Number of repetitions for local heads.
|
151 |
+
head_dim (int): Dimension size of each attention head.
|
152 |
+
wq (ColumnParallelLinear): Linear transformation for queries.
|
153 |
+
wk (ColumnParallelLinear): Linear transformation for keys.
|
154 |
+
wv (ColumnParallelLinear): Linear transformation for values.
|
155 |
+
wo (RowParallelLinear): Linear transformation for output.
|
156 |
+
cache_k (torch.Tensor): Cached keys for attention.
|
157 |
+
cache_v (torch.Tensor): Cached values for attention.
|
158 |
+
|
159 |
+
"""
|
160 |
+
super().__init__()
|
161 |
+
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
|
162 |
+
model_parallel_size = fs_init.get_model_parallel_world_size()
|
163 |
+
self.n_local_heads = n_heads // model_parallel_size
|
164 |
+
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
165 |
+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
166 |
+
self.head_dim = dim // n_heads
|
167 |
+
|
168 |
+
self.wq = ColumnParallelLinear(
|
169 |
+
dim, n_heads * self.head_dim, bias=False, gather_output=False,
|
170 |
+
init_method=nn.init.xavier_uniform_,
|
171 |
+
)
|
172 |
+
self.wk = ColumnParallelLinear(
|
173 |
+
dim, self.n_kv_heads * self.head_dim, bias=False,
|
174 |
+
gather_output=False, init_method=nn.init.xavier_uniform_,
|
175 |
+
)
|
176 |
+
self.wv = ColumnParallelLinear(
|
177 |
+
dim, self.n_kv_heads * self.head_dim, bias=False,
|
178 |
+
gather_output=False, init_method=nn.init.xavier_uniform_,
|
179 |
+
)
|
180 |
+
if y_dim > 0:
|
181 |
+
self.wk_y = ColumnParallelLinear(
|
182 |
+
y_dim, self.n_kv_heads * self.head_dim, bias=False,
|
183 |
+
gather_output=False, init_method=nn.init.xavier_uniform_,
|
184 |
+
)
|
185 |
+
self.wv_y = ColumnParallelLinear(
|
186 |
+
y_dim, self.n_kv_heads * self.head_dim, bias=False,
|
187 |
+
gather_output=False, init_method=nn.init.xavier_uniform_,
|
188 |
+
)
|
189 |
+
self.gate = nn.Parameter(torch.zeros([self.n_local_heads]))
|
190 |
+
|
191 |
+
self.wo = RowParallelLinear(
|
192 |
+
n_heads * self.head_dim, dim, bias=False,
|
193 |
+
input_is_parallel=True, init_method=nn.init.xavier_uniform_,
|
194 |
+
)
|
195 |
+
|
196 |
+
if qk_norm:
|
197 |
+
self.q_norm = nn.LayerNorm(self.n_local_heads * self.head_dim)
|
198 |
+
self.k_norm = nn.LayerNorm(self.n_local_kv_heads * self.head_dim)
|
199 |
+
if y_dim > 0:
|
200 |
+
self.ky_norm = nn.LayerNorm(self.n_local_kv_heads * self.head_dim)
|
201 |
+
else:
|
202 |
+
self.ky_norm = nn.Identity()
|
203 |
+
else:
|
204 |
+
self.q_norm = self.k_norm = nn.Identity()
|
205 |
+
self.ky_norm = nn.Identity()
|
206 |
+
|
207 |
+
# for proportional attention computation
|
208 |
+
self.base_seqlen = None
|
209 |
+
self.proportional_attn = False
|
210 |
+
|
211 |
+
@staticmethod
|
212 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
213 |
+
"""
|
214 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
215 |
+
|
216 |
+
This function reshapes the frequency tensor to have the same shape as
|
217 |
+
the target tensor 'x' for the purpose of broadcasting the frequency
|
218 |
+
tensor during element-wise operations.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
|
222 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
223 |
+
|
224 |
+
Returns:
|
225 |
+
torch.Tensor: Reshaped frequency tensor.
|
226 |
+
|
227 |
+
Raises:
|
228 |
+
AssertionError: If the frequency tensor doesn't match the expected
|
229 |
+
shape.
|
230 |
+
AssertionError: If the target tensor 'x' doesn't have the expected
|
231 |
+
number of dimensions.
|
232 |
+
"""
|
233 |
+
ndim = x.ndim
|
234 |
+
assert 0 <= 1 < ndim
|
235 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
236 |
+
shape = [d if i == 1 or i == ndim - 1 else 1
|
237 |
+
for i, d in enumerate(x.shape)]
|
238 |
+
return freqs_cis.view(*shape)
|
239 |
+
|
240 |
+
@staticmethod
|
241 |
+
def apply_rotary_emb(
|
242 |
+
xq: torch.Tensor,
|
243 |
+
xk: torch.Tensor,
|
244 |
+
freqs_cis: torch.Tensor,
|
245 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
246 |
+
"""
|
247 |
+
Apply rotary embeddings to input tensors using the given frequency
|
248 |
+
tensor.
|
249 |
+
|
250 |
+
This function applies rotary embeddings to the given query 'xq' and
|
251 |
+
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
|
252 |
+
input tensors are reshaped as complex numbers, and the frequency tensor
|
253 |
+
is reshaped for broadcasting compatibility. The resulting tensors
|
254 |
+
contain rotary embeddings and are returned as real tensors.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
xq (torch.Tensor): Query tensor to apply rotary embeddings.
|
258 |
+
xk (torch.Tensor): Key tensor to apply rotary embeddings.
|
259 |
+
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
|
260 |
+
exponentials.
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
|
264 |
+
and key tensor with rotary embeddings.
|
265 |
+
"""
|
266 |
+
with torch.cuda.amp.autocast(enabled=False):
|
267 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
268 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
269 |
+
freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_)
|
270 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
271 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
272 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
273 |
+
|
274 |
+
# copied from huggingface modeling_llama.py
|
275 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
276 |
+
|
277 |
+
def _get_unpad_data(attention_mask):
|
278 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
279 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
280 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
281 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
282 |
+
return (
|
283 |
+
indices,
|
284 |
+
cu_seqlens,
|
285 |
+
max_seqlen_in_batch,
|
286 |
+
)
|
287 |
+
|
288 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
289 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
290 |
+
|
291 |
+
key_layer = index_first_axis(
|
292 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
293 |
+
)
|
294 |
+
value_layer = index_first_axis(
|
295 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
296 |
+
)
|
297 |
+
if query_length == kv_seq_len:
|
298 |
+
query_layer = index_first_axis(
|
299 |
+
query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim), indices_k
|
300 |
+
)
|
301 |
+
cu_seqlens_q = cu_seqlens_k
|
302 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
303 |
+
indices_q = indices_k
|
304 |
+
elif query_length == 1:
|
305 |
+
max_seqlen_in_batch_q = 1
|
306 |
+
cu_seqlens_q = torch.arange(
|
307 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
308 |
+
) # There is a memcpy here, that is very bad.
|
309 |
+
indices_q = cu_seqlens_q[:-1]
|
310 |
+
query_layer = query_layer.squeeze(1)
|
311 |
+
else:
|
312 |
+
# The -q_len: slice assumes left padding.
|
313 |
+
attention_mask = attention_mask[:, -query_length:]
|
314 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
315 |
+
|
316 |
+
return (
|
317 |
+
query_layer,
|
318 |
+
key_layer,
|
319 |
+
value_layer,
|
320 |
+
indices_q,
|
321 |
+
(cu_seqlens_q, cu_seqlens_k),
|
322 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
323 |
+
)
|
324 |
+
|
325 |
+
def forward(
|
326 |
+
self,
|
327 |
+
x: torch.Tensor, x_mask: torch.Tensor,
|
328 |
+
freqs_cis: torch.Tensor,
|
329 |
+
y: torch.Tensor, y_mask: torch.Tensor,
|
330 |
+
) -> torch.Tensor:
|
331 |
+
"""
|
332 |
+
Forward pass of the attention module.
|
333 |
+
|
334 |
+
Args:
|
335 |
+
x (torch.Tensor): Input tensor.
|
336 |
+
freqs_cis (torch.Tensor): Precomputed frequency tensor.
|
337 |
+
|
338 |
+
Returns:
|
339 |
+
torch.Tensor: Output tensor after attention.
|
340 |
+
|
341 |
+
"""
|
342 |
+
bsz, seqlen, _ = x.shape
|
343 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
344 |
+
dtype = xq.dtype
|
345 |
+
|
346 |
+
xq = self.q_norm(xq)
|
347 |
+
xk = self.k_norm(xk)
|
348 |
+
|
349 |
+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
350 |
+
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
351 |
+
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
352 |
+
|
353 |
+
xq, xk = Attention.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
354 |
+
xq, xk = xq.to(dtype), xk.to(dtype)
|
355 |
+
|
356 |
+
if dtype in [torch.float16, torch.bfloat16]:
|
357 |
+
# begin var_len flash attn
|
358 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
359 |
+
xq, xk, xv, x_mask, seqlen
|
360 |
+
)
|
361 |
+
|
362 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
363 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
364 |
+
|
365 |
+
if self.proportional_attn:
|
366 |
+
softmax_scale = math.sqrt(math.log(seqlen, self.base_seqlen) / self.head_dim)
|
367 |
+
else:
|
368 |
+
softmax_scale = math.sqrt(1 / self.head_dim)
|
369 |
+
attn_output_unpad = flash_attn_varlen_func(
|
370 |
+
query_states,
|
371 |
+
key_states,
|
372 |
+
value_states,
|
373 |
+
cu_seqlens_q=cu_seqlens_q,
|
374 |
+
cu_seqlens_k=cu_seqlens_k,
|
375 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
376 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
377 |
+
dropout_p=0.,
|
378 |
+
causal=False,
|
379 |
+
softmax_scale=softmax_scale
|
380 |
+
)
|
381 |
+
output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
|
382 |
+
# end var_len_flash_attn
|
383 |
+
|
384 |
+
else:
|
385 |
+
output = F.scaled_dot_product_attention(
|
386 |
+
xq.permute(0, 2, 1, 3),
|
387 |
+
xk.permute(0, 2, 1, 3),
|
388 |
+
xv.permute(0, 2, 1, 3),
|
389 |
+
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
|
390 |
+
).permute(0, 2, 1, 3).to(dtype)
|
391 |
+
|
392 |
+
if hasattr(self, "wk_y"):
|
393 |
+
yk = self.ky_norm(self.wk_y(y)).view(bsz, -1, self.n_local_kv_heads, self.head_dim)
|
394 |
+
yv = self.wv_y(y).view(bsz, -1, self.n_local_kv_heads, self.head_dim)
|
395 |
+
n_rep = self.n_local_heads // self.n_local_kv_heads
|
396 |
+
if n_rep >= 1:
|
397 |
+
yk = yk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
398 |
+
yv = yv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
399 |
+
output_y = F.scaled_dot_product_attention(
|
400 |
+
xq.permute(0, 2, 1, 3),
|
401 |
+
yk.permute(0, 2, 1, 3),
|
402 |
+
yv.permute(0, 2, 1, 3),
|
403 |
+
y_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seqlen, -1)
|
404 |
+
).permute(0, 2, 1, 3)
|
405 |
+
output_y = output_y * self.gate.tanh().view(1, 1, -1, 1)
|
406 |
+
output = output + output_y
|
407 |
+
|
408 |
+
output = output.flatten(-2)
|
409 |
+
|
410 |
+
return self.wo(output)
|
411 |
+
|
412 |
+
|
413 |
+
class FeedForward(nn.Module):
|
414 |
+
def __init__(
|
415 |
+
self,
|
416 |
+
dim: int,
|
417 |
+
hidden_dim: int,
|
418 |
+
multiple_of: int,
|
419 |
+
ffn_dim_multiplier: Optional[float],
|
420 |
+
):
|
421 |
+
"""
|
422 |
+
Initialize the FeedForward module.
|
423 |
+
|
424 |
+
Args:
|
425 |
+
dim (int): Input dimension.
|
426 |
+
hidden_dim (int): Hidden dimension of the feedforward layer.
|
427 |
+
multiple_of (int): Value to ensure hidden dimension is a multiple
|
428 |
+
of this value.
|
429 |
+
ffn_dim_multiplier (float, optional): Custom multiplier for hidden
|
430 |
+
dimension. Defaults to None.
|
431 |
+
|
432 |
+
Attributes:
|
433 |
+
w1 (ColumnParallelLinear): Linear transformation for the first
|
434 |
+
layer.
|
435 |
+
w2 (RowParallelLinear): Linear transformation for the second layer.
|
436 |
+
w3 (ColumnParallelLinear): Linear transformation for the third
|
437 |
+
layer.
|
438 |
+
|
439 |
+
"""
|
440 |
+
super().__init__()
|
441 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
442 |
+
# custom dim factor multiplier
|
443 |
+
if ffn_dim_multiplier is not None:
|
444 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
445 |
+
hidden_dim = multiple_of * (
|
446 |
+
(hidden_dim + multiple_of - 1) // multiple_of
|
447 |
+
)
|
448 |
+
|
449 |
+
self.w1 = ColumnParallelLinear(
|
450 |
+
dim, hidden_dim, bias=False, gather_output=False,
|
451 |
+
init_method=nn.init.xavier_uniform_,
|
452 |
+
)
|
453 |
+
self.w2 = RowParallelLinear(
|
454 |
+
hidden_dim, dim, bias=False, input_is_parallel=True,
|
455 |
+
init_method=nn.init.xavier_uniform_,
|
456 |
+
)
|
457 |
+
self.w3 = ColumnParallelLinear(
|
458 |
+
dim, hidden_dim, bias=False, gather_output=False,
|
459 |
+
init_method=nn.init.xavier_uniform_,
|
460 |
+
)
|
461 |
+
|
462 |
+
# @torch.compile
|
463 |
+
def _forward_silu_gating(self, x1, x3):
|
464 |
+
return F.silu(x1) * x3
|
465 |
+
|
466 |
+
def forward(self, x):
|
467 |
+
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
468 |
+
|
469 |
+
|
470 |
+
class TransformerBlock(nn.Module):
|
471 |
+
def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int,
|
472 |
+
multiple_of: int, ffn_dim_multiplier: float, norm_eps: float,
|
473 |
+
qk_norm: bool, y_dim: int) -> None:
|
474 |
+
"""
|
475 |
+
Initialize a TransformerBlock.
|
476 |
+
|
477 |
+
Args:
|
478 |
+
layer_id (int): Identifier for the layer.
|
479 |
+
dim (int): Embedding dimension of the input features.
|
480 |
+
n_heads (int): Number of attention heads.
|
481 |
+
n_kv_heads (Optional[int]): Number of attention heads in key and
|
482 |
+
value features (if using GQA), or set to None for the same as
|
483 |
+
query.
|
484 |
+
multiple_of (int):
|
485 |
+
ffn_dim_multiplier (float):
|
486 |
+
norm_eps (float):
|
487 |
+
|
488 |
+
Attributes:
|
489 |
+
n_heads (int): Number of attention heads.
|
490 |
+
dim (int): Dimension size of the model.
|
491 |
+
head_dim (int): Dimension size of each attention head.
|
492 |
+
attention (Attention): Attention module.
|
493 |
+
feed_forward (FeedForward): FeedForward module.
|
494 |
+
layer_id (int): Identifier for the layer.
|
495 |
+
attention_norm (RMSNorm): Layer normalization for attention output.
|
496 |
+
ffn_norm (RMSNorm): Layer normalization for feedforward output.
|
497 |
+
|
498 |
+
"""
|
499 |
+
super().__init__()
|
500 |
+
self.dim = dim
|
501 |
+
self.head_dim = dim // n_heads
|
502 |
+
self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, y_dim)
|
503 |
+
self.feed_forward = FeedForward(
|
504 |
+
dim=dim, hidden_dim=4 * dim, multiple_of=multiple_of,
|
505 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
506 |
+
)
|
507 |
+
self.layer_id = layer_id
|
508 |
+
self.attention_norm = RMSNorm(dim, eps=norm_eps)
|
509 |
+
self.ffn_norm = RMSNorm(dim, eps=norm_eps)
|
510 |
+
|
511 |
+
self.adaLN_modulation = nn.Sequential(
|
512 |
+
nn.SiLU(),
|
513 |
+
ColumnParallelLinear(
|
514 |
+
min(dim, 1024), 6 * dim, bias=True, gather_output=True,
|
515 |
+
init_method=nn.init.zeros_,
|
516 |
+
),
|
517 |
+
)
|
518 |
+
|
519 |
+
self.attention_y_norm = RMSNorm(y_dim, eps=norm_eps)
|
520 |
+
|
521 |
+
def forward(
|
522 |
+
self,
|
523 |
+
x: torch.Tensor,
|
524 |
+
x_mask: torch.Tensor,
|
525 |
+
y: torch.Tensor,
|
526 |
+
y_mask: torch.Tensor,
|
527 |
+
freqs_cis: torch.Tensor,
|
528 |
+
adaln_input: Optional[torch.Tensor] = None,
|
529 |
+
):
|
530 |
+
"""
|
531 |
+
Perform a forward pass through the TransformerBlock.
|
532 |
+
|
533 |
+
Args:
|
534 |
+
x (torch.Tensor): Input tensor.
|
535 |
+
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
536 |
+
mask (torch.Tensor, optional): Masking tensor for attention.
|
537 |
+
Defaults to None.
|
538 |
+
|
539 |
+
Returns:
|
540 |
+
torch.Tensor: Output tensor after applying attention and
|
541 |
+
feedforward layers.
|
542 |
+
|
543 |
+
"""
|
544 |
+
if adaln_input is not None:
|
545 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
|
546 |
+
self.adaLN_modulation(adaln_input).chunk(6, dim=1)
|
547 |
+
|
548 |
+
x = x + gate_msa.unsqueeze(1) * self.attention(
|
549 |
+
modulate(self.attention_norm(x), shift_msa, scale_msa),
|
550 |
+
x_mask,
|
551 |
+
freqs_cis,
|
552 |
+
self.attention_y_norm(y), y_mask,
|
553 |
+
)
|
554 |
+
x = x + gate_mlp.unsqueeze(1) * self.feed_forward(
|
555 |
+
modulate(self.ffn_norm(x), shift_mlp, scale_mlp),
|
556 |
+
)
|
557 |
+
|
558 |
+
else:
|
559 |
+
x = x + self.attention(
|
560 |
+
self.attention_norm(x), x_mask, freqs_cis, self.attention_y_norm(y), y_mask,
|
561 |
+
)
|
562 |
+
x = x + self.feed_forward(self.ffn_norm(x))
|
563 |
+
|
564 |
+
return x
|
565 |
+
|
566 |
+
class ParallelFinalLayer(nn.Module):
|
567 |
+
"""
|
568 |
+
The final layer of DiT.
|
569 |
+
"""
|
570 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
571 |
+
super().__init__()
|
572 |
+
self.norm_final = nn.LayerNorm(
|
573 |
+
hidden_size, elementwise_affine=False, eps=1e-6,
|
574 |
+
)
|
575 |
+
self.linear = ColumnParallelLinear(
|
576 |
+
hidden_size, patch_size * patch_size * out_channels, bias=True,
|
577 |
+
init_method=nn.init.zeros_, gather_output=True,
|
578 |
+
)
|
579 |
+
self.adaLN_modulation = nn.Sequential(
|
580 |
+
nn.SiLU(),
|
581 |
+
ColumnParallelLinear(
|
582 |
+
min(hidden_size, 1024), 2 * hidden_size, bias=True,
|
583 |
+
init_method=nn.init.zeros_, gather_output=True,
|
584 |
+
),
|
585 |
+
)
|
586 |
+
|
587 |
+
def forward(self, x, c):
|
588 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
589 |
+
x = modulate(self.norm_final(x), shift, scale)
|
590 |
+
x = self.linear(x)
|
591 |
+
return x
|
592 |
+
|
593 |
+
|
594 |
+
class DiT_Llama(nn.Module):
|
595 |
+
"""
|
596 |
+
Diffusion model with a Transformer backbone.
|
597 |
+
"""
|
598 |
+
def __init__(
|
599 |
+
self,
|
600 |
+
patch_size: int = 2,
|
601 |
+
in_channels: int = 4,
|
602 |
+
dim: int = 4096,
|
603 |
+
n_layers: int = 32,
|
604 |
+
n_heads: int = 32,
|
605 |
+
n_kv_heads: Optional[int] = None,
|
606 |
+
multiple_of: int = 256,
|
607 |
+
ffn_dim_multiplier: Optional[float] = None,
|
608 |
+
norm_eps: float = 1e-5,
|
609 |
+
learn_sigma: bool = True,
|
610 |
+
qk_norm: bool = False,
|
611 |
+
cap_feat_dim: int = 5120,
|
612 |
+
rope_scaling_factor: float = 1.,
|
613 |
+
ntk_factor: float=1.
|
614 |
+
) -> None:
|
615 |
+
super().__init__()
|
616 |
+
self.learn_sigma = learn_sigma
|
617 |
+
self.in_channels = in_channels
|
618 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
619 |
+
self.patch_size = patch_size
|
620 |
+
|
621 |
+
self.x_embedder = ColumnParallelLinear(
|
622 |
+
in_features=patch_size * patch_size * in_channels,
|
623 |
+
out_features=dim,
|
624 |
+
bias=True,
|
625 |
+
gather_output=True,
|
626 |
+
init_method=nn.init.xavier_uniform_,
|
627 |
+
)
|
628 |
+
nn.init.constant_(self.x_embedder.bias, 0.)
|
629 |
+
|
630 |
+
self.t_embedder = ParallelTimestepEmbedder(min(dim, 1024))
|
631 |
+
self.cap_embedder = nn.Sequential(
|
632 |
+
nn.LayerNorm(cap_feat_dim),
|
633 |
+
ColumnParallelLinear(
|
634 |
+
cap_feat_dim, min(dim, 1024), bias=True, gather_output=True,
|
635 |
+
init_method=nn.init.zeros_
|
636 |
+
),
|
637 |
+
)
|
638 |
+
|
639 |
+
self.layers = nn.ModuleList([
|
640 |
+
TransformerBlock(layer_id, dim, n_heads, n_kv_heads, multiple_of,
|
641 |
+
ffn_dim_multiplier, norm_eps, qk_norm, cap_feat_dim)
|
642 |
+
for layer_id in range(n_layers)
|
643 |
+
])
|
644 |
+
self.final_layer = ParallelFinalLayer(dim, patch_size, self.out_channels)
|
645 |
+
|
646 |
+
self.freqs_cis = DiT_Llama.precompute_freqs_cis(
|
647 |
+
dim // n_heads, 40000, rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
|
648 |
+
)
|
649 |
+
self.dim = dim
|
650 |
+
self.n_heads = n_heads
|
651 |
+
self.rope_scaling_factor = rope_scaling_factor
|
652 |
+
self.ntk_factor = ntk_factor
|
653 |
+
self.eol_token = nn.Parameter(torch.empty(dim))
|
654 |
+
self.pad_token = nn.Parameter(torch.empty(dim))
|
655 |
+
nn.init.normal_(self.eol_token, std=0.02)
|
656 |
+
nn.init.normal_(self.pad_token, std=0.02)
|
657 |
+
|
658 |
+
def unpatchify(self, x: torch.Tensor, img_size: List[Tuple[int, int]], return_tensor=False) -> List[torch.Tensor]:
|
659 |
+
"""
|
660 |
+
x: (N, T, patch_size**2 * C)
|
661 |
+
imgs: (N, H, W, C)
|
662 |
+
"""
|
663 |
+
pH = pW = self.patch_size
|
664 |
+
if return_tensor:
|
665 |
+
H, W = img_size[0]
|
666 |
+
B = x.size(0)
|
667 |
+
L = (H // pH) * (W // pW + 1) # one additional for eol
|
668 |
+
x = x[:, :L].view(B, H // pH, W // pW + 1, pH, pW, self.out_channels)
|
669 |
+
x = x[:, :, :-1]
|
670 |
+
x = x.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
|
671 |
+
return x
|
672 |
+
else:
|
673 |
+
imgs = []
|
674 |
+
for i in range(x.size(0)):
|
675 |
+
H, W = img_size[i]
|
676 |
+
L = (H // pH) * (W // pW + 1)
|
677 |
+
imgs.append(x[i][:L].view(
|
678 |
+
H // pH, W // pW + 1, pH, pW, self.out_channels
|
679 |
+
)[:, :-1, :, :, :].permute(4, 0, 2, 1, 3).flatten(3, 4).flatten(1, 2))
|
680 |
+
return imgs
|
681 |
+
|
682 |
+
def patchify_and_embed(
|
683 |
+
self,
|
684 |
+
x: List[torch.Tensor] | torch.Tensor
|
685 |
+
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]]]:
|
686 |
+
if isinstance(x, torch.Tensor):
|
687 |
+
pH = pW = self.patch_size
|
688 |
+
B, C, H, W = x.size()
|
689 |
+
x = x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 1, 3, 5).flatten(3)
|
690 |
+
x = self.x_embedder(x)
|
691 |
+
x = torch.cat([
|
692 |
+
x,
|
693 |
+
self.eol_token.view(1, 1, 1, -1).expand(B, H // pH, 1, -1),
|
694 |
+
], dim=2)
|
695 |
+
x = x.flatten(1, 2)
|
696 |
+
|
697 |
+
mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
|
698 |
+
return x, mask, [(H, W)] * B
|
699 |
+
else:
|
700 |
+
pH = pW = self.patch_size
|
701 |
+
x_embed = []
|
702 |
+
img_size = []
|
703 |
+
l_effective_seq_len = []
|
704 |
+
|
705 |
+
for img in x:
|
706 |
+
C, H, W = img.size()
|
707 |
+
img_size.append((H, W))
|
708 |
+
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 0, 2, 4).flatten(2)
|
709 |
+
img = self.x_embedder(img)
|
710 |
+
img = torch.cat([
|
711 |
+
img,
|
712 |
+
self.eol_token.view(1, 1, -1).expand(H // pH, 1, -1),
|
713 |
+
], dim=1)
|
714 |
+
img = img.flatten(0, 1)
|
715 |
+
l_effective_seq_len.append(len(img))
|
716 |
+
x_embed.append(img)
|
717 |
+
|
718 |
+
max_seq_len = max(l_effective_seq_len)
|
719 |
+
mask = torch.zeros(len(x), max_seq_len, dtype=torch.int32, device=x[0].device)
|
720 |
+
padded_x_embed = []
|
721 |
+
for i, (item_embed, item_seq_len) in enumerate(zip(x_embed, l_effective_seq_len)):
|
722 |
+
item_embed = torch.cat([
|
723 |
+
item_embed,
|
724 |
+
self.pad_token.view(1, -1).expand(max_seq_len - item_seq_len, -1),
|
725 |
+
], dim=0)
|
726 |
+
padded_x_embed.append(item_embed)
|
727 |
+
mask[i][:item_seq_len] = 1
|
728 |
+
|
729 |
+
x_embed = torch.stack(padded_x_embed, dim=0)
|
730 |
+
return x_embed, mask, img_size
|
731 |
+
|
732 |
+
def forward(self, x, t, cap_feats, cap_mask):
|
733 |
+
"""
|
734 |
+
Forward pass of DiT.
|
735 |
+
t: (N,) tensor of diffusion timesteps
|
736 |
+
y: (N,) tensor of class labels
|
737 |
+
"""
|
738 |
+
x_is_tensor = isinstance(x, torch.Tensor)
|
739 |
+
x, mask, img_size = self.patchify_and_embed(x)
|
740 |
+
self.freqs_cis = self.freqs_cis.to(x.device)
|
741 |
+
|
742 |
+
t = self.t_embedder(t) # (N, D)
|
743 |
+
cap_mask_float = cap_mask.float().unsqueeze(-1)
|
744 |
+
cap_feats_pool = (cap_feats * cap_mask_float).sum(dim=1) / cap_mask_float.sum(dim=1)
|
745 |
+
cap_feats_pool = cap_feats_pool.to(cap_feats)
|
746 |
+
cap_emb = self.cap_embedder(cap_feats_pool)
|
747 |
+
adaln_input = t + cap_emb
|
748 |
+
|
749 |
+
cap_mask = cap_mask.bool()
|
750 |
+
for layer in self.layers:
|
751 |
+
x = layer(
|
752 |
+
x, mask, cap_feats, cap_mask, self.freqs_cis[:x.size(1)],
|
753 |
+
adaln_input=adaln_input
|
754 |
+
)
|
755 |
+
|
756 |
+
x = self.final_layer(x, adaln_input)
|
757 |
+
x = self.unpatchify(x, img_size, return_tensor=x_is_tensor)
|
758 |
+
if self.learn_sigma:
|
759 |
+
if x_is_tensor:
|
760 |
+
x, _ = x.chunk(2, dim=1)
|
761 |
+
else:
|
762 |
+
x = [_.chunk(2, dim=0)[0] for _ in x]
|
763 |
+
return x
|
764 |
+
|
765 |
+
def forward_with_cfg(
|
766 |
+
self,
|
767 |
+
x,
|
768 |
+
t,
|
769 |
+
cap_feats,
|
770 |
+
cap_mask,
|
771 |
+
cfg_scale,
|
772 |
+
rope_scaling_factor=None,
|
773 |
+
ntk_factor=None,
|
774 |
+
base_seqlen: Optional[int] = None,
|
775 |
+
proportional_attn: bool = False
|
776 |
+
):
|
777 |
+
"""
|
778 |
+
Forward pass of DiT, but also batches the unconditional forward pass
|
779 |
+
for classifier-free guidance.
|
780 |
+
"""
|
781 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
782 |
+
|
783 |
+
if rope_scaling_factor is not None or ntk_factor is not None:
|
784 |
+
rope_scaling_factor = rope_scaling_factor if rope_scaling_factor is not None else self.rope_scaling_factor
|
785 |
+
ntk_factor = ntk_factor if ntk_factor is not None else self.ntk_factor
|
786 |
+
if rope_scaling_factor != self.rope_scaling_factor or ntk_factor != self.ntk_factor:
|
787 |
+
print(f"override freqs_cis, rope_scaling {rope_scaling_factor}, ntk {ntk_factor}", flush=True)
|
788 |
+
self.freqs_cis = DiT_Llama.precompute_freqs_cis(
|
789 |
+
self.dim // self.n_heads, 40000,
|
790 |
+
rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
|
791 |
+
)
|
792 |
+
self.rope_scaling_factor = rope_scaling_factor
|
793 |
+
self.ntk_factor = ntk_factor
|
794 |
+
|
795 |
+
if proportional_attn:
|
796 |
+
assert base_seqlen is not None
|
797 |
+
for layer in self.layers:
|
798 |
+
layer.attention.base_seqlen = base_seqlen
|
799 |
+
layer.attention.proportional_attn = proportional_attn
|
800 |
+
else:
|
801 |
+
for layer in self.layers:
|
802 |
+
layer.attention.base_seqlen = None
|
803 |
+
layer.attention.proportional_attn = proportional_attn
|
804 |
+
|
805 |
+
half = x[: len(x) // 2]
|
806 |
+
combined = torch.cat([half, half], dim=0)
|
807 |
+
model_out = self(combined, t, cap_feats, cap_mask)
|
808 |
+
# For exact reproducibility reasons, we apply classifier-free guidance on only
|
809 |
+
# three channels by default. The standard approach to cfg applies it to all channels.
|
810 |
+
# This can be done by uncommenting the following line and commenting-out the line following that.
|
811 |
+
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
812 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
813 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
814 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
815 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
816 |
+
return torch.cat([eps, rest], dim=1)
|
817 |
+
|
818 |
+
@staticmethod
|
819 |
+
def precompute_freqs_cis(
|
820 |
+
dim: int,
|
821 |
+
end: int,
|
822 |
+
theta: float = 10000.0,
|
823 |
+
rope_scaling_factor: float = 1.0,
|
824 |
+
ntk_factor: float = 1.0
|
825 |
+
):
|
826 |
+
"""
|
827 |
+
Precompute the frequency tensor for complex exponentials (cis) with
|
828 |
+
given dimensions.
|
829 |
+
|
830 |
+
This function calculates a frequency tensor with complex exponentials
|
831 |
+
using the given dimension 'dim' and the end index 'end'. The 'theta'
|
832 |
+
parameter scales the frequencies. The returned tensor contains complex
|
833 |
+
values in complex64 data type.
|
834 |
+
|
835 |
+
Args:
|
836 |
+
dim (int): Dimension of the frequency tensor.
|
837 |
+
end (int): End index for precomputing frequencies.
|
838 |
+
theta (float, optional): Scaling factor for frequency computation.
|
839 |
+
Defaults to 10000.0.
|
840 |
+
|
841 |
+
Returns:
|
842 |
+
torch.Tensor: Precomputed frequency tensor with complex
|
843 |
+
exponentials.
|
844 |
+
"""
|
845 |
+
|
846 |
+
theta = theta * ntk_factor
|
847 |
+
|
848 |
+
print(f"theta {theta} rope scaling {rope_scaling_factor} ntk {ntk_factor}")
|
849 |
+
freqs = 1.0 / (theta ** (
|
850 |
+
torch.arange(0, dim, 2)[: (dim // 2)].float().cuda() / dim
|
851 |
+
))
|
852 |
+
t = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore
|
853 |
+
t = t / rope_scaling_factor
|
854 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
855 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
856 |
+
return freqs_cis
|
857 |
+
|
858 |
+
def parameter_count(self) -> int:
|
859 |
+
tensor_parallel_module_list = (
|
860 |
+
ColumnParallelLinear, RowParallelLinear, ParallelEmbedding,
|
861 |
+
)
|
862 |
+
total_params = 0
|
863 |
+
|
864 |
+
def _recursive_count_params(module):
|
865 |
+
nonlocal total_params
|
866 |
+
is_tp_module = isinstance(module, tensor_parallel_module_list)
|
867 |
+
for param in module.parameters(recurse=False):
|
868 |
+
total_params += param.numel() * (
|
869 |
+
fs_init.get_model_parallel_world_size()
|
870 |
+
if is_tp_module else 1
|
871 |
+
)
|
872 |
+
for submodule in module.children():
|
873 |
+
_recursive_count_params(submodule)
|
874 |
+
|
875 |
+
_recursive_count_params(self)
|
876 |
+
return total_params
|
877 |
+
|
878 |
+
def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
|
879 |
+
return list(self.layers)
|
880 |
+
|
881 |
+
|
882 |
+
#############################################################################
|
883 |
+
# DiT Configs #
|
884 |
+
#############################################################################
|
885 |
+
|
886 |
+
def DiT_Llama_2B_patch2(**kwargs):
|
887 |
+
return DiT_Llama(
|
888 |
+
patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs
|
889 |
+
)
|
890 |
+
|
891 |
+
def DiT_Llama_5B_patch2(**kwargs):
|
892 |
+
return DiT_Llama(
|
893 |
+
patch_size=2, dim=3072, n_layers=32, n_heads=32, **kwargs
|
894 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
diffusers
|
3 |
+
huggingface_hub
|
4 |
+
gradio
|
5 |
+
torch
|
6 |
+
# torch==2.2.2+cu121
|
7 |
+
fairscale
|
8 |
+
numpy
|
9 |
+
pillow
|
10 |
+
torchdiffeq
|
11 |
+
click
|
12 |
+
git+https://github.com/Alpha-VLLM/Lumina-T2X
|