heziiiii commited on
Commit
88c74ea
1 Parent(s): c26472b

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +445 -0
inference.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import time
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ # For reproducibility
9
+ # torch.backends.cudnn.benchmark = False
10
+ # torch.backends.cudnn.deterministic = True
11
+
12
+ from diffusers import schedulers
13
+ from diffusers.models import AutoencoderKL
14
+ from loguru import logger
15
+ from transformers import BertModel, BertTokenizer
16
+ from transformers.modeling_utils import logger as tf_logger
17
+
18
+ from .constants import SAMPLER_FACTORY, NEGATIVE_PROMPT, TRT_MAX_WIDTH, TRT_MAX_HEIGHT, TRT_MAX_BATCH_SIZE
19
+ from .diffusion.pipeline import StableDiffusionPipeline
20
+ from .modules.models import HunYuanDiT, HUNYUAN_DIT_CONFIG
21
+ from .modules.posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
22
+ from .modules.text_encoder import MT5Embedder
23
+ from .utils.tools import set_seeds
24
+ from peft import LoraConfig
25
+
26
+
27
+ class Resolution:
28
+ def __init__(self, width, height):
29
+ self.width = width
30
+ self.height = height
31
+
32
+ def __str__(self):
33
+ return f'{self.height}x{self.width}'
34
+
35
+
36
+ class ResolutionGroup:
37
+ def __init__(self):
38
+ self.data = [
39
+ Resolution(1024, 1024), # 1:1
40
+ Resolution(1280, 1280), # 1:1
41
+ Resolution(1024, 768), # 4:3
42
+ Resolution(1152, 864), # 4:3
43
+ Resolution(1280, 960), # 4:3
44
+ Resolution(768, 1024), # 3:4
45
+ Resolution(864, 1152), # 3:4
46
+ Resolution(960, 1280), # 3:4
47
+ Resolution(1280, 768), # 16:9
48
+ Resolution(768, 1280), # 9:16
49
+ ]
50
+ self.supported_sizes = set([(r.width, r.height) for r in self.data])
51
+
52
+ def is_valid(self, width, height):
53
+ return (width, height) in self.supported_sizes
54
+
55
+
56
+ STANDARD_RATIO = np.array([
57
+ 1.0, # 1:1
58
+ 4.0 / 3.0, # 4:3
59
+ 3.0 / 4.0, # 3:4
60
+ 16.0 / 9.0, # 16:9
61
+ 9.0 / 16.0, # 9:16
62
+ ])
63
+ STANDARD_SHAPE = [
64
+ [(1024, 1024), (1280, 1280)], # 1:1
65
+ [(1280, 960)], # 4:3
66
+ [(960, 1280)], # 3:4
67
+ [(1280, 768)], # 16:9
68
+ [(768, 1280)], # 9:16
69
+ ]
70
+ STANDARD_AREA = [
71
+ np.array([w * h for w, h in shapes])
72
+ for shapes in STANDARD_SHAPE
73
+ ]
74
+
75
+
76
+ def get_standard_shape(target_width, target_height):
77
+ """
78
+ Map image size to standard size.
79
+ """
80
+ target_ratio = target_width / target_height
81
+ closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
82
+ closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
83
+ width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
84
+ return width, height
85
+
86
+
87
+ def _to_tuple(val):
88
+ if isinstance(val, (list, tuple)):
89
+ if len(val) == 1:
90
+ val = [val[0], val[0]]
91
+ elif len(val) == 2:
92
+ val = tuple(val)
93
+ else:
94
+ raise ValueError(f"Invalid value: {val}")
95
+ elif isinstance(val, (int, float)):
96
+ val = (val, val)
97
+ else:
98
+ raise ValueError(f"Invalid value: {val}")
99
+ return val
100
+
101
+
102
+ def get_pipeline(args, vae, text_encoder, tokenizer, model, device, rank,
103
+ embedder_t5, infer_mode, sampler=None):
104
+ """
105
+ Get scheduler and pipeline for sampling. The sampler and pipeline are both
106
+ based on diffusers and make some modifications.
107
+
108
+ Returns
109
+ -------
110
+ pipeline: StableDiffusionPipeline
111
+ sampler_name: str
112
+ """
113
+ sampler = sampler or args.sampler
114
+
115
+ # Load sampler from factory
116
+ kwargs = SAMPLER_FACTORY[sampler]['kwargs']
117
+ scheduler = SAMPLER_FACTORY[sampler]['scheduler']
118
+
119
+ # Update sampler according to the arguments
120
+ kwargs['beta_schedule'] = args.noise_schedule
121
+ kwargs['beta_start'] = args.beta_start
122
+ kwargs['beta_end'] = args.beta_end
123
+ kwargs['prediction_type'] = args.predict_type
124
+
125
+ # Build scheduler according to the sampler.
126
+ scheduler_class = getattr(schedulers, scheduler)
127
+ scheduler = scheduler_class(**kwargs)
128
+
129
+ # Set timesteps for inference steps.
130
+ scheduler.set_timesteps(args.infer_steps, device)
131
+
132
+ # Only enable progress bar for rank 0
133
+ progress_bar_config = {} if rank == 0 else {'disable': True}
134
+
135
+ pipeline = StableDiffusionPipeline(vae=vae,
136
+ text_encoder=text_encoder,
137
+ tokenizer=tokenizer,
138
+ unet=model,
139
+ scheduler=scheduler,
140
+ feature_extractor=None,
141
+ safety_checker=None,
142
+ requires_safety_checker=False,
143
+ progress_bar_config=progress_bar_config,
144
+ embedder_t5=embedder_t5,
145
+ infer_mode=infer_mode,
146
+ )
147
+
148
+ pipeline = pipeline.to(device)
149
+
150
+ return pipeline, sampler
151
+
152
+
153
+ class End2End(object):
154
+ def __init__(self, args, models_root_path):
155
+ self.args = args
156
+
157
+ # Check arguments
158
+ t2i_root_path = Path(models_root_path) / "t2i"
159
+ self.root = t2i_root_path
160
+ logger.info(f"Got text-to-image model root path: {t2i_root_path}")
161
+
162
+ # Set device and disable gradient
163
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
164
+ torch.set_grad_enabled(False)
165
+ # Disable BertModel logging checkpoint info
166
+ tf_logger.setLevel('ERROR')
167
+
168
+ # ========================================================================
169
+ logger.info(f"Loading CLIP Text Encoder...")
170
+ text_encoder_path = self.root / "clip_text_encoder"
171
+ self.clip_text_encoder = BertModel.from_pretrained(str(text_encoder_path), False, revision=None).to(self.device)
172
+ logger.info(f"Loading CLIP Text Encoder finished")
173
+
174
+ # ========================================================================
175
+ logger.info(f"Loading CLIP Tokenizer...")
176
+ tokenizer_path = self.root / "tokenizer"
177
+ self.tokenizer = BertTokenizer.from_pretrained(str(tokenizer_path))
178
+ logger.info(f"Loading CLIP Tokenizer finished")
179
+
180
+ # ========================================================================
181
+ logger.info(f"Loading T5 Text Encoder and T5 Tokenizer...")
182
+ t5_text_encoder_path = self.root / 'mt5'
183
+ embedder_t5 = MT5Embedder(t5_text_encoder_path, torch_dtype=torch.float16, max_length=256)
184
+ self.embedder_t5 = embedder_t5
185
+ logger.info(f"Loading t5_text_encoder and t5_tokenizer finished")
186
+
187
+ # ========================================================================
188
+ logger.info(f"Loading VAE...")
189
+ vae_path = self.root / "sdxl-vae-fp16-fix"
190
+ self.vae = AutoencoderKL.from_pretrained(str(vae_path)).to(self.device)
191
+ logger.info(f"Loading VAE finished")
192
+
193
+ # ========================================================================
194
+ # Create model structure and load the checkpoint
195
+ logger.info(f"Building HunYuan-DiT model...")
196
+ model_config = HUNYUAN_DIT_CONFIG[self.args.model]
197
+ self.patch_size = model_config['patch_size']
198
+ self.head_size = model_config['hidden_size'] // model_config['num_heads']
199
+ self.resolutions, self.freqs_cis_img = self.standard_shapes() # Used for TensorRT models
200
+ self.image_size = _to_tuple(self.args.image_size)
201
+ latent_size = (self.image_size[0] // 8, self.image_size[1] // 8)
202
+
203
+ self.infer_mode = self.args.infer_mode
204
+ if self.infer_mode in ['fa', 'torch']:
205
+
206
+ # # for trained pt
207
+ # model_path = Path("/home1/qbs/my_program1/HunyuanDiT/log_EXP/024-dit_g2_full_1024p/checkpoints/0100000.pt/mp_rank_00_model_states.pt")
208
+ # if not model_path.exists():
209
+ # raise ValueError(f"model_path not exists: {model_path}")
210
+ # # Build model structure
211
+ # self.model = HunYuanDiT(self.args,
212
+ # input_size=latent_size,
213
+ # **model_config,
214
+ # log_fn=logger.info,
215
+ # ).half().to(self.device) # Force to use fp16
216
+ # # Load model checkpoint
217
+ # logger.info(f"Loading torch model {model_path}...")
218
+ # state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
219
+ # self.model.load_state_dict(state_dict["module"])
220
+
221
+ # for ema trained pt
222
+ model_path = Path("/home1/qbs/my_program1/HunyuanDiT/log_EXP/027-dit_g2_full_1024p/checkpoints/latest.pt/mp_rank_00_model_states.pt")
223
+ if not model_path.exists():
224
+ raise ValueError(f"model_path not exists: {model_path}")
225
+ # Build model structure
226
+ self.model = HunYuanDiT(self.args,
227
+ input_size=latent_size,
228
+ **model_config,
229
+ log_fn=logger.info,
230
+ ).half().to(self.device) # Force to use fp16
231
+ # Load model checkpoint
232
+ logger.info(f"Loading torch model {model_path}...")
233
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
234
+ self.model.load_state_dict(state_dict["ema"])
235
+
236
+ # #original
237
+ # model_dir = self.root / "model"
238
+ # model_path = model_dir / f"pytorch_model_{self.args.load_key}.pt"
239
+ # if not model_path.exists():
240
+ # raise ValueError(f"model_path not exists: {model_path}")
241
+ # # Build model structure
242
+ # self.model = HunYuanDiT(self.args,
243
+ # input_size=latent_size,
244
+ # **model_config,
245
+ # log_fn=logger.info,
246
+ # ).half().to(self.device) # Force to use fp16
247
+ # # Load model checkpoint
248
+ # logger.info(f"Loading torch model {model_path}...")
249
+ # state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
250
+ # self.model.load_state_dict(state_dict)
251
+
252
+ lora_ckpt = args.lora_ckpt
253
+ if lora_ckpt is not None and lora_ckpt != "":
254
+ logger.info(f"Loading Lora checkpoint {lora_ckpt}...")
255
+
256
+ self.model.load_adapter(lora_ckpt)
257
+ self.model.merge_and_unload()
258
+
259
+
260
+ self.model.eval()
261
+ logger.info(f"Loading torch model finished")
262
+ elif self.infer_mode == 'trt':
263
+ from .modules.trt.hcf_model import TRTModel
264
+
265
+ trt_dir = self.root / "model_trt"
266
+ engine_dir = trt_dir / "engine"
267
+ plugin_path = trt_dir / "fmha_plugins/9.2_plugin_cuda11/fMHAPlugin.so"
268
+ model_name = "model_onnx"
269
+
270
+ logger.info(f"Loading TensorRT model {engine_dir}/{model_name}...")
271
+ self.model = TRTModel(model_name=model_name,
272
+ engine_dir=str(engine_dir),
273
+ image_height=TRT_MAX_HEIGHT,
274
+ image_width=TRT_MAX_WIDTH,
275
+ text_maxlen=args.text_len,
276
+ embedding_dim=args.text_states_dim,
277
+ plugin_path=str(plugin_path),
278
+ max_batch_size=TRT_MAX_BATCH_SIZE,
279
+ )
280
+ logger.info(f"Loading TensorRT model finished")
281
+ else:
282
+ raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
283
+
284
+ # ========================================================================
285
+ # Build inference pipeline. We use a customized StableDiffusionPipeline.
286
+ logger.info(f"Loading inference pipeline...")
287
+ self.pipeline, self.sampler = self.load_sampler()
288
+ logger.info(f'Loading pipeline finished')
289
+
290
+ # ========================================================================
291
+ self.default_negative_prompt = NEGATIVE_PROMPT
292
+ logger.info("==================================================")
293
+ logger.info(f" Model is ready. ")
294
+ logger.info("==================================================")
295
+
296
+ def load_sampler(self, sampler=None):
297
+ pipeline, sampler = get_pipeline(self.args,
298
+ self.vae,
299
+ self.clip_text_encoder,
300
+ self.tokenizer,
301
+ self.model,
302
+ device=self.device,
303
+ rank=0,
304
+ embedder_t5=self.embedder_t5,
305
+ infer_mode=self.infer_mode,
306
+ sampler=sampler,
307
+ )
308
+ return pipeline, sampler
309
+
310
+ def calc_rope(self, height, width):
311
+ th = height // 8 // self.patch_size
312
+ tw = width // 8 // self.patch_size
313
+ base_size = 512 // 8 // self.patch_size
314
+ start, stop = get_fill_resize_and_crop((th, tw), base_size)
315
+ sub_args = [start, stop, (th, tw)]
316
+ rope = get_2d_rotary_pos_embed(self.head_size, *sub_args)
317
+ return rope
318
+
319
+ def standard_shapes(self):
320
+ resolutions = ResolutionGroup()
321
+ freqs_cis_img = {}
322
+ for reso in resolutions.data:
323
+ freqs_cis_img[str(reso)] = self.calc_rope(reso.height, reso.width)
324
+ return resolutions, freqs_cis_img
325
+
326
+ def predict(self,
327
+ user_prompt,
328
+ height=1024,
329
+ width=1024,
330
+ seed=None,
331
+ enhanced_prompt=None,
332
+ negative_prompt=None,
333
+ infer_steps=100,
334
+ guidance_scale=6,
335
+ batch_size=1,
336
+ src_size_cond=(1024, 1024),
337
+ sampler=None,
338
+ ):
339
+ # ========================================================================
340
+ # Arguments: seed
341
+ # ========================================================================
342
+ if seed is None:
343
+ seed = random.randint(0, 1_000_000)
344
+ if not isinstance(seed, int):
345
+ raise TypeError(f"`seed` must be an integer, but got {type(seed)}")
346
+ generator = set_seeds(seed, device=self.device)
347
+ # ========================================================================
348
+ # Arguments: target_width, target_height
349
+ # ========================================================================
350
+ if width <= 0 or height <= 0:
351
+ raise ValueError(f"`height` and `width` must be positive integers, got height={height}, width={width}")
352
+ logger.info(f"Input (height, width) = ({height}, {width})")
353
+ if self.infer_mode in ['fa', 'torch']:
354
+ # We must force height and width to align to 16 and to be an integer.
355
+ target_height = int((height // 16) * 16)
356
+ target_width = int((width // 16) * 16)
357
+ logger.info(f"Align to 16: (height, width) = ({target_height}, {target_width})")
358
+ elif self.infer_mode == 'trt':
359
+ target_width, target_height = get_standard_shape(width, height)
360
+ logger.info(f"Align to standard shape: (height, width) = ({target_height}, {target_width})")
361
+ else:
362
+ raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
363
+
364
+ # ========================================================================
365
+ # Arguments: prompt, new_prompt, negative_prompt
366
+ # ========================================================================
367
+ if not isinstance(user_prompt, str):
368
+ raise TypeError(f"`user_prompt` must be a string, but got {type(user_prompt)}")
369
+ user_prompt = user_prompt.strip()
370
+ prompt = user_prompt
371
+
372
+ if enhanced_prompt is not None:
373
+ if not isinstance(enhanced_prompt, str):
374
+ raise TypeError(f"`enhanced_prompt` must be a string, but got {type(enhanced_prompt)}")
375
+ enhanced_prompt = enhanced_prompt.strip()
376
+ prompt = enhanced_prompt
377
+
378
+ # negative prompt
379
+ if negative_prompt is None or negative_prompt == '':
380
+ negative_prompt = self.default_negative_prompt
381
+ if not isinstance(negative_prompt, str):
382
+ raise TypeError(f"`negative_prompt` must be a string, but got {type(negative_prompt)}")
383
+
384
+ # ========================================================================
385
+ # Arguments: style. (A fixed argument. Don't Change it.)
386
+ # ========================================================================
387
+ style = torch.as_tensor([0, 0] * batch_size, device=self.device)
388
+
389
+ # ========================================================================
390
+ # Inner arguments: image_meta_size (Please refer to SDXL.)
391
+ # ========================================================================
392
+ if isinstance(src_size_cond, int):
393
+ src_size_cond = [src_size_cond, src_size_cond]
394
+ if not isinstance(src_size_cond, (list, tuple)):
395
+ raise TypeError(f"`src_size_cond` must be a list or tuple, but got {type(src_size_cond)}")
396
+ if len(src_size_cond) != 2:
397
+ raise ValueError(f"`src_size_cond` must be a tuple of 2 integers, but got {len(src_size_cond)}")
398
+ size_cond = list(src_size_cond) + [target_width, target_height, 0, 0]
399
+ image_meta_size = torch.as_tensor([size_cond] * 2 * batch_size, device=self.device)
400
+
401
+ # ========================================================================
402
+ start_time = time.time()
403
+ logger.debug(f"""
404
+ prompt: {user_prompt}
405
+ enhanced prompt: {enhanced_prompt}
406
+ seed: {seed}
407
+ (height, width): {(target_height, target_width)}
408
+ negative_prompt: {negative_prompt}
409
+ batch_size: {batch_size}
410
+ guidance_scale: {guidance_scale}
411
+ infer_steps: {infer_steps}
412
+ image_meta_size: {size_cond}
413
+ """)
414
+ reso = f'{target_height}x{target_width}'
415
+ if reso in self.freqs_cis_img:
416
+ freqs_cis_img = self.freqs_cis_img[reso]
417
+ else:
418
+ freqs_cis_img = self.calc_rope(target_height, target_width)
419
+
420
+ if sampler is not None and sampler != self.sampler:
421
+ self.pipeline, self.sampler = self.load_sampler(sampler)
422
+
423
+ samples = self.pipeline(
424
+ height=target_height,
425
+ width=target_width,
426
+ prompt=prompt,
427
+ negative_prompt=negative_prompt,
428
+ num_images_per_prompt=batch_size,
429
+ guidance_scale=guidance_scale,
430
+ num_inference_steps=infer_steps,
431
+ image_meta_size=image_meta_size,
432
+ style=style,
433
+ return_dict=False,
434
+ generator=generator,
435
+ freqs_cis_img=freqs_cis_img,
436
+ use_fp16=self.args.use_fp16,
437
+ learn_sigma=self.args.learn_sigma,
438
+ )[0]
439
+ gen_time = time.time() - start_time
440
+ logger.debug(f"Success, time: {gen_time}")
441
+
442
+ return {
443
+ 'images': samples,
444
+ 'seed': seed,
445
+ }