xinlai commited on
Commit
674d663
1 Parent(s): 72b53c6
.project-root ADDED
File without changes
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: SEED Story
3
- emoji: 📚
4
  colorFrom: blue
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.39.0
8
  app_file: app.py
 
1
  ---
2
+ title: SEED Story George
3
+ emoji: 🌍
4
  colorFrom: blue
5
+ colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.39.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,750 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import datetime
4
+ import json
5
+ from typing import Optional
6
+ import transformers
7
+ from dataclasses import dataclass, field
8
+ import io
9
+ import spaces
10
+ import base64
11
+ from PIL import Image
12
+ import gradio as gr
13
+ import time
14
+ import hashlib
15
+
16
+ from utils import build_logger
17
+ from conversation import conv_seed_llama2
18
+
19
+ import hydra
20
+ import pyrootutils
21
+ import torch
22
+ import re
23
+ import time
24
+ from omegaconf import OmegaConf
25
+ from flask import Flask
26
+ import json
27
+ from typing import Optional
28
+ import cv2
29
+ from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler, StableDiffusionImg2ImgPipeline
30
+
31
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
32
+
33
+ from src.data.any_res import process_anyres_image
34
+
35
+ BOI_TOKEN = '<img>'
36
+ BOP_TOKEN = '<patch>'
37
+ EOI_TOKEN = '</img>'
38
+ EOP_TOKEN = '</patch>'
39
+ IMG_TOKEN = '<img_{:05d}>'
40
+
41
+ IMG_FLAG = '<image>'
42
+ num_img_in_tokens = 64
43
+ num_img_out_tokens = 64
44
+
45
+ resolution_grids = ['1x1', '1x2', '1x3', '1x4', '1x5', '1x6', '1x10', '2x1', '3x1', '4x1', '5x1', '6x1', '10x1', '2x2',
46
+ '2x3', '3x2', '2x4', '4x2']
47
+ base_resolution = 448
48
+
49
+ app = Flask(__name__)
50
+
51
+
52
+ def decode_image(encoded_image: str) -> Image:
53
+ decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
54
+ buffer = io.BytesIO(decoded_bytes)
55
+ image = Image.open(buffer)
56
+ return image
57
+
58
+
59
+ def encode_image(image: Image.Image, format: str = 'PNG') -> str:
60
+ with io.BytesIO() as buffer:
61
+ image.save(buffer, format=format)
62
+ encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
63
+ return encoded_image
64
+
65
+
66
+ @dataclass
67
+ class Arguments:
68
+ image_transform: Optional[str] = field(default='configs/processer/qwen_448_transform.yaml',
69
+ metadata={"help": "config path of image transform"})
70
+ tokenizer: Optional[str] = field(default='configs/tokenizer/clm_llama_tokenizer.yaml',
71
+ metadata={"help": "config path of tokenizer used to initialize tokenizer"})
72
+ llm: Optional[str] = field(default='configs/clm_models/llama2chat7b_lora.yaml', metadata={"help": "config path of llm"})
73
+ visual_encoder: Optional[str] = field(default='configs/visual_tokenzier/qwen_vitg_448.yaml',
74
+ metadata={"help": "config path of visual encoder"})
75
+ sd_adapter: Optional[str] = field(
76
+ default='configs/detokenizer/detokenizer_sdxl_qwen_vit_adapted.yaml',
77
+ metadata={"help": "config path of sd adapter"})
78
+ agent: Optional[str] = field(default='configs/clm_models/agent_7b_sft.yaml',
79
+ metadata={"help": "config path of agent model"})
80
+ diffusion_path: Optional[str] = field(default='stabilityai/stable-diffusion-xl-base-1.0',
81
+ metadata={"help": "diffusion model path"})
82
+ port: Optional[str] = field(default=80, metadata={"help": "network port"})
83
+ llm_device: Optional[str] = field(default='cuda:0', metadata={"help": "llm device"})
84
+ vit_sd_device: Optional[str] = field(default='cuda:0', metadata={"help": "sd and vit device"})
85
+ dtype: Optional[str] = field(default='fp16', metadata={"help": "mix percision"})
86
+
87
+
88
+ parser = transformers.HfArgumentParser(Arguments)
89
+ args, = parser.parse_args_into_dataclasses()
90
+
91
+
92
+ class LLMService:
93
+
94
+ def __init__(self, args) -> None:
95
+
96
+ self.llm_device = args.llm_device
97
+ self.vit_sd_device = args.vit_sd_device
98
+
99
+ dtype = args.dtype
100
+ if dtype == 'fp16':
101
+ self.dtype = torch.float16
102
+ elif dtype == 'bf16':
103
+ self.dtype = torch.bfloat16
104
+ else:
105
+ raise ValueError
106
+
107
+ image_transform_cfg = OmegaConf.load(args.image_transform)
108
+ self.image_transform = hydra.utils.instantiate(image_transform_cfg)
109
+
110
+ tokenizer_cfg = OmegaConf.load(args.tokenizer)
111
+ self.tokenizer = hydra.utils.instantiate(tokenizer_cfg)
112
+
113
+ visual_encoder_cfg = OmegaConf.load(args.visual_encoder)
114
+ self.visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
115
+ self.visual_encoder.eval().to(self.vit_sd_device, dtype=self.dtype)
116
+ print('Init visual encoder done')
117
+
118
+ llm_cfg = OmegaConf.load(args.llm)
119
+ llm = hydra.utils.instantiate(llm_cfg, torch_dtype=self.dtype)
120
+ print('Init llm done.')
121
+
122
+ agent_cfg = OmegaConf.load(args.agent)
123
+ self.agent = hydra.utils.instantiate(agent_cfg, llm=llm)
124
+
125
+ self.agent.eval().to(self.llm_device, dtype=self.dtype)
126
+ print('Init agent mdoel Done')
127
+
128
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.diffusion_path, subfolder="scheduler")
129
+
130
+ vae = AutoencoderKL.from_pretrained(args.diffusion_path, subfolder="vae").to(self.vit_sd_device,
131
+ dtype=self.dtype)
132
+
133
+ unet = UNet2DConditionModel.from_pretrained(args.diffusion_path, subfolder="unet").to(self.vit_sd_device,
134
+ dtype=self.dtype)
135
+
136
+ sd_adapter_cfg = OmegaConf.load(args.sd_adapter)
137
+
138
+ self.sd_adapter = hydra.utils.instantiate(sd_adapter_cfg, unet=unet).eval().to(self.vit_sd_device,
139
+ dtype=self.dtype)
140
+
141
+ # self.sd_adapter.init_pipe(vae=vae,
142
+ # scheduler=noise_scheduler,
143
+ # visual_encoder=self.visual_encoder.cpu(),
144
+ # image_transform=self.image_transform,
145
+ # discrete_model=None,
146
+ # dtype=self.dtype,
147
+ # device="cpu")
148
+
149
+ self.sd_adapter.init_pipe(vae=vae,
150
+ scheduler=noise_scheduler,
151
+ visual_encoder=self.visual_encoder,
152
+ image_transform=self.image_transform,
153
+ discrete_model=None,
154
+ dtype=self.dtype,
155
+ device=self.vit_sd_device)
156
+
157
+ print('Init sd adapter pipe done.')
158
+
159
+ self.visual_encoder.to(self.vit_sd_device, dtype=self.dtype)
160
+
161
+ model_id_or_path = "stablediffusionapi/realistic-vision-v51"
162
+ self.vae_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, safety_checker=None,
163
+ torch_dtype=torch.float16)
164
+ # self.vae_pipe = self.vae_pipe.to(self.vit_sd_device)
165
+
166
+ self.boi_token_id = self.tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
167
+ self.eoi_token_id = self.tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
168
+
169
+
170
+ service = LLMService(args)
171
+
172
+
173
+ @spaces.GPU
174
+ def generate(text_list, image_list, max_new_tokens, force_boi, force_bbox, force_polish):
175
+ with torch.no_grad():
176
+ text_list = text_list.split(IMG_FLAG)
177
+ top_p = 0.5
178
+ assert len(text_list) == len(image_list) + 1
179
+
180
+ image_tokens = BOI_TOKEN + ''.join(
181
+ [IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN
182
+
183
+ input_images = []
184
+ if len(image_list) > 0:
185
+ image_tensor_list = []
186
+ embeds_cmp_mask = []
187
+ embeds_gen_mask = []
188
+
189
+ if service.multi_resolution:
190
+ patch_pos = []
191
+ image_patch_length = []
192
+ image_size_list = []
193
+
194
+ for idx, image_item in enumerate(image_list):
195
+ if isinstance(image_item, str):
196
+ image = decode_image(image_item)
197
+ print('after decode image size:', image.size)
198
+ input_images.append(image)
199
+
200
+ # if service.multi_resolution:
201
+ # image_size_list.append(image.size)
202
+ # print('image size:', image.size)
203
+ # image_tensor, patch_pos_tensor = process_anyres_image(image, service.image_transform,
204
+ # service.grid_pinpoints,
205
+ # service.base_resolution)
206
+ # image_tensor_list.append(image_tensor)
207
+ # patch_pos.append(patch_pos_tensor)
208
+ # image_patch_length.append(image_tensor.shape[0])
209
+ # print('image_patch_length', image_patch_length)
210
+ # embeds_cmp_mask.extend([True] * image_tensor.shape[0])
211
+ # embeds_gen_mask.extend([False] * image_tensor.shape[0])
212
+ #
213
+ # else:
214
+ image_tensor = service.image_transform(image)
215
+ image_tensor_list.append(image_tensor)
216
+ embeds_cmp_mask.append(True)
217
+ embeds_gen_mask.append(False)
218
+ else:
219
+ raise ValueError
220
+
221
+ if service.multi_resolution:
222
+ pixel_values = torch.cat(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype)
223
+ patch_position = torch.cat(patch_pos, dim=0)
224
+
225
+ image_tokens_list = []
226
+ for patch_length in image_patch_length:
227
+ image_tokens = ''
228
+ for _ in range(patch_length - 1):
229
+ image_tokens += BOP_TOKEN + ''.join(
230
+ IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN
231
+ image_tokens += BOI_TOKEN + ''.join(
232
+ IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN
233
+ image_tokens_list.append(image_tokens)
234
+ else:
235
+ pixel_values = torch.stack(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype)
236
+
237
+ image_embeds = service.visual_encoder(pixel_values)
238
+ image_embeds = image_embeds.to(service.llm_device)
239
+
240
+ embeds_cmp_mask = torch.tensor(embeds_cmp_mask, dtype=torch.bool).to(service.llm_device)
241
+ embeds_gen_mask = torch.tensor(embeds_gen_mask, dtype=torch.bool).to(service.llm_device)
242
+
243
+ else:
244
+ image_embeds = None
245
+ patch_position = 0
246
+ embeds_cmp_mask = None
247
+ embeds_gen_mask = None
248
+
249
+ input_text = image_tokens.join(text_list)
250
+
251
+ print('input_text:', input_text)
252
+ input_ids = service.tokenizer.encode(input_text, add_special_tokens=False)
253
+ input_ids = [service.tokenizer.bos_token_id] + input_ids
254
+
255
+ input_ids = torch.tensor(input_ids).to(service.llm_device, dtype=torch.long)
256
+ ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device)
257
+ ids_gen_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device)
258
+
259
+ boi_indices = torch.where(input_ids == service.boi_token_id)[0].tolist()
260
+ eoi_indices = torch.where(input_ids == service.eoi_token_id)[0].tolist()
261
+
262
+ for boi_idx, eoi_idx in zip(boi_indices, eoi_indices):
263
+ ids_cmp_mask[boi_idx + 1:eoi_idx] = True
264
+
265
+ input_ids = input_ids.unsqueeze(0)
266
+ ids_cmp_mask = ids_cmp_mask.unsqueeze(0)
267
+ ids_gen_mask = ids_gen_mask.unsqueeze(0)
268
+
269
+ error_msg = []
270
+
271
+ output = service.agent.generate(
272
+ tokenizer=service.tokenizer,
273
+ input_ids=input_ids,
274
+ image_embeds=image_embeds,
275
+ embeds_cmp_mask=embeds_cmp_mask,
276
+ ids_cmp_mask=ids_cmp_mask,
277
+ num_img_gen_tokens=num_img_out_tokens,
278
+ max_new_tokens=max_new_tokens,
279
+ dtype=service.dtype,
280
+ device=service.llm_device,
281
+ top_p=top_p,
282
+ )
283
+
284
+ gen_imgs_base64_list = []
285
+ generated_text = output['text']
286
+ generated_text = generated_text.replace(EOI_TOKEN, IMG_FLAG).replace(service.tokenizer.eos_token, '')
287
+
288
+ torch.cuda.empty_cache()
289
+
290
+ if output['has_img_output']:
291
+ # print('loading visual encoder and llm to CPU, and sd to GPU')
292
+ # a = time.time()
293
+ # service.agent = service.agent.cpu()
294
+ # service.sd_adapter = service.sd_adapter.to(service.vit_sd_device, dtype=service.dtype)
295
+ # print("Loading finished: ", time.time() - a)
296
+
297
+ img_gen_feat = output['img_gen_feat'].to(service.vit_sd_device, dtype=service.dtype)
298
+
299
+ for img_idx in range(output['num_gen_imgs']):
300
+ img_feat = img_gen_feat[img_idx:img_idx + 1]
301
+ generated_image = service.sd_adapter.generate(image_embeds=img_feat, num_inference_steps=50)[0]
302
+
303
+ if force_polish:
304
+ # service.sd_adapter = service.sd_adapter.cpu()
305
+ # service.vae_pipe = service.vae_pipe.to(service.vit_sd_device, dtype=service.dtype)
306
+
307
+ torch.cuda.empty_cache()
308
+
309
+ service.vae_pipe = service.vae_pipe.to(service.vit_sd_device)
310
+
311
+ init_image = generated_image.resize((1024, 1024))
312
+ prompt = ""
313
+ images = service.vae_pipe(prompt=prompt, image=init_image,
314
+ num_inference_steps=50, guidance_scale=8.0, strength=0.38).images
315
+ generated_image = images[0]
316
+
317
+ image_base64 = encode_image(generated_image)
318
+ gen_imgs_base64_list.append(image_base64)
319
+
320
+ # service.vae_pipe = service.vae_pipe.to("cpu")
321
+ # service.sd_adapter = service.sd_adapter.to(service.vit_sd_device, dtype=service.dtype)
322
+
323
+ torch.cuda.empty_cache()
324
+
325
+ # print('loading visual encoder and llm to GPU, and sd to CPU')
326
+ # a = time.time()
327
+ # service.sd_adapter = service.sd_adapter.cpu()
328
+ # service.visual_encoder = service.visual_encoder.to(service.vit_sd_device, dtype=service.dtype)
329
+ # service.agent = service.agent.to(service.vit_sd_device, dtype=service.dtype)
330
+ # print("Loading finished: ", time.time() - a)
331
+
332
+ if args.has_bbox:
333
+ bboxes = extract_box(generated_text)
334
+ if bboxes is not None and len(input_images) > 0:
335
+ image_viz = visualize_bbox(input_images[-1], bboxes)
336
+ image_base64 = encode_image(image_viz)
337
+ gen_imgs_base64_list.append(image_base64)
338
+ if '<box_start>' in generated_text:
339
+ generated_text = re.sub(r'\[\[ <box_start>.*?<box_end>.*?\]\]', 'the green bounding box',
340
+ generated_text)
341
+ else:
342
+ generated_text = re.sub(r'<loc-\d+> <loc-\d+> <loc-\d+> <loc-\d+> <box_end> \]\]',
343
+ 'the green bounding box', generated_text)
344
+ generated_text += IMG_FLAG
345
+ print(input_text + generated_text)
346
+ return {'text': generated_text, 'images': gen_imgs_base64_list, 'error_msg': error_msg}
347
+
348
+
349
+ def http_bot(dialog_state, input_state, max_new_tokens, max_turns, force_image_gen, force_bbox, force_polish,
350
+ request: gr.Request):
351
+ print('input_state:', input_state)
352
+
353
+ if len(dialog_state.messages) == 0 or dialog_state.messages[-1]['role'] != dialog_state.roles[0] or len(
354
+ dialog_state.messages[-1]['message']['text'].strip(' ?.;!/')) == 0:
355
+ return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4
356
+
357
+ if len(dialog_state.messages) > max_turns * 2:
358
+ output_state = init_input_state()
359
+ output_state['text'] = 'Error: History exceeds maximum rounds, please clear history and restart.'
360
+ dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state})
361
+ input_state = init_input_state()
362
+ return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 3 + (enable_btn,)
363
+
364
+ prompt = dialog_state.get_prompt()
365
+ text = prompt['text']
366
+ max_new_tokens = int(max_new_tokens)
367
+ images = prompt['images']
368
+ force_boi = force_image_gen
369
+ force_bbox = force_bbox
370
+
371
+ results = generate(text, images, max_new_tokens, force_boi, force_bbox, force_polish)
372
+ print('response: ', {'text': results['text'], 'error_msg': results['error_msg']})
373
+
374
+ output_state = init_input_state()
375
+ image_dir = get_conv_image_dir()
376
+ output_state['text'] = results['text']
377
+
378
+ for image_base64 in results['images']:
379
+ if image_base64 == '':
380
+ image_path = ''
381
+ else:
382
+ image = decode_image(image_base64)
383
+ image = image.convert('RGB')
384
+ image_path = get_image_name(image=image, image_dir=image_dir)
385
+ if not os.path.exists(image_path):
386
+ image.save(image_path)
387
+ output_state['images'].append(image_path)
388
+
389
+ dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state})
390
+
391
+ vote_last_response(dialog_state, 'common', request)
392
+ input_state = init_input_state()
393
+ chatbot = update_error_msg(dialog_state.to_gradio_chatbot(), results['error_msg'])
394
+ return (dialog_state, input_state, chatbot) + (enable_btn,) * 4
395
+
396
+
397
+ IMG_FLAG = '<image>'
398
+ LOGDIR = 'log'
399
+
400
+ logger = build_logger("gradio_seed_x", LOGDIR)
401
+ headers = {"User-Agent": "SEED-X Client"}
402
+
403
+ no_change_btn = gr.Button()
404
+ enable_btn = gr.Button(interactive=True)
405
+ disable_btn = gr.Button(interactive=False)
406
+
407
+ conv_seed_llama = conv_seed_llama2
408
+
409
+
410
+ def get_conv_log_filename():
411
+ t = datetime.datetime.now()
412
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
413
+ return name
414
+
415
+
416
+ def get_conv_image_dir():
417
+ name = os.path.join(LOGDIR, 'images')
418
+ os.makedirs(name, exist_ok=True)
419
+ return name
420
+
421
+
422
+ def get_image_name(image, image_dir=None):
423
+ buffer = io.BytesIO()
424
+ image.save(buffer, format='PNG')
425
+ image_bytes = buffer.getvalue()
426
+ md5 = hashlib.md5(image_bytes).hexdigest()
427
+
428
+ if image_dir is not None:
429
+ image_name = os.path.join(image_dir, md5 + '.png')
430
+ else:
431
+ image_name = md5 + '.png'
432
+
433
+ return image_name
434
+
435
+
436
+ def resize_image_square(image, target_size=448):
437
+ resized_image = image.resize((target_size, target_size))
438
+ return resized_image
439
+
440
+
441
+ def resize_image(image, max_size=512):
442
+ width, height = image.size
443
+ aspect_ratio = float(width) / float(height)
444
+
445
+ if width > height:
446
+ new_width = max_size
447
+ new_height = int(new_width / aspect_ratio)
448
+ else:
449
+ new_height = max_size
450
+ new_width = int(new_height * aspect_ratio)
451
+
452
+ resized_image = image.resize((new_width, new_height))
453
+ return resized_image
454
+
455
+
456
+ def center_crop_image(image, max_aspect_ratio=1.5):
457
+ width, height = image.size
458
+ aspect_ratio = max(width, height) / min(width, height)
459
+
460
+ if aspect_ratio >= max_aspect_ratio:
461
+ if width > height:
462
+ new_width = int(height * max_aspect_ratio)
463
+ left = (width - new_width) // 2
464
+ right = (width + new_width) // 2
465
+ top = 0
466
+ bottom = height
467
+ else:
468
+ new_height = int(width * max_aspect_ratio)
469
+ left = 0
470
+ right = width
471
+ top = (height - new_height) // 2
472
+ bottom = (height + new_height) // 2
473
+
474
+ cropped_image = image.crop((left, top, right, bottom))
475
+ return cropped_image
476
+ else:
477
+ return image
478
+
479
+
480
+ def vote_last_response(state, vote_type, request: gr.Request):
481
+ with open(get_conv_log_filename(), "a") as fout:
482
+ data = {
483
+ "tstamp": round(time.time(), 4),
484
+ "type": vote_type,
485
+ "state": state.dict(),
486
+ "ip": request.client.host,
487
+ }
488
+ fout.write(json.dumps(data) + "\n")
489
+
490
+
491
+ def upvote_last_response(state, request: gr.Request):
492
+ logger.info(f"upvote. ip: {request.client.host}")
493
+ vote_last_response(state, "upvote", request)
494
+ return (disable_btn,) * 2
495
+
496
+
497
+ def downvote_last_response(state, request: gr.Request):
498
+ logger.info(f"downvote. ip: {request.client.host}")
499
+ vote_last_response(state, "downvote", request)
500
+ return (disable_btn,) * 2
501
+
502
+
503
+ def regenerate(dialog_state, request: gr.Request):
504
+ logger.info(f"regenerate. ip: {request.client.host}")
505
+ if dialog_state.messages[-1]['role'] == dialog_state.roles[1]:
506
+ dialog_state.messages.pop()
507
+ return (
508
+ dialog_state,
509
+ dialog_state.to_gradio_chatbot(),
510
+ ) + (disable_btn,) * 4
511
+
512
+
513
+ def clear_history(request: gr.Request):
514
+ logger.info(f"clear_history. ip: {request.client.host}")
515
+ dialog_state = conv_seed_llama.copy()
516
+ input_state = init_input_state()
517
+ return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4
518
+
519
+
520
+ def init_input_state():
521
+ return {'images': [], 'text': ''}
522
+
523
+
524
+ def add_text(dialog_state, input_state, text, request: gr.Request):
525
+ logger.info(f"add_text. ip: {request.client.host}.")
526
+ if text is None or len(text) == 0:
527
+ return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4
528
+ input_state['text'] += text
529
+
530
+ if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]:
531
+ dialog_state.messages[-1]['message'] = input_state
532
+ else:
533
+ dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state})
534
+ print('add_text: ', dialog_state.to_gradio_chatbot())
535
+
536
+ return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4
537
+
538
+
539
+ def is_blank(image):
540
+ image_array = np.array(image)
541
+ unique_colors = np.unique(image_array)
542
+ print('unique_colors', len(unique_colors))
543
+ return len(unique_colors) == 1
544
+
545
+
546
+ def add_image(dialog_state, input_state, image, request: gr.Request):
547
+ logger.info(f"add_image. ip: {request.client.host}.")
548
+ if image is None:
549
+ return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4
550
+
551
+ image = image.convert('RGB')
552
+
553
+ print('image size:', image.size)
554
+
555
+ image = center_crop_image(image, max_aspect_ratio=10)
556
+
557
+ image_dir = get_conv_image_dir()
558
+ image_path = get_image_name(image=image, image_dir=image_dir)
559
+ if not os.path.exists(image_path):
560
+ image.save(image_path)
561
+ input_state['images'].append(image_path)
562
+ input_state['text'] += IMG_FLAG
563
+
564
+ if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]:
565
+ dialog_state.messages[-1]['message'] = input_state
566
+ else:
567
+ dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state})
568
+
569
+ print('add_image:', dialog_state)
570
+
571
+ return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4
572
+
573
+
574
+ def update_error_msg(chatbot, error_msg):
575
+ if len(error_msg) > 0:
576
+ info = '\n-------------\nSome errors occurred during response, please clear history and restart.\n' + '\n'.join(
577
+ error_msg)
578
+ chatbot[-1][-1] = chatbot[-1][-1] + info
579
+
580
+ return chatbot
581
+
582
+
583
+ def load_demo(request: gr.Request):
584
+ logger.info(f"load_demo. ip: {request.client.host}")
585
+ dialog_state = conv_seed_llama.copy()
586
+ input_state = init_input_state()
587
+ return dialog_state, input_state
588
+
589
+
590
+ title = ("""
591
+ # SEED-X-I
592
+ [[Paper]](https://arxiv.org/abs/2404.14396) [[Code]](https://github.com/AILab-CVC/SEED-X) [[Faster Demo]](https://arc.tencent.com/en/ai-demos/multimodal)
593
+
594
+ Demo of a general instruction-tuned model SEED-X-I (17B) from the foundation model SEED-X.
595
+ SEED-X-I can follow multimodal instruction (including images with **dynamic resolutions**) and make responses with **images, texts and bounding boxes** in multi-turn conversation.
596
+
597
+ SEED-X-I **does not support image manipulation**. If you want to experience **SEED-X-Edit** for high-precision image editing, please refer to [[Inference Code]](https://github.com/AILab-CVC/SEED-X).
598
+
599
+ If you want to experience the normal model inference speed, you can use [[Faster Demo]](https://arc.tencent.com/en/ai-demos/multimodal) or run [[Inference Code]](https://github.com/AILab-CVC/SEED-X) locally.
600
+
601
+ ## Tips:
602
+ * Check out the conversation examples (at the bottom) for inspiration.
603
+ * You can adjust "Max History Rounds" to try a conversation with up to **three rounds due to insufficient GPU memory**. For more turns, you can download our checkpoints from GitHub and deploy them locally for inference.
604
+ * Our demo supports a mix of images and texts as input. You can freely upload an image or enter text, and then click on "Add Image/Text". You can repeat the former step multiple times, and click on "Submit" for model inference at last.
605
+ * You can click "Force Image Generation" to compel the model to produce images when necessary. For example, our model might struggle to generate images when there is an excessive amount of text-only context.
606
+ * You can click "Force Bounding Box" to compel the model to produce bounding box for object detection.
607
+ * You can click "Force Polishing Generated Image" to compel the model to polish the generated image with image post-processing.
608
+
609
+ * SEED-X was trained with English-only data. It may process with other languages due to the inherent capabilities from LLaMA, but might not stable.
610
+ """)
611
+
612
+ css = """
613
+ img {
614
+ font-family: 'Helvetica';
615
+ font-weight: 300;
616
+ line-height: 2;
617
+ text-align: center;
618
+
619
+ width: auto;
620
+ height: auto;
621
+ display: block;
622
+ position: relative;
623
+ }
624
+ img:before {
625
+ content: " ";
626
+ display: block;
627
+ position: absolute;
628
+ top: -10px;
629
+ left: 0;
630
+ height: calc(100% + 10px);
631
+ width: 100%;
632
+ background-color: rgb(230, 230, 230);
633
+ border: 2px dotted rgb(200, 200, 200);
634
+ border-radius: 5px;
635
+ }
636
+ img:after {
637
+ content: " ";
638
+ display: block;
639
+ font-size: 16px;
640
+ font-style: normal;
641
+ font-family: FontAwesome;
642
+ color: rgb(100, 100, 100);
643
+
644
+ position: absolute;
645
+ top: 5px;
646
+ left: 0;
647
+ width: 100%;
648
+ text-align: center;
649
+ }
650
+ """
651
+
652
+ if __name__ == '__main__':
653
+ examples_mix = [
654
+ ['https://github.com/AILab-CVC/SEED-X/blob/main/demos/bank.png?raw=true',
655
+ 'Can I conntect with an advisor on Sunday?'],
656
+ ['https://github.com/AILab-CVC/SEED-X/blob/main/demos/ground.png?raw=true',
657
+ 'Is there anything in the image that can protect me from catching the flu virus when I go out? Show me the location.'],
658
+ ['https://github.com/AILab-CVC/SEED-X/blob/main/demos/arrow.jpg?raw=true',
659
+ 'What is the object pointed by the red arrow?'],
660
+ ['https://github.com/AILab-CVC/SEED-X/blob/main/demos/shanghai.png?raw=true',
661
+ 'Where was this image taken? Explain your answer.'],
662
+ ['https://github.com/AILab-CVC/SEED-X/blob/main/demos/GPT4.png?raw=true',
663
+ 'How long does it take to make GPT-4 safer?'],
664
+ ['https://github.com/AILab-CVC/SEED-X/blob/main/demos/twitter.png?raw=true',
665
+ 'Please provide a comprehensive description of this image.'],
666
+ ]
667
+ examples_text = [
668
+ ['I want to build a two story cabin in the woods, with many commanding windows. Can you show me a picture?'],
669
+ ['Use your imagination to design a concept image for Artificial General Intelligence (AGI). Show me an image.'],
670
+ [
671
+ 'Can you design an illustration for “The Three-Body Problem” to depict a scene from the novel? Show me a picture.'],
672
+ [
673
+ 'My four year old son loves toy trains. Can you design a fancy birthday cake for him? Please generate a picture.'],
674
+ [
675
+ 'Generate an image of a portrait of young nordic girl, age 25, freckled skin, neck tatoo, blue eyes 35mm lens, photography, ultra details.'],
676
+ ['Generate an impressionist painting of an astronaut in a jungle.']
677
+ ]
678
+ with gr.Blocks(css=css) as demo:
679
+ gr.Markdown(title)
680
+ dialog_state = gr.State()
681
+ input_state = gr.State()
682
+ with gr.Row():
683
+ with gr.Column(scale=3):
684
+ with gr.Row():
685
+ image = gr.Image(type='pil', label='input_image')
686
+ with gr.Row():
687
+ text = gr.Textbox(lines=5,
688
+ show_label=False,
689
+ label='input_text',
690
+ elem_id='textbox',
691
+ placeholder="Enter text and image, and press submit,", container=False)
692
+ with gr.Row():
693
+ add_image_btn = gr.Button("Add Image")
694
+ add_text_btn = gr.Button("Add Text")
695
+
696
+ submit_btn = gr.Button("Submit")
697
+
698
+ with gr.Row():
699
+ max_new_tokens = gr.Slider(minimum=64,
700
+ maximum=1024,
701
+ value=768,
702
+ step=64,
703
+ interactive=True,
704
+ label="Max Output Tokens")
705
+ max_turns = gr.Slider(minimum=1, maximum=3, value=3, step=1, interactive=True,
706
+ label="Max History Rounds")
707
+ force_img_gen = gr.Radio(choices=[True, False], value=False, label='Force Image Generation')
708
+ force_bbox = gr.Radio(choices=[True, False], value=False, label='Force Bounding Box')
709
+ force_polish = gr.Radio(choices=[True, False], value=True, label='Force Polishing Generated Image')
710
+
711
+ with gr.Column(scale=7):
712
+ chatbot = gr.Chatbot(elem_id='chatbot', label="SEED-X-I", height=700)
713
+ with gr.Row():
714
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
715
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
716
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
717
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
718
+
719
+ with gr.Row():
720
+ with gr.Column(scale=0.7):
721
+ gr.Examples(examples=examples_mix, label='Input examples', inputs=[image, text], cache_examples=False)
722
+ with gr.Column(scale=0.3):
723
+ gr.Examples(examples=examples_text, label='Input examples', inputs=[text], cache_examples=False)
724
+
725
+ # Register listeners
726
+ btn_list = [upvote_btn, downvote_btn, regenerate_btn, clear_btn]
727
+ upvote_btn.click(upvote_last_response, [dialog_state], [upvote_btn, downvote_btn])
728
+ downvote_btn.click(downvote_last_response, [dialog_state], [upvote_btn, downvote_btn])
729
+
730
+ regenerate_btn.click(regenerate, [dialog_state], [dialog_state, chatbot] + btn_list).then(
731
+ http_bot, [dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox, force_polish],
732
+ [dialog_state, input_state, chatbot] + btn_list)
733
+ add_image_btn.click(add_image, [dialog_state, input_state, image],
734
+ [dialog_state, input_state, image, chatbot] + btn_list)
735
+
736
+ add_text_btn.click(add_text, [dialog_state, input_state, text],
737
+ [dialog_state, input_state, text, chatbot] + btn_list)
738
+
739
+ submit_btn.click(
740
+ add_image, [dialog_state, input_state, image], [dialog_state, input_state, image, chatbot] + btn_list).then(
741
+ add_text, [dialog_state, input_state, text],
742
+ [dialog_state, input_state, text, chatbot, upvote_btn, downvote_btn, regenerate_btn, clear_btn]).then(
743
+ http_bot,
744
+ [dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox, force_polish],
745
+ [dialog_state, input_state, chatbot] + btn_list)
746
+ clear_btn.click(clear_history, None, [dialog_state, input_state, chatbot] + btn_list)
747
+
748
+ demo.load(load_demo, None, [dialog_state, input_state])
749
+
750
+ demo.launch(debug=True)
conversation.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+ import io
6
+ import base64
7
+ import os
8
+ from PIL import Image
9
+ import copy
10
+
11
+ IMG_FLAG = '<image>'
12
+
13
+
14
+ class SeparatorStyle(Enum):
15
+ """Different separator style."""
16
+ SINGLE = auto()
17
+ TWO = auto()
18
+ MPT = auto()
19
+ PLAIN = auto()
20
+ LLAMA_2 = auto()
21
+
22
+
23
+ def decode_image(encoded_image: str) -> Image:
24
+ decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
25
+ buffer = io.BytesIO(decoded_bytes)
26
+ image = Image.open(buffer)
27
+ return image
28
+
29
+
30
+ def encode_image(image: Image.Image, format: str = 'PNG') -> str:
31
+ with io.BytesIO() as buffer:
32
+ image.save(buffer, format=format)
33
+ encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
34
+ return encoded_image
35
+
36
+
37
+ @dataclasses.dataclass
38
+ class Conversation:
39
+ """A class that keeps all conversation history."""
40
+ system: str
41
+ roles: List[str]
42
+ messages: List[dict] # multi-turn -> user & assistant -> {'images': [PIL.Image,], 'text': str}
43
+ offset: int
44
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
45
+ sep: str = "###"
46
+ sep2: str = None
47
+ version: str = "Unknown"
48
+
49
+ skip_next: bool = False
50
+
51
+ def get_prompt(self):
52
+ messages = copy.deepcopy(self.messages)
53
+ if self.sep_style == SeparatorStyle.SINGLE:
54
+ if self.system is None or self.system == '':
55
+ text = ''
56
+ else:
57
+ text = self.system + self.sep
58
+ images = []
59
+ for message in messages:
60
+ text += message['role'] + ": " + message['message']['text'] + self.sep
61
+ for image_path in message['message']['images']:
62
+ image = Image.open(image_path).resize((256, 256))
63
+ image_base64 = encode_image(image)
64
+ images.append(image_base64)
65
+
66
+ text += self.roles[1] + ":"
67
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
68
+ b_token = "[INST] "
69
+ e_token = " [/INST]"
70
+ if self.system is None or self.system == '':
71
+ text = ''
72
+ else:
73
+ text = f"<<SYS>>\n{self.system}\n<</SYS>>\n\n"
74
+ images = []
75
+ for idx, message in enumerate(messages):
76
+ # text += message['role'] + ": " + message['message']['text'] + self.sep
77
+ if idx % 2 == 0:
78
+ text += b_token + message['message']['text'] + e_token + self.sep
79
+ else:
80
+ text += message['message']['text'] + self.sep
81
+
82
+ for image_path in message['message']['images']:
83
+ image = Image.open(image_path)
84
+ image_base64 = encode_image(image)
85
+ images.append(image_base64)
86
+ else:
87
+ raise NotImplementedError
88
+
89
+ return {'text': text, 'images': images}
90
+
91
+ # def update_image_ids(self, images_ids):
92
+ # image_count = 0
93
+ # for message in self.messages:
94
+ # for idx in range(len(message['message']['images_ids'])):
95
+ # if message['message']["images_ids"][idx] is None:
96
+ # message['message']["images_ids"][idx] = images_ids[image_count]
97
+ # image_count += 1
98
+
99
+ # assert len(images_ids) == image_count, print(len(images_ids), image_count)
100
+
101
+ def append_message(self, role, message):
102
+ self.messages.append([role, message])
103
+
104
+ def to_gradio_chatbot(self):
105
+ dialog = []
106
+ for i, single_turn in enumerate(self.messages[self.offset:]):
107
+ single_turn = single_turn['message']
108
+ text_list = single_turn['text'].split(IMG_FLAG)
109
+ assert len(text_list) == len(single_turn['images']) + 1, print(text_list, len(single_turn['images']))
110
+ message = ''
111
+ for image_idx in range(len(single_turn['images'])):
112
+ image_path = single_turn['images'][image_idx]
113
+ image = Image.open(image_path)
114
+ image_base64 = encode_image(image)
115
+ image_str = f'<img src="data:image/png;base64,{image_base64}" alt="user upload image" />'
116
+ message += text_list[image_idx] + image_str
117
+
118
+ # image_path = single_turn['images'][image_idx]
119
+ # if image_path == '':
120
+ # message += text_list[image_idx] + '<corrupt_image>'
121
+ # else:
122
+ # message += text_list[image_idx] + f'![](file={image_path})'
123
+ message += text_list[-1]
124
+
125
+ if i % 2 == 0:
126
+ dialog.append([message, None])
127
+ else:
128
+ dialog[-1][-1] = message
129
+
130
+ return dialog
131
+
132
+ def copy(self):
133
+ return Conversation(system=self.system,
134
+ roles=self.roles,
135
+ messages=copy.deepcopy(self.messages),
136
+ offset=self.offset,
137
+ sep_style=self.sep_style,
138
+ sep=self.sep,
139
+ sep2=self.sep2,
140
+ version=self.version)
141
+
142
+ def dict(self):
143
+ messages = copy.deepcopy(self.messages)
144
+ for message in messages:
145
+ for i in range(len(message['message']['images'])):
146
+ message['message']['images'][i] = os.path.basename(message['message']['images'][i])
147
+ return {
148
+ "system": self.system,
149
+ "roles": self.roles,
150
+ "messages": messages,
151
+ "offset": self.offset,
152
+ "sep": self.sep,
153
+ "sep2": self.sep2,
154
+ }
155
+
156
+
157
+ conv_seed_vicuna = Conversation(
158
+ system="",
159
+ roles=("USER", "ASSISTANT"),
160
+ version="v2",
161
+ messages=[],
162
+ offset=0,
163
+ sep_style=SeparatorStyle.SINGLE,
164
+ sep='\n',
165
+ )
166
+
167
+ conv_seed_vicuna_system = Conversation(
168
+ system="A chat between a curious user and an artificial intelligence assistant. ",
169
+ roles=("USER", "ASSISTANT"),
170
+ version="v2",
171
+ messages=[],
172
+ offset=0,
173
+ sep_style=SeparatorStyle.SINGLE,
174
+ sep='\n',
175
+ )
176
+
177
+ conv_seed_llama2 = Conversation(
178
+ system="",
179
+ roles=("[INST]", "[/INST]"),
180
+ version="v2",
181
+ messages=[],
182
+ offset=0,
183
+ sep_style=SeparatorStyle.LLAMA_2,
184
+ sep='\n',
185
+ )
src/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .datapipes import TarArchiveLoader, JsonlParserIterDataPipe
src/data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (226 Bytes). View file
 
src/data/__pycache__/datapipes.cpython-38.pyc ADDED
Binary file (2.86 kB). View file
 
src/data/dataloader_utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import time
9
+ import random
10
+ import torch
11
+ # from lavis.datasets.data_utils import move_to_cuda
12
+ from torch.utils.data import DataLoader
13
+
14
+
15
+ class MultiIterLoader:
16
+ """
17
+ A simple wrapper for iterating over multiple iterators.
18
+
19
+ Args:
20
+ loaders (List[Loader]): List of Iterator loaders.
21
+ ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
22
+ """
23
+
24
+ def __init__(self, loaders, ratios=None):
25
+ # assert all loaders has __next__ method
26
+ for loader in loaders:
27
+ assert hasattr(loader, "__next__"), "Loader {} has no __next__ method.".format(loader)
28
+
29
+ if ratios is None:
30
+ ratios = [1.0] * len(loaders)
31
+ else:
32
+ assert len(ratios) == len(loaders)
33
+ ratios = [float(ratio) / sum(ratios) for ratio in ratios]
34
+
35
+ self.loaders = loaders
36
+ self.ratios = ratios
37
+
38
+ def __next__(self):
39
+ # random sample from each loader by ratio
40
+ loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
41
+ return next(self.loaders[loader_idx])
42
+
43
+ def __iter__(self):
44
+ return self
45
+
46
+
47
+ class PrefetchLoader(object):
48
+ """
49
+ Modified from https://github.com/ChenRocks/UNITER.
50
+
51
+ overlap compute and cuda data transfer
52
+ (copied and then modified from nvidia apex)
53
+ """
54
+
55
+ def __init__(self, loader):
56
+ self.loader = loader
57
+ self.stream = torch.cuda.Stream()
58
+
59
+ def __iter__(self):
60
+ loader_it = iter(self.loader)
61
+ self.preload(loader_it)
62
+ batch = self.next(loader_it)
63
+ while batch is not None:
64
+ is_tuple = isinstance(batch, tuple)
65
+ if is_tuple:
66
+ task, batch = batch
67
+
68
+ if is_tuple:
69
+ yield task, batch
70
+ else:
71
+ yield batch
72
+ batch = self.next(loader_it)
73
+
74
+ def __len__(self):
75
+ return len(self.loader)
76
+
77
+ def preload(self, it):
78
+ try:
79
+ self.batch = next(it)
80
+ except StopIteration:
81
+ self.batch = None
82
+ return
83
+ # if record_stream() doesn't work, another option is to make sure
84
+ # device inputs are created on the main stream.
85
+ # self.next_input_gpu = torch.empty_like(self.next_input,
86
+ # device='cuda')
87
+ # self.next_target_gpu = torch.empty_like(self.next_target,
88
+ # device='cuda')
89
+ # Need to make sure the memory allocated for next_* is not still in use
90
+ # by the main stream at the time we start copying to next_*:
91
+ # self.stream.wait_stream(torch.cuda.current_stream())
92
+ # with torch.cuda.stream(self.stream):
93
+ # self.batch = move_to_cuda(self.batch)
94
+ # more code for the alternative if record_stream() doesn't work:
95
+ # copy_ will record the use of the pinned source tensor in this
96
+ # side stream.
97
+ # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
98
+ # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
99
+ # self.next_input = self.next_input_gpu
100
+ # self.next_target = self.next_target_gpu
101
+
102
+ def next(self, it):
103
+ torch.cuda.current_stream().wait_stream(self.stream)
104
+ batch = self.batch
105
+ if batch is not None:
106
+ record_cuda_stream(batch)
107
+ self.preload(it)
108
+ return batch
109
+
110
+ def __getattr__(self, name):
111
+ method = self.loader.__getattribute__(name)
112
+ return method
113
+
114
+
115
+ def record_cuda_stream(batch):
116
+ if isinstance(batch, torch.Tensor):
117
+ batch.record_stream(torch.cuda.current_stream())
118
+ elif isinstance(batch, list) or isinstance(batch, tuple):
119
+ for t in batch:
120
+ record_cuda_stream(t)
121
+ elif isinstance(batch, dict):
122
+ for t in batch.values():
123
+ record_cuda_stream(t)
124
+ else:
125
+ pass
126
+
127
+
128
+ class IterLoader:
129
+ """
130
+ A wrapper to convert DataLoader as an infinite iterator.
131
+
132
+ Modified from:
133
+ https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
134
+ """
135
+
136
+ def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
137
+ self._dataloader = dataloader
138
+ self.iter_loader = iter(self._dataloader)
139
+ self._use_distributed = use_distributed
140
+ self._epoch = 0
141
+
142
+ @property
143
+ def epoch(self) -> int:
144
+ return self._epoch
145
+
146
+ def __next__(self):
147
+ try:
148
+ data = next(self.iter_loader)
149
+ except StopIteration:
150
+ self._epoch += 1
151
+ if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
152
+ self._dataloader.sampler.set_epoch(self._epoch)
153
+ time.sleep(2) # Prevent possible deadlock during epoch transition
154
+ self.iter_loader = iter(self._dataloader)
155
+ data = next(self.iter_loader)
156
+
157
+ return data
158
+
159
+ def __iter__(self):
160
+ return self
161
+
162
+ def __len__(self):
163
+ return len(self._dataloader)
src/data/datapipes.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchdata.datapipes as dp
2
+ import os
3
+ import tarfile
4
+ from torchdata.datapipes.iter import TarArchiveLoader
5
+ from typing import cast, IO, Iterable, Iterator, Optional, Tuple, Dict
6
+ from torchdata.datapipes import functional_datapipe
7
+ from io import BufferedIOBase
8
+ from torchdata.datapipes.utils import StreamWrapper
9
+ from torchdata.datapipes.utils.common import validate_pathname_binary_tuple
10
+ import warnings
11
+ from torchdata.datapipes.iter import IterDataPipe
12
+ import json
13
+
14
+
15
+ @functional_datapipe("load_from_tar_wo_exception")
16
+ class TarArchiveLoaderWoException(TarArchiveLoader):
17
+
18
+ def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
19
+ for data in self.datapipe:
20
+ validate_pathname_binary_tuple(data)
21
+ pathname, data_stream = data
22
+ try:
23
+ if isinstance(data_stream, StreamWrapper) and isinstance(data_stream.file_obj, tarfile.TarFile):
24
+ tar = data_stream.file_obj
25
+ else:
26
+ reading_mode = (self.mode if hasattr(data_stream, "seekable") and data_stream.seekable() else
27
+ self.mode.replace(":", "|"))
28
+ # typing.cast is used here to silence mypy's type checker
29
+ tar = tarfile.open(fileobj=cast(Optional[IO[bytes]], data_stream), mode=reading_mode)
30
+ for tarinfo in tar:
31
+ if not tarinfo.isfile():
32
+ continue
33
+ extracted_fobj = tar.extractfile(tarinfo)
34
+ if extracted_fobj is None:
35
+ warnings.warn(f"failed to extract file {tarinfo.name} from source tarfile {pathname}")
36
+ raise tarfile.ExtractError
37
+ inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name))
38
+ yield inner_pathname, StreamWrapper(extracted_fobj, data_stream,
39
+ name=inner_pathname) # type: ignore[misc]
40
+ except Exception as e:
41
+ warnings.warn(f"Unable to extract files from corrupted tarfile stream {pathname} due to: {e}, abort!")
42
+ # raise e
43
+ finally:
44
+ if isinstance(data_stream, StreamWrapper):
45
+ data_stream.autoclose()
46
+
47
+
48
+ @functional_datapipe("parse_jsonl_files")
49
+ class JsonlParserIterDataPipe(IterDataPipe[Tuple[str, Dict]]):
50
+
51
+ def __init__(self, source_datapipe: IterDataPipe[Tuple[str, IO]], **kwargs) -> None:
52
+ self.source_datapipe: IterDataPipe[Tuple[str, IO]] = source_datapipe
53
+ self.kwargs = kwargs
54
+
55
+ def __iter__(self) -> Iterator[Tuple[str, Dict]]:
56
+ for file_name, stream in self.source_datapipe:
57
+ for idx, line in enumerate(stream):
58
+ if line.strip() != '':
59
+ try:
60
+ yield f'{file_name}_line{idx}', json.loads(line)
61
+ except Exception as e:
62
+ warnings.warn(f"Error occured when parsing string to json due to: {e} abort!")
src/data/story_telling.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchdata.datapipes as dp
2
+ import json
3
+ from PIL import Image
4
+ import functools
5
+ import numpy as np
6
+ import torch
7
+ import pickle
8
+ import os
9
+ import cv2
10
+ import random
11
+ from torchvision import transforms
12
+ from braceexpand import braceexpand
13
+ import hydra
14
+ from random import choice
15
+ import tarfile
16
+ from torchdata.datapipes.iter import TarArchiveLoader
17
+ from typing import cast, IO, Iterable, Iterator, Optional, Tuple, Dict
18
+ from torchdata.datapipes import functional_datapipe
19
+ from io import BufferedIOBase
20
+ from torchdata.datapipes.utils import StreamWrapper
21
+ from torchdata.datapipes.utils.common import validate_pathname_binary_tuple
22
+ import warnings
23
+ from torchdata.datapipes.iter import IterDataPipe
24
+
25
+ import pyrootutils
26
+
27
+ pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
28
+
29
+ BOI_TOKEN = '<img>'
30
+ EOI_TOKEN = '</img>'
31
+ IMG_TOKEN = '<img_{:05d}>'
32
+
33
+ gen_prompt = [
34
+ "Please show me a picture of ",
35
+ "Please design an image of ",
36
+ "Please produce a photo of ",
37
+ "Please generate an image of ",
38
+ "Please draw a painting of ",
39
+ "I'd like to see a drawing of ",
40
+ "I'd love to see an illustration of ",
41
+ "I'd like to view an image of ",
42
+ "I want to see a picture of ",
43
+ "I would like to see a photo of ",
44
+ "Show me a photo of ",
45
+ "Generate a picture of ",
46
+ "Show me a photograph of ",
47
+ "Generate an image of ",
48
+ "Generate an image: ",
49
+ "Generate a picture: ",
50
+ "Generate a painting: ",
51
+ "Generate a photograph: ",
52
+ "Show me a photograph: ",
53
+ "Draw a picture: ",
54
+ "Draw a painting: ",
55
+ "Draw an image: ",
56
+ "Can you make an image of ",
57
+ "Can you draw a painting of ",
58
+ "Can you produce a picture of ",
59
+ "Can you generate a photo of ",
60
+ "Can you depict a picture of ",
61
+ "Can you show me an illustration of ",
62
+ ]
63
+
64
+ gen_prompt_response = [
65
+ "Here is a picture.",
66
+ "I have designed an image.",
67
+ "Here is a photo.",
68
+ "I have generated an image.",
69
+ "Here's a painting.",
70
+ "Here's a drawing.",
71
+ "Enjoy this illustration.",
72
+ "Take a look at this image.",
73
+ "Here is a picture.",
74
+ "I have created a photo.",
75
+ "Enjoy this photo.",
76
+ "I have generated a picture.",
77
+ "Here is a photograph.",
78
+ "Here's an image.",
79
+ "Certainly, here's an image.",
80
+ "Absolutely, here is a painting.",
81
+ "Sure, here is a picture.",
82
+ "Of course, here is a photo.",
83
+ "Certainly, please enjoy this picture.",
84
+ "Sure, please enjoy this illustration.",
85
+ "",
86
+ ]
87
+
88
+ jdb_filter_vocab = ['watermark', 'watermark,', 'chaos 100', 'chaos 100,']
89
+
90
+
91
+ def filter_data_with_image_ids(item):
92
+ if ('images' not in item):
93
+ # print(item['__key__'])
94
+ # print('filtered because no images')
95
+ return False
96
+ elif 'input_ids' not in item:
97
+ return False
98
+ else:
99
+ return True
100
+
101
+
102
+ def calculate_new_dimensions(height, width, target_size):
103
+ if height < width:
104
+ new_height = target_size
105
+ new_width = int(width * (target_size / height))
106
+ else:
107
+ new_width = target_size
108
+ new_height = int(height * (target_size / width))
109
+ return new_height, new_width
110
+
111
+
112
+ def unwarp_data(item):
113
+ unwarpped = {}
114
+ for key, value in item.items():
115
+ if isinstance(value, dict):
116
+ unwarpped.update(value)
117
+ elif value is not None:
118
+ unwarpped[key] = value
119
+ if 'metadata' not in unwarpped:
120
+ unwarpped['metadata'] = '{}'
121
+ # if '__key__' in unwarpped:
122
+ # unwarpped['__key__'] = unwarpped['__key__'].split('/')[-1]
123
+ return unwarpped
124
+
125
+
126
+ # def filter_data_with_similarity(item, similarity_thr=0.2, min_resolution=180, min_aspect_ratio=0.666):
127
+ def filter_data_with_similarity(item, similarity_thr=0.2, assure_text=True):
128
+ if ('images' not in item):
129
+ # print(item['__key__'])
130
+ # print('filtered because no images')
131
+ return False
132
+ elif (not item.get('filter_flag', True)):
133
+ # print(item['__key__'])
134
+ # print('filtered because filter flag.')
135
+ return False
136
+ elif assure_text and ('text' not in item):
137
+ # print(item['__key__'])
138
+ # print('filtered because assure_text')
139
+ return False
140
+ else:
141
+ metadata = json.loads(item['metadata'])
142
+
143
+ if 'all_similarities' in metadata:
144
+ similarity = max(metadata['all_similarities'])
145
+ elif 'similarity' in metadata:
146
+ similarity = metadata['similarity']
147
+ elif 'score' in metadata:
148
+ similarity = metadata['score']
149
+ elif 'SCORE' in metadata:
150
+ similarity = metadata['SCORE']
151
+ else:
152
+ similarity = None
153
+
154
+ if similarity is not None:
155
+ if similarity < similarity_thr:
156
+ # print(item['__key__'])
157
+ # print('filtered because similarity')
158
+ return False
159
+
160
+ return True
161
+
162
+
163
+ def single_turn_edit_collate(batch):
164
+ results = {}
165
+ keys = batch[0].keys()
166
+
167
+ for key in keys:
168
+ cur = [batch[i][key] for i in range(len(batch)) if batch[i][key] is not None]
169
+ if len(cur) == 0:
170
+ results[key] = None
171
+ elif isinstance(cur[0], torch.Tensor):
172
+ if key in ['embeds_gen_mask', 'embeds_cmp_mask', 'images']:
173
+ results[key] = torch.cat(cur, dim=0)
174
+ else:
175
+ results[key] = torch.stack(cur, dim=0)
176
+ else:
177
+ results[key] = cur
178
+
179
+ return results
180
+
181
+
182
+ def decode_t2i_data(item,
183
+ image_dir,
184
+ tokenizer,
185
+ image_transform=None,
186
+ sd_image_transform=None,
187
+ max_length=128,
188
+ min_resolution=400,
189
+ instruction_prompt='[INST] {instruction} [/INST]\n',
190
+ turn_sep='\n',
191
+ system_message='',
192
+ min_aspect_ratio=0.666,
193
+ num_img_in_tokens=64,
194
+ num_img_out_tokens=64):
195
+ key, value = item
196
+
197
+ if 'image' not in value or 'caption' not in value:
198
+ return {}
199
+
200
+ image_path = os.path.join(image_dir, value["image"])
201
+
202
+ try:
203
+ image = Image.open(image_path).convert('RGB')
204
+
205
+ width, height = image.size
206
+
207
+ aspect_ratio = height / width
208
+ if height < min_resolution or width < min_resolution:
209
+ print(f'filtered because resolution: ({width},{height})')
210
+ return {}
211
+ if aspect_ratio < min_aspect_ratio or aspect_ratio > 1 / min_aspect_ratio:
212
+ print(f'filtered because aspect ratio: ({width},{height})')
213
+ return {}
214
+ ### SD related
215
+
216
+ image_data = {}
217
+
218
+ if sd_image_transform is not None:
219
+ # image_data['original_sizes'] = torch.tensor([height, width])
220
+ sd_image_tensor = sd_image_transform(image)
221
+ target_size = sd_image_tensor.shape[-2]
222
+ target_width, target_height = calculate_new_dimensions(height=height, width=width, target_size=target_size)
223
+ y1 = max(0, int(round((target_height - target_size) / 2.0)))
224
+ x1 = max(0, int(round((target_width - target_size) / 2.0)))
225
+ # image_data['crop_top_lefts'] = torch.tensor([y1, x1])
226
+ image_data['time_ids'] = torch.tensor([height, width, y1, x1, target_size, target_size])
227
+
228
+ image_data['sd_images'] = sd_image_tensor
229
+
230
+ if image_transform is not None:
231
+ image = image_transform(image)
232
+
233
+ except Exception as e:
234
+ print('Error while decode image: ', e)
235
+ return {}
236
+
237
+ input_ids = []
238
+ labels = []
239
+ input_text = ''
240
+
241
+ if system_message != '':
242
+ if not system_message.endswith('\n'):
243
+ system_message += '\n'
244
+ input_text += system_message
245
+ item_ids = tokenizer.encode(system_message, add_special_tokens=False)
246
+ item_labels = [-100] * len(item_ids)
247
+ input_ids.extend(item_ids)
248
+ labels.extend(item_labels)
249
+
250
+ caption = value["caption"]
251
+
252
+ image_cmp_tokens = BOI_TOKEN + ''.join(
253
+ [IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN
254
+
255
+ image_gen_tokens = BOI_TOKEN + ''.join(
256
+ [IMG_TOKEN.format(int(item)) for item in range(num_img_out_tokens)]) + EOI_TOKEN
257
+
258
+ instruction = instruction_prompt.format_map({'instruction': caption})
259
+
260
+ response = image_gen_tokens
261
+ images = torch.stack([image], dim=0)
262
+ # print(instruction)
263
+
264
+ item_ids = tokenizer.encode(instruction, add_special_tokens=False)
265
+ item_labels = [-100] * len(item_ids)
266
+ input_text += instruction
267
+ input_ids.extend(item_ids)
268
+ labels.extend(item_labels)
269
+
270
+ item_ids = tokenizer.encode(response, add_special_tokens=False)
271
+ item_labels = item_ids
272
+ input_text += response
273
+ input_ids.extend(item_ids)
274
+ labels.extend(item_labels)
275
+
276
+ input_ids = [tokenizer.bos_token_id] + input_ids + [tokenizer.eos_token_id]
277
+ attention_mask = [1] * len(input_ids)
278
+ labels = [-100] + labels + [tokenizer.eos_token_id]
279
+
280
+ boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
281
+ eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
282
+ ids_cmp_mask = [False] * len(input_ids)
283
+ ids_gen_mask = [False] * len(input_ids)
284
+
285
+ embeds_cmp_mask = [False]
286
+ embeds_gen_mask = [True]
287
+
288
+ # print(len(input_ids))
289
+ if len(input_ids) >= max_length:
290
+ # input_ids = input_ids[:max_length]
291
+ # attention_mask = attention_mask[:max_length]
292
+ # labels = labels[:max_length]
293
+ # ids_cmp_mask = ids_cmp_mask[:max_length]
294
+ # ids_gen_mask = ids_gen_mask[:max_length]
295
+ # print('An edit sample has been removed because of max length. input_text: ', input_text)
296
+ return {}
297
+ else:
298
+ padding_length = max_length - len(input_ids)
299
+ input_ids = input_ids + [tokenizer.pad_token_id] * padding_length
300
+ attention_mask = attention_mask + [0] * padding_length
301
+ labels = labels + [-100] * padding_length
302
+ ids_cmp_mask = ids_cmp_mask + [False] * padding_length
303
+ ids_gen_mask = ids_gen_mask + [False] * padding_length
304
+
305
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
306
+ attention_mask = torch.tensor(attention_mask, dtype=torch.long)
307
+ labels = torch.tensor(labels, dtype=torch.long)
308
+ ids_cmp_mask = torch.tensor(ids_cmp_mask, dtype=torch.bool)
309
+ ids_gen_mask = torch.tensor(ids_gen_mask, dtype=torch.bool)
310
+ embeds_cmp_mask = torch.tensor(embeds_cmp_mask) if embeds_cmp_mask is not None else None
311
+ embeds_gen_mask = torch.tensor(embeds_gen_mask) if embeds_gen_mask is not None else None
312
+
313
+ boi_idx = torch.where(input_ids == boi_token_id)[0].tolist()
314
+ eoi_idx = torch.where(input_ids == eoi_token_id)[0].tolist()
315
+
316
+ ids_gen_mask[boi_idx[0] + 1:eoi_idx[0]] = True
317
+ labels[boi_idx[0] + 1:eoi_idx[0] + 1] = -100
318
+
319
+ ret = {
320
+ 'input_ids': input_ids,
321
+ 'attention_mask': attention_mask,
322
+ 'labels': labels,
323
+ 'ids_gen_mask': ids_gen_mask,
324
+ 'ids_cmp_mask': ids_cmp_mask,
325
+ 'embeds_gen_mask': embeds_gen_mask,
326
+ 'embeds_cmp_mask': embeds_cmp_mask,
327
+ 'images': images,
328
+ 'text': input_text,
329
+ }
330
+
331
+ ret.update(image_data)
332
+
333
+ return ret
334
+
335
+
336
+ def build_t2i_datapipe(data_dir,
337
+ image_dir,
338
+ tokenizer=None,
339
+ max_length=77,
340
+ batch_size=None,
341
+ min_resolution=180,
342
+ image_transform=None,
343
+ sd_image_transform=None,
344
+ instruction_prompt='[INST] {instruction} [INST]\n',
345
+ turn_sep='\n',
346
+ system_message='',
347
+ min_aspect_ratio=0.666,
348
+ num_img_in_tokens=64,
349
+ num_img_out_tokens=64,
350
+ cycle_count=None):
351
+ decode_partial = functools.partial(decode_t2i_data,
352
+ image_dir=image_dir,
353
+ tokenizer=tokenizer,
354
+ image_transform=image_transform,
355
+ sd_image_transform=sd_image_transform,
356
+ max_length=max_length,
357
+ instruction_prompt=instruction_prompt,
358
+ turn_sep=turn_sep,
359
+ system_message=system_message,
360
+ min_resolution=min_resolution,
361
+ min_aspect_ratio=min_aspect_ratio,
362
+ num_img_in_tokens=num_img_in_tokens,
363
+ num_img_out_tokens=num_img_out_tokens)
364
+
365
+ filter_partial = functools.partial(filter_data_with_image_ids)
366
+
367
+ if isinstance(data_dir, str):
368
+ data_dir = list(braceexpand(data_dir))
369
+
370
+ datapipe = dp.iter.FileLister(root=data_dir, masks='*.jsonl', recursive=True)
371
+ datapipe = datapipe.shuffle()
372
+ datapipe = datapipe.cycle(count=cycle_count)
373
+ datapipe = datapipe.shuffle()
374
+ # datapipe = dp.iter.FileLister(root=data_dir, masks='0000000.tar', recursive=True)
375
+ datapipe = datapipe.sharding_filter()
376
+ # datapipe = datapipe.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
377
+
378
+ datapipe = datapipe.open_files(mode='r')
379
+ datapipe = datapipe.parse_jsonl_files()
380
+ datapipe = datapipe.map(decode_partial)
381
+ datapipe = datapipe.filter(filter_partial)
382
+
383
+ # datapipe = datapipe.shuffle(buffer_size=1024)
384
+ if batch_size is not None:
385
+ datapipe = datapipe.batch(batch_size)
386
+ datapipe = datapipe.collate(single_turn_edit_collate)
387
+ return datapipe
388
+
389
+
390
+ def decode_long_story_data(item,
391
+ image_dir,
392
+ tokenizer,
393
+ story_len,
394
+ image_transform=None,
395
+ sd_image_transform=None,
396
+ max_length=128,
397
+ min_resolution=400,
398
+ instruction_prompt='{instruction}',
399
+ turn_sep='\n',
400
+ system_message='',
401
+ min_aspect_ratio=0.666,
402
+ num_img_in_tokens=64,
403
+ num_img_out_tokens=64, ):
404
+ key, value = item
405
+ if 'images' not in value or 'captions' not in value:
406
+ return {}
407
+
408
+ image_paths = [os.path.join(image_dir, image_path) for image_path in value["images"]]
409
+ # assert len(image_paths) == story_len
410
+ story_len = len(image_paths)
411
+ num_image_given = random.randint(0, story_len - 2)
412
+
413
+ try:
414
+ images = []
415
+ for image_path in image_paths:
416
+ image = Image.open(image_path).convert('RGB')
417
+ images.append(image)
418
+ width, height = image.size
419
+
420
+ aspect_ratio = height / width
421
+ if height < min_resolution or width < min_resolution:
422
+ print(f'filtered because resolution: ({width},{height})')
423
+ return {}
424
+ if aspect_ratio < min_aspect_ratio or aspect_ratio > 1 / min_aspect_ratio:
425
+ print(f'filtered because aspect ratio: ({width},{height})')
426
+ return {}
427
+
428
+ image_data = {}
429
+ sd_image = images[num_image_given + 1]
430
+ if sd_image_transform is not None:
431
+ # image_data['original_sizes'] = torch.tensor([height, width])
432
+ sd_image_tensor = sd_image_transform(sd_image)
433
+ target_size = sd_image_tensor.shape[-2]
434
+ target_width, target_height = calculate_new_dimensions(height=height, width=width, target_size=target_size)
435
+ y1 = max(0, int(round((target_height - target_size) / 2.0)))
436
+ x1 = max(0, int(round((target_width - target_size) / 2.0)))
437
+ # image_data['crop_top_lefts'] = torch.tensor([y1, x1])
438
+ image_data['time_ids'] = torch.tensor([height, width, y1, x1, target_size, target_size])
439
+
440
+ image_data['sd_images'] = sd_image_tensor
441
+
442
+ if image_transform is not None:
443
+ for i in range(len(images)):
444
+ images[i] = image_transform(images[i])
445
+ images = torch.stack(images, dim=0)
446
+
447
+ except Exception as e:
448
+ print('Error while decode image: ', e)
449
+ return {}
450
+
451
+ input_ids = []
452
+ labels = []
453
+ input_text = ''
454
+
455
+ if system_message != '':
456
+ if not system_message.endswith('\n'):
457
+ system_message += '\n'
458
+ input_text += system_message
459
+ item_ids = tokenizer.encode(system_message, add_special_tokens=False)
460
+ item_labels = [-100] * len(item_ids)
461
+ input_ids.extend(item_ids)
462
+ labels.extend(item_labels)
463
+
464
+ captions_all = []
465
+ for i in range(story_len):
466
+ caption = value["captions"][i]
467
+ captions_all.append(caption)
468
+
469
+ image_cmp_tokens = BOI_TOKEN + ''.join(
470
+ [IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN
471
+
472
+ image_gen_tokens = BOI_TOKEN + ''.join(
473
+ [IMG_TOKEN.format(int(item)) for item in range(num_img_out_tokens)]) + EOI_TOKEN
474
+
475
+ instruction = instruction_prompt.format_map({'instruction': captions_all[0] + image_cmp_tokens})
476
+ for i in range(num_image_given):
477
+ instruction = instruction + "[INST]" + captions_all[i + 1] + image_cmp_tokens
478
+
479
+ response = "[INST]" + captions_all[num_image_given + 1] + image_gen_tokens
480
+
481
+ images = images[:num_image_given + 2]
482
+ # print(instruction)
483
+
484
+ item_ids = tokenizer.encode(instruction, add_special_tokens=False)
485
+ item_labels = [-100] * len(item_ids)
486
+ input_text += instruction
487
+ input_ids.extend(item_ids)
488
+ labels.extend(item_labels)
489
+
490
+ item_ids = tokenizer.encode(response, add_special_tokens=False)
491
+ item_labels = item_ids
492
+ input_text += response
493
+ input_ids.extend(item_ids)
494
+ labels.extend(item_labels)
495
+
496
+ input_ids = [tokenizer.bos_token_id] + input_ids + [tokenizer.eos_token_id]
497
+ attention_mask = [1] * len(input_ids)
498
+ labels = [-100] + labels + [tokenizer.eos_token_id]
499
+
500
+ boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
501
+ eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
502
+ ids_cmp_mask = [False] * len(input_ids)
503
+ ids_gen_mask = [False] * len(input_ids)
504
+
505
+ embeds_cmp_mask = [True] + [True] * num_image_given + [False]
506
+ embeds_gen_mask = [False] + [False] * num_image_given + [True]
507
+
508
+ # print(len(input_ids))
509
+ if len(input_ids) >= max_length:
510
+ # input_ids = input_ids[:max_length]
511
+ # attention_mask = attention_mask[:max_length]
512
+ # labels = labels[:max_length]
513
+ # ids_cmp_mask = ids_cmp_mask[:max_length]
514
+ # ids_gen_mask = ids_gen_mask[:max_length]
515
+ # print('An edit sample has been removed because of max length. input_text: ', input_text)
516
+ return {}
517
+ else:
518
+ padding_length = max_length - len(input_ids)
519
+ input_ids = input_ids + [tokenizer.pad_token_id] * padding_length
520
+ attention_mask = attention_mask + [0] * padding_length
521
+ labels = labels + [-100] * padding_length
522
+ ids_cmp_mask = ids_cmp_mask + [False] * padding_length
523
+ ids_gen_mask = ids_gen_mask + [False] * padding_length
524
+
525
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
526
+ attention_mask = torch.tensor(attention_mask, dtype=torch.long)
527
+ labels = torch.tensor(labels, dtype=torch.long)
528
+ ids_cmp_mask = torch.tensor(ids_cmp_mask, dtype=torch.bool)
529
+ ids_gen_mask = torch.tensor(ids_gen_mask, dtype=torch.bool)
530
+ embeds_cmp_mask = torch.tensor(embeds_cmp_mask) if embeds_cmp_mask is not None else None
531
+ embeds_gen_mask = torch.tensor(embeds_gen_mask) if embeds_gen_mask is not None else None
532
+
533
+ boi_idx = torch.where(input_ids == boi_token_id)[0].tolist()
534
+ eoi_idx = torch.where(input_ids == eoi_token_id)[0].tolist()
535
+
536
+ ids_cmp_mask[boi_idx[0] + 1:eoi_idx[0]] = True
537
+ for i in range(num_image_given):
538
+ ids_cmp_mask[boi_idx[i + 1] + 1:eoi_idx[i + 1]] = True
539
+
540
+ ids_gen_mask[boi_idx[-1] + 1:eoi_idx[-1]] = True
541
+ labels[boi_idx[-1] + 1:eoi_idx[-1] + 1] = -100
542
+
543
+ ret = {
544
+ 'input_ids': input_ids,
545
+ 'attention_mask': attention_mask,
546
+ 'labels': labels,
547
+ 'ids_gen_mask': ids_gen_mask,
548
+ 'ids_cmp_mask': ids_cmp_mask,
549
+ 'embeds_gen_mask': embeds_gen_mask,
550
+ 'embeds_cmp_mask': embeds_cmp_mask,
551
+ 'images': images,
552
+ 'text': input_text,
553
+ }
554
+
555
+ ret.update(image_data)
556
+
557
+ return ret
558
+
559
+
560
+ def build_long_story_datapipe(data_dir,
561
+ image_dir,
562
+ tokenizer=None,
563
+ story_len=30,
564
+ max_length=77,
565
+ batch_size=None,
566
+ min_resolution=180,
567
+ image_transform=None,
568
+ sd_image_transform=None,
569
+ instruction_prompt='{instruction}',
570
+ turn_sep='\n',
571
+ system_message='',
572
+ min_aspect_ratio=0.666,
573
+ num_img_in_tokens=64,
574
+ num_img_out_tokens=64,
575
+ cycle_count=None):
576
+ decode_partial = functools.partial(decode_long_story_data,
577
+ image_dir=image_dir,
578
+ tokenizer=tokenizer,
579
+ story_len=story_len,
580
+ image_transform=image_transform,
581
+ sd_image_transform=sd_image_transform,
582
+ max_length=max_length,
583
+ instruction_prompt=instruction_prompt,
584
+ turn_sep=turn_sep,
585
+ system_message=system_message,
586
+ min_resolution=min_resolution,
587
+ min_aspect_ratio=min_aspect_ratio,
588
+ num_img_in_tokens=num_img_in_tokens,
589
+ num_img_out_tokens=num_img_out_tokens)
590
+
591
+ filter_partial = functools.partial(filter_data_with_image_ids)
592
+
593
+ if isinstance(data_dir, str):
594
+ data_dir = list(braceexpand(data_dir))
595
+
596
+ datapipe = dp.iter.FileLister(root=data_dir, masks='*.jsonl', recursive=True)
597
+ datapipe = datapipe.shuffle()
598
+ datapipe = datapipe.cycle(count=cycle_count)
599
+ datapipe = datapipe.shuffle()
600
+ # datapipe = dp.iter.FileLister(root=data_dir, masks='0000000.tar', recursive=True)
601
+ datapipe = datapipe.sharding_filter()
602
+ # datapipe = datapipe.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
603
+
604
+ datapipe = datapipe.open_files(mode='r')
605
+ datapipe = datapipe.parse_jsonl_files()
606
+ datapipe = datapipe.map(decode_partial)
607
+ datapipe = datapipe.filter(filter_partial)
608
+
609
+ # datapipe = datapipe.shuffle(buffer_size=1024)
610
+ if batch_size is not None:
611
+ datapipe = datapipe.batch(batch_size)
612
+ datapipe = datapipe.collate(single_turn_edit_collate)
613
+ return datapipe
614
+
615
+
616
+ def build_multi_datapipes(datapipes, tokenizer=None, image_transform=None, sd_image_transform=None,
617
+ sample_weights=None):
618
+ # assert concat_type in ['concat', 'mux_longest', 'sample']
619
+ if sample_weights is None:
620
+ sample_weights = [1] * len(datapipes)
621
+ else:
622
+ assert len(sample_weights) == len(datapipes)
623
+
624
+ datapipes = [
625
+ hydra.utils.instantiate(datapipe, tokenizer=tokenizer, image_transform=image_transform,
626
+ sd_image_transform=sd_image_transform) for datapipe in datapipes
627
+ ]
628
+
629
+ datasets_to_weights_dict = {}
630
+ for dataset, sample_weight in zip(datapipes, sample_weights):
631
+ datasets_to_weights_dict[dataset] = sample_weight
632
+ datapipe = dp.iter.SampleMultiplexer(datasets_to_weights_dict)
633
+
634
+ return datapipe
src/eval/gpt_comparative_eval.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from openai import OpenAI
3
+ import ast
4
+ import time
5
+ import os
6
+ import base64
7
+ # from PIL import Image
8
+ import io
9
+
10
+ client = OpenAI(
11
+ base_url="YOUR_URL",
12
+ api_key="YOUR_KEY",
13
+ )
14
+
15
+ instruction = "Please act as an impartial judge and evaluate the quality of the generation story contents provided by two AI assistants. Your job is to evaluate which assistant's generation is better. Your evaluation should consider the coherence of the generated story images and text. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie."
16
+
17
+ # style
18
+ # instruction = "Please act as an impartial judge and evaluate the quality of the generation story contents provided by two AI assistants. Your job is to evaluate which assistant's generation is better. Your evaluation should consider the style consistency of the story images. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie."
19
+
20
+ # text engaging level
21
+ # instruction = "Please act as an impartial judge and evaluate the quality of the generation story contents provided by two AI assistants. Your job is to evaluate which assistant's generation is better. Your evaluation should consider the engaging level of the story. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie."
22
+
23
+ def api_call(messages):
24
+ try_times = 0
25
+ while try_times < 3:
26
+ try:
27
+ chat_completion = client.chat.completions.create(
28
+ messages=messages,
29
+ model="gpt-4-turbo-2024-04-09", #"gpt-4-0125-preview", #"claude-3-opus-20240229", #"gpt-4-1106-preview",
30
+ max_tokens=4096,
31
+ temperature=0.3,
32
+ # stop=['<wait to execute>']
33
+ )
34
+ success = True
35
+ break
36
+ except Exception as e:
37
+ print(f"Error during API call: {e}")
38
+ time.sleep(15)
39
+ try_times += 1
40
+ success = False
41
+ if success:
42
+ cleaned_string = chat_completion.choices[0].message.content.strip()
43
+ return cleaned_string
44
+ else:
45
+ return None
46
+
47
+
48
+ def encode_image(image_path):
49
+ with open(image_path, "rb") as image_file:
50
+ return base64.b64encode(image_file.read()).decode("utf-8")
51
+
52
+
53
+ def read_json_and_extract_content(filepath):
54
+ """
55
+ Reads a JSON file and extracts sentences and images.
56
+
57
+ Args:
58
+ filepath (str): The path to the JSON file.
59
+
60
+ Returns:
61
+ dict: A dictionary with two keys 'sentences' and 'images', containing the respective content.
62
+ """
63
+ with open(filepath, 'r') as file:
64
+ data = json.load(file)
65
+
66
+ all_content = []
67
+ for line in data:
68
+ extracted_content = {
69
+ "sentences": [],
70
+ "images": []
71
+ }
72
+ # Matching sentences to their corresponding images using their indices
73
+ for ix in line['sentence_ixs']:
74
+ if ix == 0:
75
+ continue
76
+ extracted_content['sentences'].append(line['sentences'][ix].replace('<|beginofimage|>', ''))
77
+ extracted_content['images'].append(line['images'][ix])
78
+ all_content.append(extracted_content)
79
+
80
+ return all_content
81
+
82
+
83
+ def read_seed_content_from_folders(base_path):
84
+ """
85
+ Reads sentences from text.txt and image paths from subfolders named val_x.
86
+
87
+ Args:
88
+ base_path (str): Path to the main folder containing subfolders val_0 to val_179.
89
+
90
+ Returns:
91
+ list of dict: Each dictionary contains 'sentences' and 'images' from each subfolder.
92
+ """
93
+ contents = []
94
+
95
+ # Iterate over each possible subfolder val_0 to val_179
96
+ for i in range(180): # 0 to 179 inclusive
97
+ folder_name = f"val_{i}"
98
+ folder_path = os.path.join(base_path, folder_name)
99
+
100
+ if os.path.exists(folder_path):
101
+ content_dict = {
102
+ "sentences": [],
103
+ "images": []
104
+ }
105
+
106
+ # Read sentences from text.txt
107
+ text_file_path = os.path.join(folder_path, 'text.txt')
108
+ if os.path.isfile(text_file_path):
109
+ with open(text_file_path, 'r') as file:
110
+ content_dict['sentences'] = file.read().splitlines()[:6]
111
+ content_dict['sentences'] = [s.replace('[INST]', '') for s in content_dict['sentences'] ]
112
+
113
+ # Collect paths for the images ori_01 to ori_06
114
+ for j in range(1, 7): # 1 to 6 inclusive
115
+ image_name = f"ori_0{j}.jpg" # Assuming the images are in .jpg format
116
+ image_path = os.path.join(folder_path, image_name)
117
+ if os.path.isfile(image_path):
118
+ content_dict['images'].append(image_path)
119
+
120
+ # Add the content dictionary to the list if it contains any images or sentences
121
+ if content_dict['sentences'] or content_dict['images']:
122
+ contents.append(content_dict)
123
+
124
+ return contents
125
+
126
+
127
+ def evaluate_models(assistant_a, assistant_b, instruction):
128
+ # Encode all images to base64
129
+ images_a_base64 = [encode_image(img_path) for img_path in assistant_a['images'][:5]]
130
+ images_b_base64 = [encode_image(img_path) for img_path in assistant_b['images'][:5]]
131
+
132
+ # Extract the stories from both assistants
133
+ story_a = assistant_a['sentences']
134
+ story_b = assistant_b['sentences']
135
+
136
+ messages = []
137
+ # A
138
+ messages.append(
139
+ {
140
+ "role": "user",
141
+ "content": [
142
+ {
143
+ "type": "text",
144
+ "text": "Story text from Assistant A: {}\n".format(story_a[:5])
145
+ }
146
+ ]
147
+ }
148
+ )
149
+ messages.append(
150
+ {
151
+ "role": "user",
152
+ "content": [
153
+ {
154
+ "type": "text",
155
+ "text": "Images from Assistant A are encoded in base64.\n"
156
+ }
157
+ ]
158
+ }
159
+ )
160
+ for img_a in images_a_base64:
161
+ messages.append({
162
+ "role": "user",
163
+ "content": [
164
+ {
165
+ "type": "image_url",
166
+ "image_url": {"url": f"data:image/jpeg;base64,{img_a}"}
167
+ }
168
+ ]
169
+ })
170
+
171
+ # B
172
+ messages.append(
173
+ {
174
+ "role": "user",
175
+ "content": [
176
+ {
177
+ "type": "text",
178
+ "text": "Story text from Assistant B: {}\n".format(story_b[:5])
179
+ }
180
+ ]
181
+ }
182
+ )
183
+ messages.append(
184
+ {
185
+ "role": "user",
186
+ "content": [
187
+ {
188
+ "type": "text",
189
+ "text": "Images from Assistant B are encoded in base64.\n"
190
+ }
191
+ ]
192
+ }
193
+ )
194
+ for img_b in images_b_base64:
195
+ messages.append({
196
+ "role": "user",
197
+ "content": [
198
+ {
199
+ "type": "image_url",
200
+ "image_url": {"url": f"data:image/jpeg;base64,{img_b}"}
201
+ }
202
+ ]
203
+ })
204
+
205
+ # INST
206
+ messages.append(
207
+ {
208
+ "role": "user",
209
+ "content": [
210
+ {
211
+ "type": "text",
212
+ "text": instruction
213
+ }
214
+ ]
215
+ }
216
+ )
217
+ # Combine stories and encoded images into the evaluation instruction
218
+ result = api_call(messages)
219
+ print(result)
220
+ return result
221
+
222
+ def main():
223
+ # read mm json
224
+ mm_contents = read_json_and_extract_content('/group/40034/shuaisyang/seed_project/StorySalon/llm_eval/mm_eval.json')
225
+ seed_contents = read_seed_content_from_folders('/group/40034/shuaisyang/seed_project/StorySalon/llm_eval/gen_george_len7')
226
+ assert len(mm_contents) == len(seed_contents)
227
+ mm_win = 0
228
+ seed_win = 0
229
+ tie = 0
230
+ error = []
231
+ for i in range(len(mm_contents)):
232
+ # for i in range(2):
233
+ mm = mm_contents[i]
234
+ seed = seed_contents[i]
235
+ judgment = evaluate_models(mm, seed, instruction)
236
+
237
+ if "[[A]]" in judgment:
238
+ mm_win += 1
239
+ elif "[[B]]" in judgment:
240
+ seed_win += 1
241
+ elif "[[C]]" in judgment:
242
+ tie += 1
243
+ else:
244
+ error.append([i, judgment])
245
+
246
+ with open('coherence.txt', 'w') as f:
247
+ f.write("mm:{}\nseed:{}\ntie:{}\nerror:{}".format(mm_win, seed_win, tie, error))
248
+
249
+ main()
src/eval/gpt_score_eval.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from openai import OpenAI
3
+ import ast
4
+ import time
5
+ import os
6
+ import base64
7
+ # from PIL import Image
8
+ import io
9
+ import re
10
+
11
+ client = OpenAI(
12
+ base_url="YOUR_URL",
13
+ api_key="YOUR_KEY",
14
+ )
15
+
16
+ style_instruction = "Please act as an impartial judge and evaluate the quality of the generation story contents provided by an AI assistant. Your job is to give a score out of 10. Your evaluation should consider the style consistency of the story images. Do not allow the length of the responses to influence your evaluation. Be as objective as possible. After providing your explanation, output your final score by strictly following this format: \"[[score]]\", such as \"[[7]]\"."
17
+
18
+ engage_instruction = "Please act as an impartial judge and evaluate the quality of the generation story contents provided by an AI assistant. Your job is to give a score out of 10. Your evaluation should consider the engaging level of the story. Do not allow the length of the responses to influence your evaluation. Be as objective as possible. After providing your explanation, output your final score by strictly following this format: \"[[score]]\", such as \"[[7]]\"."
19
+
20
+ coherence_instruction = "Please act as an impartial judge and evaluate the quality of the generation story contents provided by an AI assistant. Your job is to give a score out of 10. Your evaluation should consider the coherence of the generated story images and text. Do not allow the length of the responses to influence your evaluation. Be as objective as possible. After providing your explanation, output your final score by strictly following this format: \"[[score]]\", such as \"[[7]]\"."
21
+
22
+ def api_call(messages):
23
+ try_times = 0
24
+ while try_times < 3:
25
+ try:
26
+ chat_completion = client.chat.completions.create(
27
+ messages=messages,
28
+ model="gpt-4-turbo-2024-04-09", #"gpt-4-0125-preview", #"claude-3-opus-20240229", #"gpt-4-1106-preview",
29
+ max_tokens=4096,
30
+ temperature=0.3,
31
+ # stop=['<wait to execute>']
32
+ )
33
+ success = True
34
+ break
35
+ except Exception as e:
36
+ print(f"Error during API call: {e}")
37
+ time.sleep(15)
38
+ try_times += 1
39
+ success = False
40
+ if success:
41
+ cleaned_string = chat_completion.choices[0].message.content.strip()
42
+ return cleaned_string
43
+ else:
44
+ return None
45
+
46
+
47
+ def encode_image(image_path):
48
+ with open(image_path, "rb") as image_file:
49
+ return base64.b64encode(image_file.read()).decode("utf-8")
50
+
51
+
52
+ def read_json_and_extract_content(filepath):
53
+ """
54
+ Reads a JSON file and extracts sentences and images.
55
+
56
+ Args:
57
+ filepath (str): The path to the JSON file.
58
+
59
+ Returns:
60
+ dict: A dictionary with two keys 'sentences' and 'images', containing the respective content.
61
+ """
62
+ with open(filepath, 'r') as file:
63
+ data = json.load(file)
64
+
65
+ all_content = []
66
+ for line in data:
67
+ extracted_content = {
68
+ "sentences": [],
69
+ "images": []
70
+ }
71
+ # Matching sentences to their corresponding images using their indices
72
+ for ix in line['sentence_ixs']:
73
+ if ix == 0:
74
+ continue
75
+ extracted_content['sentences'].append(line['sentences'][ix].replace('<|beginofimage|>', ''))
76
+ extracted_content['images'].append(line['images'][ix])
77
+ all_content.append(extracted_content)
78
+
79
+ return all_content
80
+
81
+
82
+ def read_seed_content_from_folders(base_path):
83
+ """
84
+ Reads sentences from text.txt and image paths from subfolders named val_x.
85
+
86
+ Args:
87
+ base_path (str): Path to the main folder containing subfolders val_0 to val_179.
88
+
89
+ Returns:
90
+ list of dict: Each dictionary contains 'sentences' and 'images' from each subfolder.
91
+ """
92
+ contents = []
93
+
94
+ # Iterate over each possible subfolder val_0 to val_179
95
+ for i in range(180): # 0 to 179 inclusive
96
+ folder_name = f"val_{i}"
97
+ folder_path = os.path.join(base_path, folder_name)
98
+
99
+ if os.path.exists(folder_path):
100
+ content_dict = {
101
+ "sentences": [],
102
+ "images": []
103
+ }
104
+
105
+ # Read sentences from text.txt
106
+ text_file_path = os.path.join(folder_path, 'text.txt')
107
+ if os.path.isfile(text_file_path):
108
+ with open(text_file_path, 'r') as file:
109
+ content_dict['sentences'] = file.read().splitlines()[:6]
110
+ content_dict['sentences'] = [s.replace('[INST]', '') for s in content_dict['sentences'] ]
111
+
112
+ # Collect paths for the images ori_01 to ori_06
113
+ for j in range(1, 7): # 1 to 6 inclusive
114
+ image_name = f"ori_0{j}.jpg" # Assuming the images are in .jpg format
115
+ image_path = os.path.join(folder_path, image_name)
116
+ if os.path.isfile(image_path):
117
+ content_dict['images'].append(image_path)
118
+
119
+ # Add the content dictionary to the list if it contains any images or sentences
120
+ if content_dict['sentences'] or content_dict['images']:
121
+ contents.append(content_dict)
122
+
123
+ return contents
124
+
125
+
126
+ def evaluate_models(assistant_a, instruction):
127
+ print(assistant_a, instruction)
128
+ # Encode all images to base64
129
+ images_a_base64 = [encode_image(img_path) for img_path in assistant_a['images'][:5]]
130
+
131
+ # Extract the stories from both assistants
132
+ story_a = assistant_a['sentences']
133
+
134
+ messages = []
135
+ # A
136
+ messages.append(
137
+ {
138
+ "role": "user",
139
+ "content": [
140
+ {
141
+ "type": "text",
142
+ "text": "Story text from Assistant A: {}\n".format(story_a[:5])
143
+ }
144
+ ]
145
+ }
146
+ )
147
+ messages.append(
148
+ {
149
+ "role": "user",
150
+ "content": [
151
+ {
152
+ "type": "text",
153
+ "text": "Images are encoded in base64.\n"
154
+ }
155
+ ]
156
+ }
157
+ )
158
+ for img_a in images_a_base64:
159
+ messages.append({
160
+ "role": "user",
161
+ "content": [
162
+ {
163
+ "type": "image_url",
164
+ "image_url": {"url": f"data:image/jpeg;base64,{img_a}"}
165
+ }
166
+ ]
167
+ })
168
+
169
+ # INST
170
+ messages.append(
171
+ {
172
+ "role": "user",
173
+ "content": [
174
+ {
175
+ "type": "text",
176
+ "text": instruction
177
+ }
178
+ ]
179
+ }
180
+ )
181
+ # Combine stories and encoded images into the evaluation instruction
182
+ result = api_call(messages)
183
+ print(result)
184
+ return result
185
+
186
+ def find_number_in_string(input_string):
187
+ # Regular expression to find [[number]]
188
+ pattern = r'\[\[(\d+)\]\]'
189
+ match = re.search(pattern, input_string)
190
+
191
+ if match:
192
+ return int(match.group(1)) # Return the number as an integer
193
+ else:
194
+ return None # No match found
195
+
196
+
197
+ def main():
198
+ # read mm json
199
+ # mm_contents = read_json_and_extract_content('/group/40034/shuaisyang/seed_project/StorySalon/llm_eval/mm_eval.json')
200
+ seed_contents = read_seed_content_from_folders('/group/40034/shuaisyang/seed_project/StorySalon/llm_eval/gen_george')
201
+ # assert len(mm_contents) == len(seed_contents)
202
+ # mm_win = 0
203
+ seed_win = 0
204
+ # tie = 0
205
+
206
+ error = []
207
+ metrics = ['style', 'engaging', 'coherence']
208
+ for idx, ins in enumerate((style_instruction, engage_instruction, coherence_instruction)):
209
+ total_score = 0
210
+ scores = ''
211
+ for i in range(len(seed_contents)):
212
+ seed = seed_contents[i]
213
+ judgment = evaluate_models(seed, ins)
214
+ number_found = find_number_in_string(judgment)
215
+ scores += str(number_found) + '\n'
216
+ total_score += number_found
217
+
218
+ with open('result_{}.txt'.format(metrics[idx]), 'w') as f:
219
+ f.write("total:{}\navg:{}\nscores:{}".format(total_score, total_score/len(seed_contents), scores))
220
+
221
+
222
+ main()
src/inference/gen_george.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ import hydra
3
+ from omegaconf import OmegaConf
4
+ import torch
5
+ import os
6
+ import re
7
+ import pyrootutils
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ import json
10
+ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, EulerDiscreteScheduler
11
+
12
+ pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
13
+
14
+ BOI_TOKEN = '<img>'
15
+ EOI_TOKEN = '</img>'
16
+ IMG_TOKEN = '<img_{:05d}>'
17
+
18
+ device = 'cuda:0'
19
+ dtype = torch.float16
20
+ dtype_str = 'fp16'
21
+ num_img_in_tokens = 64
22
+ num_img_out_tokens = 64
23
+ instruction_prompt = '{instruction}'
24
+
25
+ tokenizer_cfg_path = 'configs/tokenizer/clm_llama_tokenizer.yaml'
26
+ image_transform_cfg_path = 'configs/processer/qwen_448_transform.yaml'
27
+ visual_encoder_cfg_path = 'configs/visual_tokenizer/qwen_vitg_448.yaml'
28
+
29
+ llm_cfg_path = 'configs/clm_models/llama2chat7b_lora.yaml'
30
+ agent_cfg_path = 'configs/clm_models/agent_7b_sft.yaml'
31
+
32
+ adapter_cfg_path = 'configs/detokenizer/detokenizer_sdxl_qwen_vit_adapted.yaml'
33
+ discrete_model_cfg_path = 'configs/discrete_model/discrete_identity.yaml'
34
+
35
+ diffusion_model_path = 'pretrained/stable-diffusion-xl-base-1.0'
36
+
37
+ save_dir = "output"
38
+
39
+ tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path)
40
+ tokenizer = hydra.utils.instantiate(tokenizer_cfg)
41
+
42
+ image_transform_cfg = OmegaConf.load(image_transform_cfg_path)
43
+ image_transform = hydra.utils.instantiate(image_transform_cfg)
44
+
45
+ visual_encoder_cfg = OmegaConf.load(visual_encoder_cfg_path)
46
+ visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
47
+ visual_encoder.eval().to(device, dtype=dtype)
48
+ print('Init visual encoder done')
49
+
50
+ llm_cfg = OmegaConf.load(llm_cfg_path)
51
+ llm = hydra.utils.instantiate(llm_cfg, torch_dtype=dtype_str)
52
+ print('Init llm done.')
53
+
54
+ agent_model_cfg = OmegaConf.load(agent_cfg_path)
55
+ agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm)
56
+
57
+ agent_model.eval().to(device, dtype=dtype)
58
+ print('Init agent model Done')
59
+
60
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(diffusion_model_path, subfolder="scheduler")
61
+ print('init vae')
62
+ vae = AutoencoderKL.from_pretrained(diffusion_model_path, subfolder="vae").to(device, dtype=dtype)
63
+ print('init unet')
64
+ unet = UNet2DConditionModel.from_pretrained(diffusion_model_path, subfolder="unet").to(device, dtype=dtype)
65
+
66
+ adapter_cfg = OmegaConf.load(adapter_cfg_path)
67
+ adapter = hydra.utils.instantiate(adapter_cfg, unet=unet).to(device, dtype=dtype).eval()
68
+ print('Init adapter done')
69
+
70
+ discrete_model_cfg = OmegaConf.load(discrete_model_cfg_path)
71
+ discrete_model = hydra.utils.instantiate(discrete_model_cfg).to(device).eval()
72
+ print('Init discrete model done')
73
+
74
+ adapter.init_pipe(vae=vae,
75
+ scheduler=noise_scheduler,
76
+ visual_encoder=visual_encoder,
77
+ image_transform=image_transform,
78
+ discrete_model=discrete_model,
79
+ dtype=dtype,
80
+ device=device)
81
+
82
+ print('Init adapter pipe done')
83
+ boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
84
+ eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
85
+
86
+
87
+ def read_jsonl_to_dict(filename):
88
+ data = []
89
+ with open(filename, 'r') as file:
90
+ for line in file:
91
+ # Each line is a valid JSON object
92
+ json_object = json.loads(line)
93
+ data.append(json_object)
94
+ return data
95
+
96
+
97
+ # data
98
+ filename = 'data/json/val.jsonl'
99
+ image_root = 'data/image/george_full'
100
+ data = read_jsonl_to_dict(filename)
101
+ image_paths = [
102
+ os.path.join(image_root, d['images'][0]) for d in data
103
+ ]
104
+ questions = [
105
+ d['captions'][0] for d in data
106
+ ]
107
+
108
+
109
+ # texts = [
110
+ # d['captions'][1:] for d in data
111
+ # ]
112
+
113
+
114
+ def add_subtitle(original_image, text):
115
+ # Calculate the size of the new image
116
+ text_height = 80 # Height of the black bar for the text
117
+ new_image_size = (original_image.width, original_image.height + text_height)
118
+
119
+ # Create a new image with a black background
120
+ new_image = Image.new("RGB", new_image_size, "black")
121
+ # Paste the original image onto the new image
122
+ new_image.paste(original_image, (0, 0))
123
+
124
+ # Prepare the new image for drawing
125
+ draw = ImageDraw.Draw(new_image)
126
+
127
+ # Specify the font size and font path
128
+ font_size = 14 # Adjust font size as needed
129
+ # font = ImageFont.truetype(font_path, font_size)
130
+
131
+ # Manually split the text into two lines
132
+ line1, line2 = text[:len(text) // 2], text[len(text) // 2:]
133
+
134
+ # Update the position for the first line of text to ensure both lines are centered vertically
135
+ text_position_line1 = (10, original_image.height + (text_height - font_size) // 2)
136
+
137
+ # Define the text color
138
+ text_color = "white"
139
+
140
+ # Add the first line of text to the new image
141
+ draw.text(text_position_line1, line1, fill=text_color)
142
+
143
+ # Adjust the position for the second line of text, based on the height of the first line
144
+ text_position_line2 = (10, text_position_line1[1] + font_size)
145
+
146
+ # Add the second line of text to the new image
147
+ draw.text(text_position_line2, line2, fill=text_color)
148
+
149
+ return new_image
150
+
151
+
152
+ for j in range(len(image_paths)):
153
+ image_path = image_paths[j]
154
+ question = questions[j]
155
+ image = Image.open(image_path).convert('RGB')
156
+
157
+ save_folder = '{}/val_{}'.format(save_dir, j)
158
+
159
+ os.makedirs(save_folder, exist_ok=True)
160
+
161
+ init_image = add_subtitle(image, question)
162
+ save_path = os.path.join(save_folder, '000start_image.jpg')
163
+ init_image.save(save_path)
164
+
165
+ agent_model.llm.base_model.model.use_kv_cache_head = False
166
+ image_tensor = image_transform(image).unsqueeze(0).to(device, dtype=dtype)
167
+
168
+ image_tokens = BOI_TOKEN + ''.join([IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN
169
+
170
+ prompt = instruction_prompt.format_map({'instruction': question + image_tokens})
171
+ print(prompt)
172
+ print('*' * 20)
173
+
174
+ input_ids = tokenizer.encode(prompt, add_special_tokens=False)
175
+ input_ids = [tokenizer.bos_token_id] + input_ids
176
+
177
+ boi_idx = input_ids.index(boi_token_id)
178
+ eoi_idx = input_ids.index(eoi_token_id)
179
+
180
+ input_ids = torch.tensor(input_ids).to(device, dtype=torch.long).unsqueeze(0)
181
+
182
+ ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
183
+
184
+ ids_cmp_mask[0, boi_idx + 1:eoi_idx] = True
185
+ embeds_cmp_mask = torch.tensor([True]).to(device, dtype=torch.bool)
186
+
187
+ with torch.no_grad():
188
+ image_embeds = visual_encoder(image_tensor)
189
+ output = agent_model.generate(tokenizer=tokenizer,
190
+ input_ids=input_ids,
191
+ image_embeds=image_embeds,
192
+ embeds_cmp_mask=embeds_cmp_mask,
193
+ ids_cmp_mask=ids_cmp_mask,
194
+ max_new_tokens=500,
195
+ num_img_gen_tokens=num_img_out_tokens)
196
+ text = re.sub(r'\s*<[^>]*>\s*', ' ', output['text']).strip()
197
+
198
+ with open("{}/text.txt".format(save_folder), 'a+') as text_file:
199
+ text_file.write(text + '\n')
200
+ with open("{}/token.txt".format(save_folder), 'a+') as token_file:
201
+ token_file.write("context token: {}\n".format(input_ids.shape))
202
+ print(output['text'])
203
+ print('*' * 20)
204
+
205
+ story_len = 25
206
+ window_size = 8
207
+ text_id = 1
208
+ while output['has_img_output'] and image_embeds.shape[0] < story_len:
209
+ image_embeds_gen = output['img_gen_feat']
210
+ images_gen = adapter.generate(image_embeds=output['img_gen_feat'], num_inference_steps=50)
211
+
212
+ name = '{:02d}.jpg'.format(text_id)
213
+ save_path = os.path.join(save_folder, name)
214
+
215
+ # Open the generated image
216
+ original_image = images_gen[0]
217
+ ori_path = os.path.join(save_folder, 'ori_{:02d}.jpg'.format(text_id))
218
+ original_image.save(ori_path)
219
+
220
+ new_image = add_subtitle(original_image, text)
221
+ # Save the modified image
222
+ new_image.save(save_path)
223
+
224
+ image_embeds = torch.cat((image_embeds, image_embeds_gen), dim=0)
225
+
226
+ # image_embeds = torch.cat((image_embeds, image_embeds_gen), dim=0)
227
+
228
+ if text_id >= story_len - 1:
229
+ break
230
+
231
+ prompt = prompt + text + image_tokens
232
+ text_id += 1
233
+
234
+ input_ids = tokenizer.encode(prompt, add_special_tokens=False)
235
+ while image_embeds.shape[0] > window_size:
236
+ eoi_prompt_idx = prompt.index(EOI_TOKEN)
237
+ prompt = prompt[eoi_prompt_idx + len(EOI_TOKEN) + len('[INST]'):]
238
+ image_embeds = image_embeds[1:]
239
+ input_ids = tokenizer.encode(prompt, add_special_tokens=False)
240
+
241
+ print(prompt)
242
+ print('*' * 20)
243
+
244
+ input_ids = [tokenizer.bos_token_id] + input_ids
245
+
246
+ boi_idx = torch.where(torch.tensor(input_ids) == boi_token_id)[0].tolist()
247
+ eoi_idx = torch.where(torch.tensor(input_ids) == eoi_token_id)[0].tolist()
248
+
249
+ input_ids = torch.tensor(input_ids).to(device, dtype=torch.long).unsqueeze(0)
250
+
251
+ ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
252
+
253
+ for i in range(image_embeds.shape[0]):
254
+ ids_cmp_mask[0, boi_idx[i] + 1:eoi_idx[i]] = True
255
+ embeds_cmp_mask = torch.tensor([True] * image_embeds.shape[0]).to(device, dtype=torch.bool)
256
+
257
+ output = agent_model.generate(tokenizer=tokenizer,
258
+ input_ids=input_ids,
259
+ image_embeds=image_embeds,
260
+ embeds_cmp_mask=embeds_cmp_mask,
261
+ ids_cmp_mask=ids_cmp_mask,
262
+ max_new_tokens=500,
263
+ num_img_gen_tokens=num_img_out_tokens)
264
+ text = re.sub(r'\s*<[^>]*>\s*', ' ', output['text']).strip()
265
+ print(output['text'])
266
+ print('*' * 20)
267
+ with open("{}/text.txt".format(save_folder), 'a+') as text_file:
268
+ text_file.write(text + '\n')
269
+ with open("{}/token.txt".format(save_folder), 'a+') as token_file:
270
+ token_file.write("context token: {}\n".format(input_ids.shape))
src/inference/vis_george_sink.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ from omegaconf import OmegaConf
3
+ import torch
4
+ import os
5
+ import re
6
+ import pyrootutils
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ import json
9
+ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, EulerDiscreteScheduler
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ from collections import Counter
13
+ import time
14
+
15
+ pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
16
+
17
+ BOI_TOKEN = '<img>'
18
+ EOI_TOKEN = '</img>'
19
+ IMG_TOKEN = '<img_{:05d}>'
20
+
21
+ device = 'cuda:0'
22
+ dtype = torch.float16
23
+ dtype_str = 'fp16'
24
+ num_img_in_tokens = 64
25
+ num_img_out_tokens = 64
26
+ instruction_prompt = '{instruction}'
27
+
28
+ tokenizer_cfg_path = 'configs/tokenizer/clm_llama_tokenizer.yaml'
29
+ image_transform_cfg_path = 'configs/processer/qwen_448_transform.yaml'
30
+ visual_encoder_cfg_path = 'configs/visual_tokenizer/qwen_vitg_448.yaml'
31
+
32
+ llm_cfg_path = 'configs/clm_models/llama2chat7b_lora.yaml'
33
+ agent_cfg_path = 'configs/clm_models/agent_7b_sft.yaml'
34
+
35
+ adapter_cfg_path = 'configs/detokenizer/detokenizer_sdxl_qwen_vit_adapted.yaml'
36
+ discrete_model_cfg_path = 'configs/discrete_model/discrete_identity.yaml'
37
+
38
+ diffusion_model_path = 'pretrained/stable-diffusion-xl-base-1.0'
39
+
40
+ save_dir = "output"
41
+
42
+ cache_mode = 'img_head_tail'
43
+ # init
44
+ tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path)
45
+ tokenizer = hydra.utils.instantiate(tokenizer_cfg)
46
+
47
+ image_transform_cfg = OmegaConf.load(image_transform_cfg_path)
48
+ image_transform = hydra.utils.instantiate(image_transform_cfg)
49
+
50
+ visual_encoder_cfg = OmegaConf.load(visual_encoder_cfg_path)
51
+ visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
52
+ visual_encoder.eval().to(device, dtype=dtype)
53
+ print('Init visual encoder done')
54
+
55
+ llm_cfg = OmegaConf.load(llm_cfg_path)
56
+ llm = hydra.utils.instantiate(llm_cfg, torch_dtype=dtype_str)
57
+ print('Init llm done.')
58
+
59
+ agent_model_cfg = OmegaConf.load(agent_cfg_path)
60
+ agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm)
61
+
62
+ agent_model.eval().to(device, dtype=dtype)
63
+ print('Init agent model Done')
64
+
65
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(diffusion_model_path, subfolder="scheduler")
66
+ print('init vae')
67
+ vae = AutoencoderKL.from_pretrained(diffusion_model_path, subfolder="vae").to(device, dtype=dtype)
68
+ print('init unet')
69
+ unet = UNet2DConditionModel.from_pretrained(diffusion_model_path, subfolder="unet").to(device, dtype=dtype)
70
+
71
+ adapter_cfg = OmegaConf.load(adapter_cfg_path)
72
+ adapter = hydra.utils.instantiate(adapter_cfg, unet=unet).to(device, dtype=dtype).eval()
73
+ print('Init adapter done')
74
+
75
+ discrete_model_cfg = OmegaConf.load(discrete_model_cfg_path)
76
+ discrete_model = hydra.utils.instantiate(discrete_model_cfg).to(device).eval()
77
+ print('Init discrete model done')
78
+
79
+ adapter.init_pipe(vae=vae,
80
+ scheduler=noise_scheduler,
81
+ visual_encoder=visual_encoder,
82
+ image_transform=image_transform,
83
+ discrete_model=discrete_model,
84
+ dtype=dtype,
85
+ device=device)
86
+
87
+ print('Init adapter pipe done')
88
+ boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
89
+ eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
90
+
91
+
92
+ def read_jsonl_to_dict(filename):
93
+ data = []
94
+ with open(filename, 'r') as file:
95
+ for line in file:
96
+ # Each line is a valid JSON object
97
+ json_object = json.loads(line)
98
+ data.append(json_object)
99
+ return data
100
+
101
+
102
+ # data
103
+ filename = 'data/json/val.jsonl'
104
+ image_root = 'data/image/george_full'
105
+ data = read_jsonl_to_dict(filename)
106
+ image_paths = [
107
+ os.path.join(image_root, d['images'][0]) for d in data
108
+ ]
109
+ starting_texts = [
110
+ d['captions'][0] for d in data
111
+ ]
112
+
113
+ texts = [
114
+ d['captions'][1:] for d in data
115
+ ]
116
+
117
+ def add_subtitle(original_image, text):
118
+ # Calculate the size of the new image
119
+ text_height = 80 # Height of the black bar for the text
120
+ new_image_size = (original_image.width, original_image.height + text_height)
121
+
122
+ # Create a new image with a black background
123
+ new_image = Image.new("RGB", new_image_size, "black")
124
+ # Paste the original image onto the new image
125
+ new_image.paste(original_image, (0, 0))
126
+
127
+ # Prepare the new image for drawing
128
+ draw = ImageDraw.Draw(new_image)
129
+
130
+ # Specify the font size and font path
131
+ font_size = 14 # Adjust font size as needed
132
+ # font = ImageFont.truetype(font_path, font_size)
133
+
134
+ # Manually split the text into two lines
135
+ line1, line2 = text[:len(text) // 2], text[len(text) // 2:]
136
+
137
+ # Update the position for the first line of text to ensure both lines are centered vertically
138
+ text_position_line1 = (10, original_image.height + (text_height - font_size) // 2)
139
+
140
+ # Define the text color
141
+ text_color = "white"
142
+
143
+ # Add the first line of text to the new image
144
+ draw.text(text_position_line1, line1, fill=text_color)
145
+
146
+ # Adjust the position for the second line of text, based on the height of the first line
147
+ text_position_line2 = (10, text_position_line1[1] + font_size)
148
+
149
+ # Add the second line of text to the new image
150
+ draw.text(text_position_line2, line2, fill=text_color)
151
+
152
+ return new_image
153
+
154
+
155
+
156
+ for j in range(len(image_paths)):
157
+ image_path = image_paths[j]
158
+ starting_text = starting_texts[j]
159
+ text_seq = texts[j]
160
+ image = Image.open(image_path).convert('RGB')
161
+
162
+ save_folder = '{}/val_{}'.format(save_dir, j)
163
+
164
+ os.makedirs(save_folder, exist_ok=True)
165
+
166
+ init_image = add_subtitle(image, starting_text)
167
+ save_path = os.path.join(save_folder, '000start_image.jpg')
168
+ init_image.save(save_path)
169
+
170
+ sink_kv_cache = []
171
+ agent_model.llm.base_model.model.kv_cache_head = None
172
+ agent_model.llm.base_model.model.past_key_values = None
173
+ agent_model.llm.base_model.model.use_kv_cache_head = False
174
+
175
+ image_tensor = image_transform(image).unsqueeze(0).to(device, dtype=dtype)
176
+
177
+ image_tokens = BOI_TOKEN + ''.join([IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN
178
+
179
+ text = text_seq[0]
180
+ prompt = instruction_prompt.format_map({'instruction': starting_text + image_tokens}) + text
181
+ print(prompt)
182
+ print('*' * 20)
183
+
184
+ input_ids = tokenizer.encode(prompt, add_special_tokens=False)
185
+ input_ids = [tokenizer.bos_token_id] + input_ids
186
+
187
+ boi_idx = input_ids.index(boi_token_id)
188
+ eoi_idx = input_ids.index(eoi_token_id)
189
+
190
+ input_ids = torch.tensor(input_ids).to(device, dtype=torch.long).unsqueeze(0)
191
+
192
+ ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
193
+
194
+ ids_cmp_mask[0, boi_idx + 1:eoi_idx] = True
195
+ embeds_cmp_mask = torch.tensor([True]).to(device, dtype=torch.bool)
196
+
197
+ with torch.no_grad():
198
+ image_embeds = visual_encoder(image_tensor)
199
+ left = 0
200
+ right = input_ids.shape[1]
201
+ output = agent_model.generate(tokenizer=tokenizer,
202
+ input_ids=input_ids,
203
+ image_embeds=image_embeds,
204
+ embeds_cmp_mask=embeds_cmp_mask,
205
+ ids_cmp_mask=ids_cmp_mask,
206
+ max_new_tokens=500,
207
+ num_img_gen_tokens=num_img_out_tokens,
208
+ )
209
+ with open("{}/text.txt".format(save_folder), 'a+') as text_file:
210
+ text_file.write(text + '\n')
211
+ with open("{}/token.txt".format(save_folder), 'a+') as token_file:
212
+ token_file.write("context token: {} boi_idx: {}\n".format(input_ids.shape, boi_idx))
213
+
214
+ story_len = 25
215
+ window_size = 8
216
+ text_id = 1
217
+ while output['has_img_output'] and image_embeds.shape[0] < story_len:
218
+ image_embeds_gen = output['img_gen_feat']
219
+ images_gen = adapter.generate(image_embeds=output['img_gen_feat'], num_inference_steps=50)
220
+
221
+ name = '{:02d}.jpg'.format(text_id)
222
+ save_path = os.path.join(save_folder, name)
223
+
224
+ # Open the generated image
225
+ original_image = images_gen[0]
226
+ ori_path = os.path.join(save_folder, 'ori_{:02d}.jpg'.format(text_id))
227
+ original_image.save(ori_path)
228
+
229
+ new_image = add_subtitle(original_image, text)
230
+ # Save the modified image
231
+ new_image.save(save_path)
232
+
233
+ image_embeds = torch.cat((image_embeds, image_embeds_gen), dim=0)
234
+
235
+ # next gen
236
+ text = text_seq[text_id]
237
+ text_id += 1
238
+
239
+ # image_embeds = torch.cat((image_embeds, image_embeds_gen), dim=0)
240
+ if text_id >= story_len - 1:
241
+ break
242
+
243
+ past_key_values = [[kv[:, :, :input_ids.shape[1], :] for kv in l] for l in output['past_key_values']]
244
+ agent_model.llm.base_model.model.kv_cache_head = input_ids.shape[1]
245
+
246
+ prompt = prompt + image_tokens + text
247
+ next_input_ids = tokenizer.encode(image_tokens + text, add_special_tokens=False)
248
+ next_input_ids = torch.tensor(next_input_ids).to(device, dtype=torch.long).unsqueeze(0)
249
+ input_ids = torch.cat((input_ids, next_input_ids), dim=1)
250
+ left = right
251
+ right = input_ids.shape[1]
252
+
253
+
254
+ while image_embeds.shape[0] > window_size:
255
+
256
+ eoi_prompt_idx = prompt.index(EOI_TOKEN)
257
+ prompt = prompt[eoi_prompt_idx + len(EOI_TOKEN) :]
258
+
259
+ boi_idx = torch.where(input_ids == boi_token_id)[1].tolist()
260
+ eoi_idx = torch.where(input_ids == eoi_token_id)[1].tolist()
261
+
262
+ image_embeds = image_embeds[1:]
263
+ input_ids = input_ids[:, eoi_idx[0]+1:]
264
+
265
+ # slice kv cache
266
+ if cache_mode == 'img_head_tail':
267
+ if len(sink_kv_cache) == 0:
268
+ sink_kv_cache = [
269
+ [
270
+ kv[:, :, :4, :] for kv in l
271
+ ] for l in past_key_values
272
+ ]
273
+ sink_kv_cache = [
274
+ [
275
+ torch.cat(
276
+ (sink_kv_cache[l_idx][kv_idx],
277
+ kv[:, :, boi_idx[0] - 4:boi_idx[0] + 8, :],
278
+ kv[:, :, eoi_idx[0] - 8:eoi_idx[0] + 4, :]),
279
+ dim=2
280
+ ) for kv_idx, kv in enumerate(l)
281
+ ] for l_idx, l in enumerate(past_key_values)
282
+ ]
283
+ past_key_values = [
284
+ [
285
+ torch.cat(
286
+ (sink_kv_cache[l_idx][kv_idx],
287
+ kv[:, :, eoi_idx[0] + sink_kv_cache[0][0].shape[2] + 1:, :]),
288
+ dim=2
289
+ ) for kv_idx, kv in enumerate(l)
290
+ ] for l_idx, l in enumerate(past_key_values)
291
+ ]
292
+ # slice Left right
293
+ agent_model.llm.base_model.model.kv_cache_head -= eoi_idx[0] + 1
294
+ left -= eoi_idx[0] + 1
295
+ right -= eoi_idx[0] + 1
296
+
297
+ print("prompt: {}".format(prompt))
298
+ print('*' * 20)
299
+
300
+ boi_idx = torch.where(input_ids == boi_token_id)[1].tolist()
301
+ eoi_idx = torch.where(input_ids == eoi_token_id)[1].tolist()
302
+
303
+ ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
304
+
305
+ for i in range(image_embeds.shape[0]):
306
+ ids_cmp_mask[0, boi_idx[i] + 1:eoi_idx[i]] = True
307
+ embeds_cmp_mask = torch.tensor([True] * image_embeds.shape[0]).to(device, dtype=torch.bool)
308
+
309
+ output = agent_model.generate(tokenizer=tokenizer,
310
+ input_ids=input_ids,
311
+ image_embeds=image_embeds,
312
+ embeds_cmp_mask=embeds_cmp_mask,
313
+ ids_cmp_mask=ids_cmp_mask,
314
+ max_new_tokens=500,
315
+ num_img_gen_tokens=num_img_out_tokens,
316
+ past_key_values=None)
317
+ with open("{}/text.txt".format(save_folder), 'a+') as text_file:
318
+ text_file.write(text + '\n')
319
+ with open("{}/token.txt".format(save_folder), 'a+') as token_file:
320
+ token_file.write("context token: {} boi_idx: {}\n".format(input_ids.shape, boi_idx))
src/models/__init__.py ADDED
File without changes
src/models/discrete_models.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pyrootutils
4
+ import torch.distributed as dist
5
+ import torch.nn.functional as F
6
+
7
+ pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
8
+ from src.train.dist_utils import concat_all_gather
9
+
10
+
11
+ def cosine_loss(rec, target):
12
+ target = target / target.norm(dim=-1, keepdim=True)
13
+ rec = rec / rec.norm(dim=-1, keepdim=True)
14
+ rec_loss = (1 - (target * rec).sum(-1)).mean()
15
+ return rec_loss
16
+
17
+
18
+ def contrastive_loss(image_feats, text_feats, logit_scale):
19
+ image_feats = image_feats.unsqueeze(1).contiguous()
20
+ image_feats_all = concat_all_gather(image_feats) # [batch_size*num_gpu, num_query_tokens, embed_dim]
21
+ text_feats_all = concat_all_gather(text_feats) # [batch_size*num_gpu, embed_dim]
22
+
23
+ sim_q2t = torch.matmul(image_feats.unsqueeze(1), text_feats_all.unsqueeze(-1)).squeeze()
24
+ # [batch_size, batch_size*num_gpu, num_query_tokens]
25
+
26
+ # image-text similarity: aggregate across all query tokens
27
+ # sim_i2t, _ = sim_q2t.max(-1)
28
+ # sim_i2t = sim_q2t.mean(-1)
29
+ sim_i2t = sim_q2t
30
+ sim_i2t = sim_i2t / logit_scale
31
+
32
+ # text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens]
33
+ sim_t2q = torch.matmul(text_feats.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)).squeeze()
34
+
35
+ # print(image_feats_all.shape, text_feat_all.shape, sim_q2t.shape, sim_t2q.shape)
36
+ # text-image similarity: aggregate across all query tokens
37
+ # sim_t2i, _ = sim_t2q.max(-1)
38
+ # sim_t2i = sim_t2q.mean(-1)
39
+ sim_t2i = sim_t2q
40
+ sim_t2i = sim_t2i / logit_scale # [batch_size, batch_size*num_gpu]
41
+
42
+ rank = dist.get_rank()
43
+ bs = image_feats.size(0)
44
+ targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(image_feats.device)
45
+
46
+ loss_itc = (F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) +
47
+ F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)) / 2
48
+
49
+ i2t_acc = (sim_i2t.argmax(-1) == targets).sum() / len(sim_i2t)
50
+ t2i_acc = (sim_t2i.argmax(-1) == targets).sum() / len(sim_t2i)
51
+
52
+ return loss_itc, i2t_acc, t2i_acc
53
+
54
+
55
+ class DiscreteModleOnlyDistill(nn.Module):
56
+
57
+ def __init__(self,
58
+ qformer,
59
+ quantizer,
60
+ distiller=None,
61
+ loss_type='cosine',
62
+ scale_commit_loss=1.0,
63
+ freeze_qformer=False) -> None:
64
+ super().__init__()
65
+ self.qformer = qformer
66
+ self.quantizer = quantizer
67
+ self.distiller = distiller
68
+ self.loss_type = loss_type
69
+ self.scale_commit_loss = scale_commit_loss
70
+
71
+ self.freeze_qformer = freeze_qformer
72
+
73
+ if freeze_qformer:
74
+ self.qformer.requires_grad_(False)
75
+
76
+ def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None):
77
+ if self.freeze_qformer:
78
+ with torch.no_grad():
79
+ qforemr_embeds = self.qformer(image_embeds=image_embeds)
80
+ else:
81
+ qforemr_embeds = self.qformer(image_embeds=image_embeds)
82
+
83
+ quantizer_output = self.quantizer(qforemr_embeds)
84
+ recon_embeds = self.distiller(quantizer_output['quant_embeds'])
85
+
86
+ if self.loss_type == 'cosine':
87
+ distill_loss = cosine_loss(recon_embeds, image_embeds)
88
+ else:
89
+ raise NotImplementedError
90
+
91
+ total_loss = distill_loss + self.scale_commit_loss * \
92
+ quantizer_output['commit_loss']
93
+
94
+ return {
95
+ 'total_loss': total_loss,
96
+ 'distill_loss': distill_loss,
97
+ 'commit_loss': quantizer_output['commit_loss'],
98
+ 'indices': quantizer_output['indices']
99
+ }
100
+
101
+ def encode_image_embeds(self, image_embeds):
102
+ qforemr_embeds = self.qformer(image_embeds=image_embeds)
103
+ quantizer_output = self.quantizer(qforemr_embeds)
104
+
105
+ output_embeds = quantizer_output['quant_embeds']
106
+ if self.distiller is not None:
107
+ output_embeds = self.distiller(output_embeds)
108
+ return output_embeds
109
+
110
+ @classmethod
111
+ def from_pretrained(cls, qformer, quantizer, distiller=None, pretrained_model_path=None, **kwargs):
112
+ model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, **kwargs)
113
+ if pretrained_model_path is not None:
114
+ ckpt = torch.load(pretrained_model_path, map_location='cpu')
115
+ missing, unexpected = model.load_state_dict(ckpt, strict=False)
116
+ print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
117
+ return model
118
+
119
+
120
+ class DiscreteModleIdentity(nn.Module):
121
+
122
+ def __init__(self) -> None:
123
+ super().__init__()
124
+ self.model = nn.Identity()
125
+
126
+ def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None):
127
+ return
128
+
129
+ def encode_image_embeds(self, image_embeds):
130
+ return self.model(image_embeds)
131
+
132
+
133
+ class DiscreteModleStageOneContrastive(nn.Module):
134
+
135
+ def __init__(self, qformer, quantizer=None, distiller=None, projection_dim=1024,
136
+ image_cls_token_type='last') -> None:
137
+ super().__init__()
138
+ self.qformer = qformer
139
+ self.quantizer = quantizer
140
+ self.distiller = distiller
141
+ self.image_cls_token_type = image_cls_token_type
142
+ self.logit_scale = nn.Parameter(0.07 * torch.ones([]))
143
+ self.image_proj = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
144
+ self.text_proj = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
145
+
146
+ def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None):
147
+ image_embeds = self.qformer(image_embeds=image_embeds)
148
+ if self.image_cls_token_type == 'last':
149
+ image_embeds = image_embeds[:, -1, :]
150
+ else:
151
+ raise NotImplementedError
152
+
153
+ text_embeds = self.qformer(input_ids=input_ids, text_attention_mask=text_attention_mask)
154
+ text_embeds = text_embeds[:, 0, :]
155
+
156
+ image_embeds = F.normalize(self.image_proj(image_embeds), dim=-1)
157
+ text_embeds = F.normalize(self.text_proj(text_embeds), dim=-1)
158
+
159
+ contrast_loss, i2t_acc, t2i_acc = contrastive_loss(image_feats=image_embeds,
160
+ text_feats=text_embeds,
161
+ logit_scale=self.logit_scale)
162
+
163
+ return {
164
+ 'total_loss': contrast_loss,
165
+ 'i2t_acc': i2t_acc,
166
+ 't2i_acc': t2i_acc,
167
+ }
168
+
169
+ def encode_image_embeds(self, image_embeds):
170
+ image_embeds = self.qformer(image_embeds=image_embeds)
171
+
172
+ return image_embeds
173
+
174
+ @classmethod
175
+ def from_pretrained(cls, qformer, quantizer, distiller=None, pretrained_model_path=None, **kwargs):
176
+ model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, **kwargs)
177
+ if pretrained_model_path is not None:
178
+ ckpt = torch.load(pretrained_model_path, map_location='cpu')
179
+ missing, unexpected = model.load_state_dict(ckpt, strict=False)
180
+ print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
181
+ return model
182
+
183
+
184
+ class DiscreteModleStageTwoContrastiveDistill(nn.Module):
185
+
186
+ def __init__(self,
187
+ qformer,
188
+ quantizer=None,
189
+ distiller=None,
190
+ contrast_head=None,
191
+ projection_dim=1024,
192
+ distill_loss_type='cosine',
193
+ freeze_qformer=True,
194
+ image_cls_token_type='last',
195
+ scale_commit_loss=1.0,
196
+ scale_contrast_loss=1.0,
197
+ scale_distill_loss=1.0) -> None:
198
+ super().__init__()
199
+ self.qformer = qformer
200
+ self.quantizer = quantizer
201
+ self.distiller = distiller
202
+ self.contrast_head = contrast_head
203
+ self.distill_loss_type = distill_loss_type
204
+ self.image_cls_token_type = image_cls_token_type
205
+ if self.contrast_head is not None:
206
+ self.logit_scale = nn.Parameter(0.07 * torch.ones([]))
207
+ self.image_proj = nn.Linear(contrast_head.perceiver.config.projection_dim, projection_dim, bias=False)
208
+ self.text_proj = nn.Linear(contrast_head.perceiver.config.projection_dim, projection_dim, bias=False)
209
+
210
+ self.freeze_qformer = freeze_qformer
211
+ if freeze_qformer:
212
+ self.qformer.requires_grad_(False)
213
+
214
+ self.scale_commit_loss = scale_commit_loss
215
+ self.scale_contrast_loss = scale_contrast_loss
216
+ self.scale_distill_loss = scale_distill_loss
217
+
218
+ def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None):
219
+ if self.freeze_qformer:
220
+ with torch.no_grad():
221
+ qforemr_embeds = self.qformer(image_embeds=image_embeds)
222
+ else:
223
+ qforemr_embeds = self.qformer(image_embeds=image_embeds)
224
+
225
+ quantizer_output = self.quantizer(qforemr_embeds)
226
+
227
+ output_state = {}
228
+ output_state['indices'] = quantizer_output['indices']
229
+ output_state['commit_loss'] = quantizer_output['commit_loss']
230
+ output_state['total_loss'] = self.scale_commit_loss * quantizer_output['commit_loss']
231
+ if self.distiller is not None:
232
+ recon_embeds = self.distiller(quantizer_output['quant_embeds'])
233
+
234
+ if self.distill_loss_type == 'cosine':
235
+ distill_loss = cosine_loss(recon_embeds, image_embeds)
236
+ else:
237
+ raise NotImplementedError
238
+
239
+ output_state['distill_loss'] = distill_loss
240
+ output_state['total_loss'] += self.scale_distill_loss * distill_loss
241
+
242
+ if self.contrast_head is not None:
243
+ text_embeds = self.qformer(input_ids=input_ids, text_attention_mask=text_attention_mask)
244
+ text_embeds = text_embeds[:, 0, :]
245
+
246
+ image_embeds = self.contrast_head(quantizer_output['quant_embeds'])
247
+ if self.image_cls_token_type == 'last':
248
+ image_embeds = image_embeds[:, -1, :]
249
+ else:
250
+ raise NotImplementedError
251
+
252
+ image_embeds = F.normalize(self.image_proj(image_embeds), dim=-1)
253
+ text_embeds = F.normalize(self.text_proj(text_embeds), dim=-1)
254
+
255
+ contrast_loss, i2t_acc, t2i_acc = contrastive_loss(image_feats=image_embeds,
256
+ text_feats=text_embeds,
257
+ logit_scale=self.logit_scale)
258
+ output_state['contrast_loss'] = contrast_loss
259
+ output_state['total_loss'] += self.scale_contrast_loss * contrast_loss
260
+ output_state['i2t_acc'] = i2t_acc
261
+ output_state['t2i_acc'] = t2i_acc
262
+
263
+ return output_state
264
+
265
+ def encode_image_embeds(self, image_embeds):
266
+ pass
267
+
268
+ @classmethod
269
+ def from_pretrained(cls, qformer, quantizer, distiller=None, contrast_head=None, pretrained_model_path=None,
270
+ **kwargs):
271
+ model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, contrast_head=contrast_head, **kwargs)
272
+ if pretrained_model_path is not None:
273
+ ckpt = torch.load(pretrained_model_path, map_location='cpu')
274
+ missing, unexpected = model.load_state_dict(ckpt, strict=False)
275
+ print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
276
+ return model
277
+
278
+
279
+ class DiscreteModleDistillWithDoubleContrastive(nn.Module):
280
+
281
+ def __init__(
282
+ self,
283
+ qformer,
284
+ quantizer=None,
285
+ distiller=None,
286
+ contrast_head=None,
287
+ projection_dim=1024,
288
+ distill_loss_type='cosine',
289
+ share_contrast_head=True, # share contrastive head with distiller
290
+ quantize_cls_token=False,
291
+ rec_qformer=False,
292
+ has_contrast=False,
293
+ freeze_qformer=False,
294
+ scale_commit_loss=1.0,
295
+ scale_contrast_loss=1.0,
296
+ scale_distill_loss=1.0) -> None:
297
+ super().__init__()
298
+ self.qformer = qformer
299
+ self.quantizer = quantizer
300
+ self.distiller = distiller
301
+ self.contrast_head = contrast_head
302
+ self.distill_loss_type = distill_loss_type
303
+ self.quantize_cls_token = quantize_cls_token
304
+
305
+ self.rec_qformer = rec_qformer
306
+ self.has_contrast = has_contrast
307
+
308
+ if freeze_qformer:
309
+ self.qformer.requires_grad_(False)
310
+ else:
311
+ self.logit_scale_qformer = nn.Parameter(0.07 * torch.ones([]))
312
+ self.image_proj_qformer = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
313
+ self.text_proj_qformer = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
314
+ self.cls_norm_qformer = nn.LayerNorm(qformer.perceiver.config.projection_dim)
315
+
316
+ if self.contrast_head is not None:
317
+ self.logit_scale_head = nn.Parameter(0.07 * torch.ones([]))
318
+ self.image_proj_head = nn.Linear(contrast_head.perceiver.config.projection_dim, projection_dim, bias=False)
319
+ self.text_proj_head = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
320
+ self.cls_norm_head = nn.LayerNorm(contrast_head.perceiver.config.projection_dim)
321
+
322
+ if share_contrast_head and distiller is not None:
323
+ self.logit_scale_head = nn.Parameter(0.07 * torch.ones([]))
324
+ self.image_proj_head = nn.Linear(distiller.perceiver.config.projection_dim, projection_dim, bias=False)
325
+ self.text_proj_head = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
326
+ self.cls_norm_head = nn.LayerNorm(distiller.perceiver.config.projection_dim)
327
+
328
+ self.scale_commit_loss = scale_commit_loss
329
+ self.scale_contrast_loss = scale_contrast_loss
330
+ self.scale_distill_loss = scale_distill_loss
331
+ self.share_contrast_head = share_contrast_head
332
+ self.freeze_qformer = freeze_qformer
333
+ assert int(self.share_contrast_head) + int(contrast_head is not None) <= 1
334
+
335
+ def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None):
336
+
337
+ if self.freeze_qformer:
338
+ with torch.no_grad():
339
+ qforemr_embeds = self.qformer(image_embeds=image_embeds)
340
+ else:
341
+ qforemr_embeds = self.qformer(image_embeds=image_embeds)
342
+ qforemr_cls_embeds = qforemr_embeds[:, -1, :]
343
+
344
+ if not self.quantize_cls_token:
345
+ qforemr_embeds = qforemr_embeds[:, :-1, :]
346
+
347
+ if self.has_contrast:
348
+ text_embeds = self.qformer(input_ids=input_ids, text_attention_mask=text_attention_mask)
349
+ text_cls_embeds = text_embeds[:, 0, :]
350
+
351
+ output_state = {}
352
+ output_state['total_loss'] = 0.0
353
+
354
+ if not self.freeze_qformer and self.has_contrast:
355
+ qforemr_cls_embeds = self.cls_norm_qformer(qforemr_cls_embeds)
356
+ qformer_image_embeds = F.normalize(self.image_proj_qformer(qforemr_cls_embeds), dim=-1)
357
+ qformer_text_embeds = F.normalize(self.text_proj_qformer(text_cls_embeds), dim=-1)
358
+
359
+ qformer_contrast_loss, \
360
+ qformer_i2t_acc, \
361
+ qformer_t2i_acc = contrastive_loss(image_feats=qformer_image_embeds,
362
+ text_feats=qformer_text_embeds,
363
+ logit_scale=self.logit_scale_qformer)
364
+ output_state['qformer_contrast_loss'] = qformer_contrast_loss
365
+ output_state['total_loss'] += self.scale_contrast_loss * qformer_contrast_loss
366
+ output_state['qformer_i2t_acc'] = qformer_i2t_acc
367
+ output_state['qformer_t2i_acc'] = qformer_t2i_acc
368
+
369
+ if self.quantizer is not None and self.distiller is not None:
370
+ quantizer_output = self.quantizer(qforemr_embeds)
371
+
372
+ recon_embeds = self.distiller(quantizer_output['quant_embeds'])
373
+ if self.share_contrast_head:
374
+ contrast_head_cls_embeds = recon_embeds[:, -1, :]
375
+ contrast_head_cls_embeds = self.cls_norm_head(contrast_head_cls_embeds)
376
+ recon_embeds = recon_embeds[:, :-1, :]
377
+ if self.contrast_head is not None:
378
+ contrast_head_embeds = self.contrast_head(quantizer_output['quant_embeds'])
379
+ contrast_head_cls_embeds = contrast_head_embeds[:, -1, :]
380
+ contrast_head_cls_embeds = self.cls_norm_head(contrast_head_cls_embeds)
381
+
382
+ output_state['indices'] = quantizer_output['indices']
383
+ output_state['commit_loss'] = quantizer_output['commit_loss']
384
+ output_state['total_loss'] += self.scale_commit_loss * quantizer_output['commit_loss']
385
+
386
+ if self.rec_qformer:
387
+ target_embeds = qforemr_embeds
388
+ else:
389
+ target_embeds = image_embeds
390
+
391
+ if self.distill_loss_type == 'cosine':
392
+ distill_loss = cosine_loss(recon_embeds, target_embeds)
393
+ else:
394
+ raise NotImplementedError
395
+
396
+ output_state['distill_loss'] = distill_loss
397
+ output_state['total_loss'] += self.scale_distill_loss * distill_loss
398
+
399
+ if self.contrast_head is not None or self.share_contrast_head:
400
+ head_image_embeds = F.normalize(self.image_proj_head(contrast_head_cls_embeds), dim=-1)
401
+ head_text_embeds = F.normalize(self.text_proj_head(text_cls_embeds), dim=-1)
402
+
403
+ head_contrast_loss, head_i2t_acc, head_t2i_acc = contrastive_loss(image_feats=head_image_embeds,
404
+ text_feats=head_text_embeds,
405
+ logit_scale=self.logit_scale_head)
406
+ output_state['head_contrast_loss'] = head_contrast_loss
407
+ output_state['total_loss'] += self.scale_contrast_loss * head_contrast_loss
408
+ output_state['head_i2t_acc'] = head_i2t_acc
409
+ output_state['head_t2i_acc'] = head_t2i_acc
410
+
411
+ return output_state
412
+
413
+ def encode_image_embeds(self, image_embeds):
414
+ qforemr_embeds = self.qformer(image_embeds=image_embeds)
415
+ return qforemr_embeds
416
+
417
+ @classmethod
418
+ def from_pretrained(cls, qformer, quantizer=None, distiller=None, contrast_head=None, pretrained_model_path=None,
419
+ **kwargs):
420
+ model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, contrast_head=contrast_head, **kwargs)
421
+ if pretrained_model_path is not None:
422
+ ckpt = torch.load(pretrained_model_path, map_location='cpu')
423
+ missing, unexpected = model.load_state_dict(ckpt, strict=False)
424
+ print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
425
+ return model
426
+
427
+ @classmethod
428
+ def from_pretrained_stage1_yuying(cls,
429
+ qformer,
430
+ quantizer=None,
431
+ distiller=None,
432
+ contrast_head=None,
433
+ pretrained_model_path=None,
434
+ **kwargs):
435
+ model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, contrast_head=contrast_head, **kwargs)
436
+ if pretrained_model_path is not None:
437
+ ckpt = torch.load(pretrained_model_path, map_location='cpu')
438
+ ckpt = ckpt['model']
439
+
440
+ new_ckpt = {}
441
+ new_ckpt['qformer.embed_module.query'] = ckpt['query_tokens'].squeeze(0)
442
+ new_ckpt['qformer.norm.weight'] = ckpt['ln_vision.weight']
443
+ new_ckpt['qformer.norm.bias'] = ckpt['ln_vision.bias']
444
+
445
+ for key in ckpt.keys():
446
+ if key.startswith('Qformer'):
447
+ new_key = key.replace('Qformer', 'qformer.perceiver')
448
+ new_ckpt[new_key] = ckpt[key]
449
+ del ckpt
450
+ missing, unexpected = model.load_state_dict(new_ckpt, strict=False)
451
+ print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
452
+ print(missing)
453
+ print(unexpected)
454
+ return model
src/models/qwen_visual.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import OrderedDict
7
+ import math
8
+ import requests
9
+ from io import BytesIO
10
+ from functools import partial
11
+ from PIL import Image
12
+ from typing import Callable, Optional, Sequence, Tuple, List
13
+ import numpy as np
14
+
15
+ import torch
16
+ from torch import nn
17
+ from torch.nn import functional as F
18
+ from torch.nn.init import trunc_normal_
19
+ from torchvision import transforms
20
+ from torchvision.transforms import InterpolationMode
21
+
22
+
23
+ def get_abs_pos(abs_pos, tgt_size):
24
+ # abs_pos: L, C
25
+ # tgt_size: M
26
+ # return: M, C
27
+ src_size = int(math.sqrt(abs_pos.size(0)))
28
+ tgt_size = int(math.sqrt(tgt_size))
29
+ dtype = abs_pos.dtype
30
+
31
+ if src_size != tgt_size:
32
+ return F.interpolate(
33
+ abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
34
+ size=(tgt_size, tgt_size),
35
+ mode="bicubic",
36
+ align_corners=False,
37
+ ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
38
+ else:
39
+ return abs_pos
40
+
41
+
42
+ # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
43
+
44
+
45
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
46
+ """
47
+ grid_size: int of the grid height and width
48
+ return:
49
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
50
+ """
51
+ grid_h = np.arange(grid_size, dtype=np.float32)
52
+ grid_w = np.arange(grid_size, dtype=np.float32)
53
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
54
+ grid = np.stack(grid, axis=0)
55
+
56
+ grid = grid.reshape([2, 1, grid_size, grid_size])
57
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
58
+ if cls_token:
59
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
60
+ return pos_embed
61
+
62
+
63
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
64
+ assert embed_dim % 2 == 0
65
+
66
+ # use half of dimensions to encode grid_h
67
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
68
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
69
+
70
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
71
+ return emb
72
+
73
+
74
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
75
+ """
76
+ embed_dim: output dimension for each position
77
+ pos: a list of positions to be encoded: size (M,)
78
+ out: (M, D)
79
+ """
80
+ assert embed_dim % 2 == 0
81
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
82
+ omega /= embed_dim / 2.
83
+ omega = 1. / 10000 ** omega # (D/2,)
84
+
85
+ pos = pos.reshape(-1) # (M,)
86
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
87
+
88
+ emb_sin = np.sin(out) # (M, D/2)
89
+ emb_cos = np.cos(out) # (M, D/2)
90
+
91
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
92
+ return emb
93
+
94
+
95
+ class Resampler(nn.Module):
96
+ """
97
+ A 2D perceiver-resampler network with one cross attention layers by
98
+ (grid_size**2) learnable queries and 2d sincos pos_emb
99
+ Outputs:
100
+ A tensor with the shape of (grid_size**2, embed_dim)
101
+ """
102
+
103
+ def __init__(self, grid_size, embed_dim, num_heads, kv_dim=None, norm_layer=nn.LayerNorm):
104
+ super().__init__()
105
+ self.num_queries = grid_size ** 2
106
+ self.embed_dim = embed_dim
107
+ self.num_heads = num_heads
108
+
109
+ self.pos_embed = nn.Parameter(torch.from_numpy(get_2d_sincos_pos_embed(embed_dim,
110
+ grid_size)).float()).requires_grad_(
111
+ False)
112
+
113
+ self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
114
+ trunc_normal_(self.query, std=.02)
115
+
116
+ if kv_dim is not None and kv_dim != embed_dim:
117
+ self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
118
+ self.out_dim = kv_dim
119
+ else:
120
+ self.kv_proj = nn.Identity()
121
+ self.out_dim = embed_dim
122
+
123
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads)
124
+ self.ln_q = norm_layer(embed_dim)
125
+ self.ln_kv = norm_layer(embed_dim)
126
+
127
+ self.apply(self._init_weights)
128
+
129
+ def _init_weights(self, m):
130
+ if isinstance(m, nn.Linear):
131
+ trunc_normal_(m.weight, std=.02)
132
+ if isinstance(m, nn.Linear) and m.bias is not None:
133
+ nn.init.constant_(m.bias, 0)
134
+ elif isinstance(m, nn.LayerNorm):
135
+ nn.init.constant_(m.bias, 0)
136
+ nn.init.constant_(m.weight, 1.0)
137
+
138
+ def forward(self, x, attn_mask=None):
139
+
140
+ pos_embed = get_abs_pos(self.pos_embed, x.size(1))
141
+
142
+ x = self.kv_proj(x)
143
+ x = self.ln_kv(x).permute(1, 0, 2)
144
+
145
+ N = x.shape[1]
146
+ q = self.ln_q(self.query)
147
+ out = \
148
+ self.attn(self._repeat(q, N) + self.pos_embed.unsqueeze(1), x + pos_embed.unsqueeze(1), x, attn_mask=attn_mask)[
149
+ 0]
150
+ return out.permute(1, 0, 2)
151
+
152
+ def _repeat(self, query, N: int):
153
+ return query.unsqueeze(1).repeat(1, N, 1)
154
+
155
+
156
+ class VisualAttention(nn.Module):
157
+ """self-attention layer class.
158
+
159
+ Self-attention layer takes input with size [s, b, h]
160
+ and returns output of the same size.
161
+ """
162
+
163
+ def __init__(self, embed_dim, num_heads, bias=True, kdim=None, vdim=None):
164
+ super(VisualAttention, self).__init__()
165
+ self.embed_dim = embed_dim
166
+ self.kdim = kdim if kdim is not None else embed_dim
167
+ self.vdim = vdim if vdim is not None else embed_dim
168
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
169
+
170
+ self.num_heads = num_heads
171
+
172
+ # Per attention head and per partition values.
173
+ assert embed_dim % num_heads == 0
174
+ self.hidden_size_per_attention_head = embed_dim // num_heads
175
+ self.num_attention_heads_per_partition = num_heads
176
+ self.hidden_size_per_partition = embed_dim
177
+
178
+ # Strided linear layer.
179
+ assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently'
180
+ self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
181
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
182
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
183
+
184
+ def forward(self, query, key, value, attn_mask=None):
185
+ # query/key/value: [sq, b, h]
186
+ sq, b, _ = query.size()
187
+
188
+ assert query is key, 'Only Support Self-Attention Currently'
189
+ sk = sq
190
+ mixed_x_layer = self.in_proj(query)
191
+
192
+ # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
193
+ new_tensor_shape = mixed_x_layer.size()[:-1] + \
194
+ (self.num_attention_heads_per_partition,
195
+ 3 * self.hidden_size_per_attention_head)
196
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
197
+
198
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
199
+ query_layer, key_layer, value_layer = mixed_x_layer.split(self.hidden_size_per_attention_head, dim=-1)
200
+
201
+ # [sq, b, np, hn] -> [sq, b * np, hn]
202
+ query_layer = query_layer.view(sq, b * self.num_attention_heads_per_partition,
203
+ self.hidden_size_per_attention_head).transpose(0, 1)
204
+ # [sk, b, np, hn] -> [sk, b * np, hn]
205
+ key_layer = key_layer.view(sk, b * self.num_attention_heads_per_partition,
206
+ self.hidden_size_per_attention_head).transpose(0, 1)
207
+
208
+ q_scaled = query_layer / self.norm_factor
209
+ if attn_mask is not None:
210
+ attention_probs = torch.baddbmm(attn_mask, q_scaled, key_layer.transpose(-2, -1))
211
+ else:
212
+ attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
213
+ attention_probs = attention_probs.softmax(dim=-1)
214
+
215
+ value_layer = value_layer.view(sk, b * self.num_attention_heads_per_partition,
216
+ self.hidden_size_per_attention_head).transpose(0, 1)
217
+
218
+ # matmul: [b * np, sq, hn]
219
+ context_layer = torch.bmm(attention_probs, value_layer)
220
+
221
+ # change view [b, np, sq, hn]
222
+ context_layer = context_layer.view(b, self.num_attention_heads_per_partition, sq,
223
+ self.hidden_size_per_attention_head)
224
+
225
+ # [b, np, sq, hn] --> [sq, b, np, hn]
226
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
227
+
228
+ # [sq, b, np, hn] --> [sq, b, hp]
229
+ new_context_layer_shape = context_layer.size()[:-2] + \
230
+ (self.hidden_size_per_partition,)
231
+ context_layer = context_layer.view(*new_context_layer_shape)
232
+
233
+ output = self.out_proj(context_layer)
234
+
235
+ return output
236
+
237
+
238
+ class VisualAttentionBlock(nn.Module):
239
+
240
+ def __init__(
241
+ self,
242
+ d_model: int,
243
+ n_head: int,
244
+ mlp_ratio: float = 4.0,
245
+ act_layer: Callable = nn.GELU,
246
+ norm_layer: Callable = nn.LayerNorm,
247
+ is_cross_attention: bool = False,
248
+ ):
249
+ super().__init__()
250
+
251
+ self.ln_1 = norm_layer(d_model)
252
+ if is_cross_attention:
253
+ self.ln_1_kv = norm_layer(d_model)
254
+
255
+ self.ln_2 = norm_layer(d_model)
256
+ mlp_width = int(d_model * mlp_ratio)
257
+ self.attn = VisualAttention(d_model, n_head)
258
+ self.mlp = nn.Sequential(
259
+ OrderedDict([("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()),
260
+ ("c_proj", nn.Linear(mlp_width, d_model))]))
261
+
262
+ def attention(
263
+ self,
264
+ q_x: torch.Tensor,
265
+ k_x: Optional[torch.Tensor] = None,
266
+ v_x: Optional[torch.Tensor] = None,
267
+ attn_mask: Optional[torch.Tensor] = None,
268
+ ):
269
+ k_x = k_x if k_x is not None else q_x
270
+ v_x = v_x if v_x is not None else q_x
271
+
272
+ attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
273
+ return self.attn(q_x, k_x, v_x, attn_mask=attn_mask)
274
+
275
+ def forward(
276
+ self,
277
+ q_x: torch.Tensor,
278
+ k_x: Optional[torch.Tensor] = None,
279
+ v_x: Optional[torch.Tensor] = None,
280
+ attn_mask: Optional[torch.Tensor] = None,
281
+ ):
282
+ k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
283
+ v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
284
+
285
+ x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
286
+ x = x + self.mlp(self.ln_2(x))
287
+ return x
288
+
289
+
290
+ class TransformerBlock(nn.Module):
291
+
292
+ def __init__(
293
+ self,
294
+ width: int,
295
+ layers: int,
296
+ heads: int,
297
+ mlp_ratio: float = 4.0,
298
+ act_layer: Callable = nn.GELU,
299
+ norm_layer: Callable = nn.LayerNorm,
300
+ ):
301
+ super().__init__()
302
+ self.width = width
303
+ self.layers = layers
304
+
305
+ self.resblocks = nn.ModuleList(
306
+ [VisualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer) for _ in
307
+ range(layers)])
308
+
309
+ def get_cast_dtype(self) -> torch.dtype:
310
+ return self.resblocks[0].mlp.c_fc.weight.dtype
311
+
312
+ def get_cast_device(self) -> torch.device:
313
+ return self.resblocks[0].mlp.c_fc.weight.device
314
+
315
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
316
+ for r in self.resblocks:
317
+ x = r(x, attn_mask=attn_mask)
318
+ return x
319
+
320
+
321
+ class VisionTransformerWithAttnPool(nn.Module):
322
+
323
+ def __init__(self,
324
+ image_size: int,
325
+ patch_size: int,
326
+ width: int,
327
+ layers: int,
328
+ heads: int,
329
+ mlp_ratio: float,
330
+ n_queries: int = 256,
331
+ output_dim: int = 512,
332
+ **kwargs):
333
+ super().__init__()
334
+ image_height, image_width = self.image_size = (image_size, image_size)
335
+ patch_height, patch_width = self.patch_size = (patch_size, patch_size)
336
+ self.grid_size = (image_height // patch_height, image_width // patch_width)
337
+ self.output_dim = output_dim
338
+
339
+ mean = (0.48145466, 0.4578275, 0.40821073)
340
+ std = (0.26862954, 0.26130258, 0.27577711)
341
+ self.image_transform = transforms.Compose([
342
+ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
343
+ transforms.ToTensor(),
344
+ transforms.Normalize(mean=mean, std=std),
345
+ ])
346
+
347
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
348
+
349
+ # class embeddings and positional embeddings
350
+ scale = width ** -0.5
351
+ self.positional_embedding = nn.Parameter(scale * torch.randn(256, width))
352
+
353
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
354
+ act_layer = nn.GELU
355
+
356
+ self.ln_pre = norm_layer(width)
357
+ self.transformer = TransformerBlock(
358
+ width,
359
+ layers,
360
+ heads,
361
+ mlp_ratio,
362
+ act_layer=act_layer,
363
+ norm_layer=norm_layer,
364
+ )
365
+
366
+ self.attn_pool = Resampler(
367
+ grid_size=int(math.sqrt(n_queries)),
368
+ embed_dim=output_dim,
369
+ num_heads=output_dim // 128,
370
+ kv_dim=width,
371
+ norm_layer=norm_layer,
372
+ )
373
+ self.ln_post = norm_layer(output_dim)
374
+ self.proj = nn.Parameter((output_dim ** -0.5) * torch.randn(output_dim, output_dim))
375
+
376
+ def forward(self, x: torch.Tensor):
377
+ x = x.to(
378
+ dtype=self.transformer.get_cast_dtype(),
379
+ device=self.transformer.get_cast_device(),
380
+ )
381
+ # to patches
382
+ x = self.conv1(x) # shape = [*, width, grid, grid]
383
+ # shape = [*, width, grid ** 2]
384
+ x = x.reshape(x.shape[0], x.shape[1], -1)
385
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
386
+
387
+ x = x + get_abs_pos(self.positional_embedding, x.size(1))
388
+
389
+ x = self.ln_pre(x)
390
+
391
+ x = x.permute(1, 0, 2) # NLD -> LND
392
+ x = self.transformer(x)
393
+ x = x.permute(1, 0, 2) # LND -> NLD
394
+
395
+ x = self.attn_pool(x)
396
+ x = self.ln_post(x)
397
+ x = x @ self.proj
398
+
399
+ return x
400
+
401
+ def encode(self, image_paths: List[str]):
402
+ images = []
403
+ for image_path in image_paths:
404
+ if image_path.startswith("http://") or image_path.startswith("https://"):
405
+ image = Image.open(requests.get(image_path, stream=True).raw)
406
+ else:
407
+ image = Image.open(image_path)
408
+ image = image.convert("RGB")
409
+ images.append(self.image_transform(image))
410
+ images = torch.stack(images, dim=0)
411
+ return self(images)
412
+
413
+ @classmethod
414
+ def from_pretrained(cls, pretrained_model_path=None, **kawrgs):
415
+ model = cls(**kawrgs)
416
+ if pretrained_model_path is not None:
417
+ ckpt = torch.load(pretrained_model_path, map_location='cpu')
418
+ missing, unexpected = model.load_state_dict(ckpt, strict=False)
419
+ print('Load ckpt of qwen visual encoder')
420
+ print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
421
+
422
+ return model
423
+
424
+
425
+ class VisionTransformer(nn.Module):
426
+
427
+ def __init__(self,
428
+ image_size: int,
429
+ patch_size: int,
430
+ width: int,
431
+ layers: int,
432
+ heads: int,
433
+ mlp_ratio: float,
434
+ n_queries: int = 256,
435
+ output_dim: int = 512,
436
+ **kwargs):
437
+ super().__init__()
438
+ image_height, image_width = self.image_size = (image_size, image_size)
439
+ patch_height, patch_width = self.patch_size = (patch_size, patch_size)
440
+ self.grid_size = (image_height // patch_height, image_width // patch_width)
441
+ self.output_dim = output_dim
442
+
443
+ mean = (0.48145466, 0.4578275, 0.40821073)
444
+ std = (0.26862954, 0.26130258, 0.27577711)
445
+ self.image_transform = transforms.Compose([
446
+ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
447
+ transforms.ToTensor(),
448
+ transforms.Normalize(mean=mean, std=std),
449
+ ])
450
+
451
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
452
+
453
+ # class embeddings and positional embeddings
454
+ scale = width ** -0.5
455
+ self.positional_embedding = nn.Parameter(scale * torch.randn(256, width))
456
+
457
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
458
+ act_layer = nn.GELU
459
+
460
+ self.ln_pre = norm_layer(width)
461
+ self.transformer = TransformerBlock(
462
+ width,
463
+ layers,
464
+ heads,
465
+ mlp_ratio,
466
+ act_layer=act_layer,
467
+ norm_layer=norm_layer,
468
+ )
469
+
470
+ def forward(self, x: torch.Tensor):
471
+ x = x.to(
472
+ dtype=self.transformer.get_cast_dtype(),
473
+ device=self.transformer.get_cast_device(),
474
+ )
475
+ # to patches
476
+ x = self.conv1(x) # shape = [*, width, grid, grid]
477
+ # shape = [*, width, grid ** 2]
478
+ x = x.reshape(x.shape[0], x.shape[1], -1)
479
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
480
+
481
+ x = x + get_abs_pos(self.positional_embedding, x.size(1))
482
+
483
+ x = self.ln_pre(x)
484
+
485
+ x = x.permute(1, 0, 2) # NLD -> LND
486
+ x = self.transformer(x)
487
+ x = x.permute(1, 0, 2) # LND -> NLD
488
+
489
+ return x
490
+
491
+ def encode(self, image_paths: List[str]):
492
+ images = []
493
+ for image_path in image_paths:
494
+ if image_path.startswith("http://") or image_path.startswith("https://"):
495
+ image = Image.open(requests.get(image_path, stream=True).raw)
496
+ else:
497
+ image = Image.open(image_path)
498
+ image = image.convert("RGB")
499
+ images.append(self.image_transform(image))
500
+ images = torch.stack(images, dim=0)
501
+ return self(images)
src/models_clm/__init__.py ADDED
File without changes
src/models_clm/generation.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import LogitsProcessor
3
+
4
+ BOI_TOKEN = '<img>'
5
+ EOI_TOKEN = '</img>'
6
+ IMG_TOKEN = '<img_{:05d}>'
7
+
8
+
9
+ class AutoImageTokenGenerationProcessor(LogitsProcessor):
10
+
11
+ def __init__(self, tokenizer, num_img_gen_tokens=64) -> None:
12
+ super().__init__()
13
+ # self.boi_token_id = tokenizer.encode(BOI_TOKEN)[0]
14
+ # self.eoi_token_id = tokenizer.encode(EOI_TOKEN)[0]
15
+ img_all_token_str = ''.join([BOI_TOKEN] + [IMG_TOKEN.format(int(item))
16
+ for item in range(num_img_gen_tokens)] + [EOI_TOKEN])
17
+ self.img_ids_list = tokenizer.encode(img_all_token_str, add_special_tokens=False)
18
+
19
+ def __call__(self, input_ids, scores):
20
+ bz = input_ids.shape[0]
21
+ for i in range(bz):
22
+ cur_input_id = input_ids[i, -1].item()
23
+ if cur_input_id in self.img_ids_list[:-1]:
24
+
25
+ output_id = self.img_ids_list[self.img_ids_list.index(cur_input_id) + 1]
26
+ scores[i, ..., output_id] = scores[i, ...].max() + 10.
27
+ else:
28
+
29
+ scores[i, ..., torch.tensor(self.img_ids_list[1:]).to(dtype=torch.long)] = 0.0
30
+
31
+ return scores
src/models_clm/modeling_llama_4_35.py ADDED
@@ -0,0 +1,1236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ # coding=utf-8
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """ PyTorch LLaMA model."""
22
+ import math
23
+ import warnings
24
+ from typing import List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ from torch import nn
30
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
34
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \
35
+ SequenceClassifierOutputWithPast
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
38
+ from transformers.utils import (
39
+ add_start_docstrings,
40
+ add_start_docstrings_to_model_forward,
41
+ is_flash_attn_2_available,
42
+ logging,
43
+ replace_return_docstrings,
44
+ )
45
+ from transformers.utils.import_utils import is_torch_fx_available
46
+ from transformers.models.llama.configuration_llama import LlamaConfig
47
+
48
+ if is_flash_attn_2_available():
49
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
50
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
51
+
52
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
53
+ # It means that the function will not be traced through and simply appear as a node in the graph.
54
+ if is_torch_fx_available():
55
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+ _CONFIG_FOR_DOC = "LlamaConfig"
60
+
61
+
62
+ def _get_unpad_data(attention_mask):
63
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
64
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
65
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
66
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
67
+ return (
68
+ indices,
69
+ cu_seqlens,
70
+ max_seqlen_in_batch,
71
+ )
72
+
73
+
74
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
75
+ warnings.warn(
76
+ "Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils.AttentionMaskConverter._prepare_4d_attention_mask" # yapf: disable # noqa
77
+
78
+ )
79
+ return AttentionMaskConverter._prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
80
+
81
+
82
+ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device,
83
+ past_key_values_length: int = 0):
84
+ warnings.warn(
85
+ "Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttentionMaskConverter._make_causal_mask" # yapf: disable # noqa
86
+
87
+ )
88
+ return AttentionMaskConverter._make_causal_mask(input_ids_shape=input_ids_shape,
89
+ dtype=dtype,
90
+ device=device,
91
+ past_key_values_length=past_key_values_length)
92
+
93
+
94
+ class LlamaRMSNorm(nn.Module):
95
+
96
+ def __init__(self, hidden_size, eps=1e-6):
97
+ """
98
+ LlamaRMSNorm is equivalent to T5LayerNorm
99
+ """
100
+ super().__init__()
101
+ self.weight = nn.Parameter(torch.ones(hidden_size))
102
+ self.variance_epsilon = eps
103
+
104
+ def forward(self, hidden_states):
105
+ input_dtype = hidden_states.dtype
106
+ hidden_states = hidden_states.to(torch.float32)
107
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
108
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
109
+ return self.weight * hidden_states.to(input_dtype)
110
+
111
+
112
+ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
113
+
114
+
115
+ class LlamaRotaryEmbedding(nn.Module):
116
+
117
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
118
+ super().__init__()
119
+
120
+ self.dim = dim
121
+ self.max_position_embeddings = max_position_embeddings
122
+ self.base = base
123
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
124
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
125
+
126
+ # Build here to make `torch.jit.trace` work.
127
+ self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device,
128
+ dtype=torch.get_default_dtype())
129
+
130
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
131
+ self.max_seq_len_cached = seq_len
132
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
133
+
134
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
135
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
136
+ emb = torch.cat((freqs, freqs), dim=-1)
137
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
138
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
139
+
140
+ def forward(self, x, seq_len=None):
141
+ # x: [bs, num_attention_heads, seq_len, head_size]
142
+ if seq_len > self.max_seq_len_cached:
143
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
144
+
145
+ return (
146
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
147
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
148
+ )
149
+
150
+
151
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
152
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
153
+
154
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
155
+ self.scaling_factor = scaling_factor
156
+ super().__init__(dim, max_position_embeddings, base, device)
157
+
158
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
159
+ self.max_seq_len_cached = seq_len
160
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
161
+ t = t / self.scaling_factor
162
+
163
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
164
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
165
+ emb = torch.cat((freqs, freqs), dim=-1)
166
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
167
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
168
+
169
+
170
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
171
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
172
+
173
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
174
+ self.scaling_factor = scaling_factor
175
+ super().__init__(dim, max_position_embeddings, base, device)
176
+
177
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
178
+ self.max_seq_len_cached = seq_len
179
+
180
+ if seq_len > self.max_position_embeddings:
181
+ base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
182
+ (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
183
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
184
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
185
+
186
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
187
+
188
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
189
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
190
+ emb = torch.cat((freqs, freqs), dim=-1)
191
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
192
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
193
+
194
+
195
+ def rotate_half(x):
196
+ """Rotates half the hidden dims of the input."""
197
+ x1 = x[..., :x.shape[-1] // 2]
198
+ x2 = x[..., x.shape[-1] // 2:]
199
+ return torch.cat((-x2, x1), dim=-1)
200
+
201
+
202
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
203
+ """Applies Rotary Position Embedding to the query and key tensors.
204
+
205
+ Args:
206
+ q (`torch.Tensor`): The query tensor.
207
+ k (`torch.Tensor`): The key tensor.
208
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
209
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
210
+ position_ids (`torch.Tensor`):
211
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
212
+ used to pass offsetted position ids when working with a KV-cache.
213
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
214
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
215
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
216
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
217
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
218
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
219
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
220
+ Returns:
221
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
222
+ """
223
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
224
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
225
+ q_embed = (q * cos) + (rotate_half(q) * sin)
226
+ k_embed = (k * cos) + (rotate_half(k) * sin)
227
+ return q_embed, k_embed
228
+
229
+
230
+ class LlamaMLP(nn.Module):
231
+
232
+ def __init__(self, config):
233
+ super().__init__()
234
+ self.config = config
235
+ self.hidden_size = config.hidden_size
236
+ self.intermediate_size = config.intermediate_size
237
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
238
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
239
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
240
+ self.act_fn = ACT2FN[config.hidden_act]
241
+
242
+ def forward(self, x):
243
+ if self.config.pretraining_tp > 1:
244
+ slice = self.intermediate_size // self.config.pretraining_tp
245
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
246
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
247
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
248
+
249
+ gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
250
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
251
+
252
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
253
+ down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in
254
+ range(self.config.pretraining_tp)]
255
+ down_proj = sum(down_proj)
256
+ else:
257
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
258
+
259
+ return down_proj
260
+
261
+
262
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
263
+ """
264
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
265
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
266
+ """
267
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
268
+ if n_rep == 1:
269
+ return hidden_states
270
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
271
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
272
+
273
+
274
+ class LlamaAttention(nn.Module):
275
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
276
+
277
+ def __init__(self, config: LlamaConfig):
278
+ super().__init__()
279
+ self.config = config
280
+ self.hidden_size = config.hidden_size
281
+ self.num_heads = config.num_attention_heads
282
+ self.head_dim = self.hidden_size // self.num_heads
283
+ self.num_key_value_heads = config.num_key_value_heads
284
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
285
+ self.max_position_embeddings = config.max_position_embeddings
286
+ self.rope_theta = config.rope_theta
287
+ self.is_causal = True
288
+
289
+ if (self.head_dim * self.num_heads) != self.hidden_size:
290
+ raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
291
+ f" and `num_heads`: {self.num_heads}).")
292
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
293
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
294
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
295
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
296
+ self._init_rope()
297
+
298
+ def _init_rope(self):
299
+ if self.config.rope_scaling is None:
300
+ self.rotary_emb = LlamaRotaryEmbedding(
301
+ self.head_dim,
302
+ max_position_embeddings=self.max_position_embeddings,
303
+ base=self.rope_theta,
304
+ )
305
+ else:
306
+ scaling_type = self.config.rope_scaling["type"]
307
+ scaling_factor = self.config.rope_scaling["factor"]
308
+ if scaling_type == "linear":
309
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
310
+ self.head_dim,
311
+ max_position_embeddings=self.max_position_embeddings,
312
+ scaling_factor=scaling_factor,
313
+ base=self.rope_theta,
314
+ )
315
+ elif scaling_type == "dynamic":
316
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
317
+ self.head_dim,
318
+ max_position_embeddings=self.max_position_embeddings,
319
+ scaling_factor=scaling_factor,
320
+ base=self.rope_theta,
321
+ )
322
+ else:
323
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
324
+
325
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
326
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
327
+
328
+ def forward(
329
+ self,
330
+ hidden_states: torch.Tensor,
331
+ attention_mask: Optional[torch.Tensor] = None,
332
+ position_ids: Optional[torch.LongTensor] = None,
333
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
334
+ output_attentions: bool = False,
335
+ use_cache: bool = False,
336
+ **kwargs,
337
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
338
+ if "padding_mask" in kwargs:
339
+ warnings.warn(
340
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
341
+ )
342
+
343
+ bsz, q_len, _ = hidden_states.size()
344
+
345
+ if self.config.pretraining_tp > 1:
346
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
347
+ query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp,
348
+ dim=0)
349
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
350
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
351
+
352
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
353
+ query_states = torch.cat(query_states, dim=-1)
354
+
355
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
356
+ key_states = torch.cat(key_states, dim=-1)
357
+
358
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
359
+ value_states = torch.cat(value_states, dim=-1)
360
+
361
+ else:
362
+ query_states = self.q_proj(hidden_states)
363
+ key_states = self.k_proj(hidden_states)
364
+ value_states = self.v_proj(hidden_states)
365
+
366
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
367
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
368
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
369
+
370
+ kv_seq_len = key_states.shape[-2]
371
+ if past_key_value is not None:
372
+ kv_seq_len += past_key_value[0].shape[-2]
373
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
374
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
375
+
376
+ if past_key_value is not None:
377
+ # reuse k, v, self_attention
378
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
379
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
380
+
381
+ past_key_value = (key_states, value_states) if use_cache else None
382
+
383
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
384
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
385
+
386
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
387
+
388
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
389
+ raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
390
+ f" {attn_weights.size()}")
391
+
392
+ if attention_mask is not None:
393
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
394
+ raise ValueError(
395
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
396
+ attn_weights = attn_weights + attention_mask
397
+
398
+ # upcast attention to fp32
399
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
400
+ attn_output = torch.matmul(attn_weights, value_states)
401
+
402
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
403
+ raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
404
+ f" {attn_output.size()}")
405
+
406
+ attn_output = attn_output.transpose(1, 2).contiguous()
407
+
408
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
409
+
410
+ if self.config.pretraining_tp > 1:
411
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
412
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
413
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
414
+ else:
415
+ attn_output = self.o_proj(attn_output)
416
+
417
+ if not output_attentions:
418
+ attn_weights = None
419
+
420
+ return attn_output, attn_weights, past_key_value
421
+
422
+
423
+ class LlamaFlashAttention2(LlamaAttention):
424
+ """
425
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
426
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
427
+ flash attention and deal with padding tokens in case the input contains any of them.
428
+ """
429
+
430
+ def forward(
431
+ self,
432
+ hidden_states: torch.Tensor,
433
+ attention_mask: Optional[torch.LongTensor] = None,
434
+ position_ids: Optional[torch.LongTensor] = None,
435
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
436
+ output_attentions: bool = False,
437
+ use_cache: bool = False,
438
+ **kwargs,
439
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
440
+ # LlamaFlashAttention2 attention does not support output_attentions
441
+ if "padding_mask" in kwargs:
442
+ warnings.warn(
443
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
444
+ )
445
+
446
+ # overwrite attention_mask with padding_mask
447
+ attention_mask = kwargs.pop("padding_mask")
448
+
449
+ output_attentions = False
450
+
451
+ bsz, q_len, _ = hidden_states.size()
452
+
453
+ query_states = self.q_proj(hidden_states)
454
+ key_states = self.k_proj(hidden_states)
455
+ value_states = self.v_proj(hidden_states)
456
+
457
+ # Flash attention requires the input to have the shape
458
+ # batch_size x seq_length x head_dim x hidden_dim
459
+ # therefore we just need to keep the original shape
460
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
461
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
462
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
463
+
464
+ kv_seq_len = key_states.shape[-2]
465
+ if past_key_value is not None:
466
+ kv_seq_len += past_key_value[0].shape[-2]
467
+
468
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
469
+
470
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
471
+
472
+ if past_key_value is not None:
473
+ # reuse k, v, self_attention
474
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
475
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
476
+
477
+ past_key_value = (key_states, value_states) if use_cache else None
478
+
479
+ query_states = query_states.transpose(1, 2)
480
+ key_states = key_states.transpose(1, 2)
481
+ value_states = value_states.transpose(1, 2)
482
+
483
+ # TODO: llama does not have dropout in the config??
484
+ # It is recommended to use dropout with FA according to the docs
485
+ # when training.
486
+ dropout_rate = 0.0 # if not self.training else self.attn_dropout
487
+
488
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
489
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
490
+ # cast them back in the correct dtype just to be sure everything works as expected.
491
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
492
+ # in fp32. (LlamaRMSNorm handles it correctly)
493
+
494
+ input_dtype = query_states.dtype
495
+ if input_dtype == torch.float32:
496
+ # Handle the case where the model is quantized
497
+ if hasattr(self.config, "_pre_quantization_dtype"):
498
+ target_dtype = self.config._pre_quantization_dtype
499
+ else:
500
+ target_dtype = self.q_proj.weight.dtype
501
+
502
+ logger.warning_once(
503
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
504
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
505
+ f" {target_dtype}.")
506
+
507
+ query_states = query_states.to(target_dtype)
508
+ key_states = key_states.to(target_dtype)
509
+ value_states = value_states.to(target_dtype)
510
+
511
+ attn_output = self._flash_attention_forward(query_states,
512
+ key_states,
513
+ value_states,
514
+ attention_mask,
515
+ q_len,
516
+ dropout=dropout_rate)
517
+
518
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
519
+ attn_output = self.o_proj(attn_output)
520
+
521
+ if not output_attentions:
522
+ attn_weights = None
523
+
524
+ return attn_output, attn_weights, past_key_value
525
+
526
+ def _flash_attention_forward(self,
527
+ query_states,
528
+ key_states,
529
+ value_states,
530
+ attention_mask,
531
+ query_length,
532
+ dropout=0.0,
533
+ softmax_scale=None):
534
+ """
535
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
536
+ first unpad the input, then computes the attention scores and pad the final attention scores.
537
+
538
+ Args:
539
+ query_states (`torch.Tensor`):
540
+ Input query states to be passed to Flash Attention API
541
+ key_states (`torch.Tensor`):
542
+ Input key states to be passed to Flash Attention API
543
+ value_states (`torch.Tensor`):
544
+ Input value states to be passed to Flash Attention API
545
+ attention_mask (`torch.Tensor`):
546
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
547
+ position of padding tokens and 1 for the position of non-padding tokens.
548
+ dropout (`int`, *optional*):
549
+ Attention dropout
550
+ softmax_scale (`float`, *optional*):
551
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
552
+ """
553
+ # Contains at least one padding token in the sequence
554
+ if attention_mask is not None:
555
+ batch_size = query_states.shape[0]
556
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
557
+ query_states, key_states, value_states, attention_mask, query_length)
558
+
559
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
560
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
561
+
562
+ attn_output_unpad = flash_attn_varlen_func(
563
+ query_states,
564
+ key_states,
565
+ value_states,
566
+ cu_seqlens_q=cu_seqlens_q,
567
+ cu_seqlens_k=cu_seqlens_k,
568
+ max_seqlen_q=max_seqlen_in_batch_q,
569
+ max_seqlen_k=max_seqlen_in_batch_k,
570
+ dropout_p=dropout,
571
+ softmax_scale=softmax_scale,
572
+ causal=self.is_causal,
573
+ )
574
+
575
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
576
+ else:
577
+ attn_output = flash_attn_func(query_states,
578
+ key_states,
579
+ value_states,
580
+ dropout,
581
+ softmax_scale=softmax_scale,
582
+ causal=self.is_causal)
583
+
584
+ return attn_output
585
+
586
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
587
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
588
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
589
+
590
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
591
+ indices_k)
592
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
593
+ indices_k)
594
+ if query_length == kv_seq_len:
595
+ query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
596
+ indices_k)
597
+ cu_seqlens_q = cu_seqlens_k
598
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
599
+ indices_q = indices_k
600
+ elif query_length == 1:
601
+ max_seqlen_in_batch_q = 1
602
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32,
603
+ device=query_layer.device) # There is a memcpy here, that is very bad.
604
+ indices_q = cu_seqlens_q[:-1]
605
+ query_layer = query_layer.squeeze(1)
606
+ else:
607
+ # The -q_len: slice assumes left padding.
608
+ attention_mask = attention_mask[:, -query_length:]
609
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
610
+
611
+ return (
612
+ query_layer,
613
+ key_layer,
614
+ value_layer,
615
+ indices_q,
616
+ (cu_seqlens_q, cu_seqlens_k),
617
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
618
+ )
619
+
620
+
621
+ class LlamaDecoderLayer(nn.Module):
622
+
623
+ def __init__(self, config: LlamaConfig):
624
+ super().__init__()
625
+ self.hidden_size = config.hidden_size
626
+ self.self_attn = (LlamaAttention(
627
+ config=config) if not getattr(config, "_flash_attn_2_enabled", False) else LlamaFlashAttention2(
628
+ config=config))
629
+ self.mlp = LlamaMLP(config)
630
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
631
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
632
+
633
+ def forward(
634
+ self,
635
+ hidden_states: torch.Tensor,
636
+ attention_mask: Optional[torch.Tensor] = None,
637
+ position_ids: Optional[torch.LongTensor] = None,
638
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
639
+ output_attentions: Optional[bool] = False,
640
+ use_cache: Optional[bool] = False,
641
+ **kwargs,
642
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
643
+ """
644
+ Args:
645
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
646
+ attention_mask (`torch.FloatTensor`, *optional*):
647
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
648
+ query_sequence_length, key_sequence_length)` if default attention is used.
649
+ output_attentions (`bool`, *optional*):
650
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
651
+ returned tensors for more detail.
652
+ use_cache (`bool`, *optional*):
653
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
654
+ (see `past_key_values`).
655
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
656
+ """
657
+ if "padding_mask" in kwargs:
658
+ warnings.warn(
659
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
660
+ )
661
+
662
+ residual = hidden_states
663
+
664
+ hidden_states = self.input_layernorm(hidden_states)
665
+
666
+ # Self Attention
667
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
668
+ hidden_states=hidden_states,
669
+ attention_mask=attention_mask,
670
+ position_ids=position_ids,
671
+ past_key_value=past_key_value,
672
+ output_attentions=output_attentions,
673
+ use_cache=use_cache,
674
+ **kwargs,
675
+ )
676
+ hidden_states = residual + hidden_states
677
+
678
+ # Fully Connected
679
+ residual = hidden_states
680
+ hidden_states = self.post_attention_layernorm(hidden_states)
681
+ hidden_states = self.mlp(hidden_states)
682
+ hidden_states = residual + hidden_states
683
+
684
+ outputs = (hidden_states,)
685
+
686
+ if output_attentions:
687
+ outputs += (self_attn_weights,)
688
+
689
+ if use_cache:
690
+ outputs += (present_key_value,)
691
+
692
+ return outputs
693
+
694
+
695
+ LLAMA_START_DOCSTRING = r"""
696
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
697
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
698
+ etc.)
699
+
700
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
701
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
702
+ and behavior.
703
+
704
+ Parameters:
705
+ config ([`LlamaConfig`]):
706
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
707
+ load the weights associated with the model, only the configuration. Check out the
708
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
709
+ """
710
+
711
+
712
+ @add_start_docstrings(
713
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
714
+ LLAMA_START_DOCSTRING,
715
+ )
716
+ class LlamaPreTrainedModel(PreTrainedModel):
717
+ config_class = LlamaConfig
718
+ base_model_prefix = "model"
719
+ supports_gradient_checkpointing = True
720
+ _no_split_modules = ["LlamaDecoderLayer"]
721
+ _skip_keys_device_placement = "past_key_values"
722
+ _supports_flash_attn_2 = True
723
+
724
+ def _init_weights(self, module):
725
+ std = self.config.initializer_range
726
+ if isinstance(module, nn.Linear):
727
+ module.weight.data.normal_(mean=0.0, std=std)
728
+ if module.bias is not None:
729
+ module.bias.data.zero_()
730
+ elif isinstance(module, nn.Embedding):
731
+ module.weight.data.normal_(mean=0.0, std=std)
732
+ if module.padding_idx is not None:
733
+ module.weight.data[module.padding_idx].zero_()
734
+
735
+
736
+ LLAMA_INPUTS_DOCSTRING = r"""
737
+ Args:
738
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
739
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
740
+ it.
741
+
742
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
743
+ [`PreTrainedTokenizer.__call__`] for details.
744
+
745
+ [What are input IDs?](../glossary#input-ids)
746
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
747
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
748
+
749
+ - 1 for tokens that are **not masked**,
750
+ - 0 for tokens that are **masked**.
751
+
752
+ [What are attention masks?](../glossary#attention-mask)
753
+
754
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
755
+ [`PreTrainedTokenizer.__call__`] for details.
756
+
757
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
758
+ `past_key_values`).
759
+
760
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
761
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
762
+ information on the default strategy.
763
+
764
+ - 1 indicates the head is **not masked**,
765
+ - 0 indicates the head is **masked**.
766
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
767
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
768
+ config.n_positions - 1]`.
769
+
770
+ [What are position IDs?](../glossary#position-ids)
771
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
772
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
773
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
774
+ `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
775
+
776
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
777
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
778
+
779
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
780
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
781
+ of shape `(batch_size, sequence_length)`.
782
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
783
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
784
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
785
+ model's internal embedding lookup matrix.
786
+ use_cache (`bool`, *optional*):
787
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
788
+ `past_key_values`).
789
+ output_attentions (`bool`, *optional*):
790
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
791
+ tensors for more detail.
792
+ output_hidden_states (`bool`, *optional*):
793
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
794
+ more detail.
795
+ return_dict (`bool`, *optional*):
796
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
797
+ """
798
+
799
+
800
+ @add_start_docstrings(
801
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
802
+ LLAMA_START_DOCSTRING,
803
+ )
804
+ class LlamaModel(LlamaPreTrainedModel):
805
+ """
806
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
807
+
808
+ Args:
809
+ config: LlamaConfig
810
+ """
811
+
812
+ def __init__(self, config: LlamaConfig):
813
+ super().__init__(config)
814
+ self.padding_idx = config.pad_token_id
815
+ self.vocab_size = config.vocab_size
816
+
817
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
818
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
819
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
820
+
821
+ self.gradient_checkpointing = False
822
+ # Initialize weights and apply final processing
823
+ self.post_init()
824
+
825
+ def get_input_embeddings(self):
826
+ return self.embed_tokens
827
+
828
+ def set_input_embeddings(self, value):
829
+ self.embed_tokens = value
830
+
831
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
832
+ def forward(
833
+ self,
834
+ input_ids: torch.LongTensor = None,
835
+ attention_mask: Optional[torch.Tensor] = None,
836
+ position_ids: Optional[torch.LongTensor] = None,
837
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
838
+ inputs_embeds: Optional[torch.FloatTensor] = None,
839
+ use_cache: Optional[bool] = None,
840
+ output_attentions: Optional[bool] = None,
841
+ output_hidden_states: Optional[bool] = None,
842
+ return_dict: Optional[bool] = None,
843
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
844
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
845
+ output_hidden_states = (
846
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
847
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
848
+
849
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
850
+
851
+ # retrieve input_ids and inputs_embeds
852
+ if input_ids is not None and inputs_embeds is not None:
853
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
854
+ elif input_ids is not None:
855
+ batch_size, seq_length = input_ids.shape[:2]
856
+ elif inputs_embeds is not None:
857
+ batch_size, seq_length = inputs_embeds.shape[:2]
858
+ else:
859
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
860
+
861
+ past_key_values_length = 0
862
+ if past_key_values is not None:
863
+ past_key_values_length = past_key_values[0][0].shape[2]
864
+
865
+ if position_ids is None:
866
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
867
+ position_ids = torch.arange(past_key_values_length,
868
+ seq_length + past_key_values_length,
869
+ dtype=torch.long,
870
+ device=device)
871
+ position_ids = position_ids.unsqueeze(0)
872
+
873
+ if inputs_embeds is None:
874
+ inputs_embeds = self.embed_tokens(input_ids)
875
+
876
+ if getattr(self.config, "_flash_attn_2_enabled", False):
877
+ # 2d mask is passed through the layers
878
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
879
+ else:
880
+ # 4d mask is passed through the layers
881
+ attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds,
882
+ past_key_values_length)
883
+
884
+ # embed positions
885
+ hidden_states = inputs_embeds
886
+
887
+ if self.gradient_checkpointing and self.training:
888
+ if use_cache:
889
+ logger.warning_once(
890
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
891
+ use_cache = False
892
+
893
+ # decoder layers
894
+ all_hidden_states = () if output_hidden_states else None
895
+ all_self_attns = () if output_attentions else None
896
+ next_decoder_cache = () if use_cache else None
897
+
898
+ for idx, decoder_layer in enumerate(self.layers):
899
+ if output_hidden_states:
900
+ all_hidden_states += (hidden_states,)
901
+
902
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
903
+
904
+ if self.gradient_checkpointing and self.training:
905
+ layer_outputs = self._gradient_checkpointing_func(
906
+ decoder_layer.__call__,
907
+ hidden_states,
908
+ attention_mask,
909
+ position_ids,
910
+ past_key_value,
911
+ output_attentions,
912
+ use_cache,
913
+ )
914
+ else:
915
+ layer_outputs = decoder_layer(
916
+ hidden_states,
917
+ attention_mask=attention_mask,
918
+ position_ids=position_ids,
919
+ past_key_value=past_key_value,
920
+ output_attentions=output_attentions,
921
+ use_cache=use_cache,
922
+ )
923
+
924
+ hidden_states = layer_outputs[0]
925
+
926
+ if use_cache:
927
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
928
+
929
+ if output_attentions:
930
+ all_self_attns += (layer_outputs[1],)
931
+
932
+ hidden_states = self.norm(hidden_states)
933
+
934
+ # add hidden states from the last decoder layer
935
+ if output_hidden_states:
936
+ all_hidden_states += (hidden_states,)
937
+
938
+ next_cache = next_decoder_cache if use_cache else None
939
+ if not return_dict:
940
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
941
+ return BaseModelOutputWithPast(
942
+ last_hidden_state=hidden_states,
943
+ past_key_values=next_cache,
944
+ hidden_states=all_hidden_states,
945
+ attentions=all_self_attns,
946
+ )
947
+
948
+
949
+ class LlamaForCausalLM(LlamaPreTrainedModel):
950
+ _tied_weights_keys = ["lm_head.weight"]
951
+
952
+ def __init__(self, config):
953
+ super().__init__(config)
954
+ self.model = LlamaModel(config)
955
+ self.vocab_size = config.vocab_size
956
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
957
+
958
+ # Initialize weights and apply final processing
959
+ self.post_init()
960
+
961
+ def get_input_embeddings(self):
962
+ return self.model.embed_tokens
963
+
964
+ def set_input_embeddings(self, value):
965
+ self.model.embed_tokens = value
966
+
967
+ def get_output_embeddings(self):
968
+ return self.lm_head
969
+
970
+ def set_output_embeddings(self, new_embeddings):
971
+ self.lm_head = new_embeddings
972
+
973
+ def set_decoder(self, decoder):
974
+ self.model = decoder
975
+
976
+ def get_decoder(self):
977
+ return self.model
978
+
979
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
980
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
981
+ def forward(
982
+ self,
983
+ input_ids: torch.LongTensor = None,
984
+ attention_mask: Optional[torch.Tensor] = None,
985
+ position_ids: Optional[torch.LongTensor] = None,
986
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
987
+ inputs_embeds: Optional[torch.FloatTensor] = None,
988
+ labels: Optional[torch.LongTensor] = None,
989
+ use_cache: Optional[bool] = None,
990
+ output_attentions: Optional[bool] = None,
991
+ output_hidden_states: Optional[bool] = None,
992
+ return_dict: Optional[bool] = None,
993
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
994
+ r"""
995
+ Args:
996
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
997
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
998
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
999
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1000
+
1001
+ Returns:
1002
+
1003
+ Example:
1004
+
1005
+ ```python
1006
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1007
+
1008
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1009
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1010
+
1011
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1012
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1013
+
1014
+ >>> # Generate
1015
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1016
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1017
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1018
+ ```"""
1019
+
1020
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1021
+ output_hidden_states = (
1022
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
1023
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1024
+
1025
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1026
+ outputs = self.model(
1027
+ input_ids=input_ids,
1028
+ attention_mask=attention_mask,
1029
+ position_ids=position_ids,
1030
+ past_key_values=past_key_values,
1031
+ inputs_embeds=inputs_embeds,
1032
+ use_cache=use_cache,
1033
+ output_attentions=output_attentions,
1034
+ output_hidden_states=output_hidden_states,
1035
+ return_dict=return_dict,
1036
+ )
1037
+
1038
+ hidden_states = outputs[0]
1039
+ if self.config.pretraining_tp > 1:
1040
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1041
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1042
+ logits = torch.cat(logits, dim=-1)
1043
+ else:
1044
+ logits = self.lm_head(hidden_states)
1045
+ logits = logits.float()
1046
+
1047
+ loss = None
1048
+ if labels is not None:
1049
+ # Shift so that tokens < n predict n
1050
+ shift_logits = logits[..., :-1, :].contiguous()
1051
+ shift_labels = labels[..., 1:].contiguous()
1052
+ # Flatten the tokens
1053
+ loss_fct = CrossEntropyLoss()
1054
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1055
+ shift_labels = shift_labels.view(-1)
1056
+ # Enable model parallelism
1057
+ shift_labels = shift_labels.to(shift_logits.device)
1058
+ loss = loss_fct(shift_logits, shift_labels)
1059
+
1060
+ if not return_dict:
1061
+ output = (logits,) + outputs[1:]
1062
+ return (loss,) + output if loss is not None else output
1063
+
1064
+ return CausalLMOutputWithPast(
1065
+ loss=loss,
1066
+ logits=logits,
1067
+ past_key_values=outputs.past_key_values,
1068
+ hidden_states=outputs.hidden_states,
1069
+ attentions=outputs.attentions,
1070
+ )
1071
+
1072
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None,
1073
+ **kwargs):
1074
+ if past_key_values is not None:
1075
+ past_length = past_key_values[0][0].shape[2]
1076
+
1077
+ # Some generation methods already pass only the last input ID
1078
+ if input_ids.shape[1] > past_length:
1079
+ remove_prefix_length = past_length
1080
+ else:
1081
+ # Default to old behavior: keep only final ID
1082
+ remove_prefix_length = input_ids.shape[1] - 1
1083
+
1084
+ input_ids = input_ids[:, remove_prefix_length:]
1085
+
1086
+ position_ids = kwargs.get("position_ids", None)
1087
+ if attention_mask is not None and position_ids is None:
1088
+ # create position_ids on the fly for batch generation
1089
+ position_ids = attention_mask.long().cumsum(-1) - 1
1090
+ position_ids.masked_fill_(attention_mask == 0, 1)
1091
+ if past_key_values:
1092
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1093
+
1094
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1095
+ if inputs_embeds is not None and past_key_values is None:
1096
+ model_inputs = {"inputs_embeds": inputs_embeds}
1097
+ else:
1098
+ model_inputs = {"input_ids": input_ids}
1099
+
1100
+ model_inputs.update({
1101
+ "position_ids": position_ids,
1102
+ "past_key_values": past_key_values,
1103
+ "use_cache": kwargs.get("use_cache"),
1104
+ "attention_mask": attention_mask,
1105
+ })
1106
+ return model_inputs
1107
+
1108
+ @staticmethod
1109
+ def _reorder_cache(past_key_values, beam_idx):
1110
+ reordered_past = ()
1111
+ for layer_past in past_key_values:
1112
+ reordered_past += (
1113
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
1114
+ return reordered_past
1115
+
1116
+
1117
+ @add_start_docstrings(
1118
+ """
1119
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
1120
+
1121
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1122
+ (e.g. GPT-2) do.
1123
+
1124
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1125
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1126
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1127
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1128
+ each row of the batch).
1129
+ """,
1130
+ LLAMA_START_DOCSTRING,
1131
+ )
1132
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1133
+
1134
+ def __init__(self, config):
1135
+ super().__init__(config)
1136
+ self.num_labels = config.num_labels
1137
+ self.model = LlamaModel(config)
1138
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1139
+
1140
+ # Initialize weights and apply final processing
1141
+ self.post_init()
1142
+
1143
+ def get_input_embeddings(self):
1144
+ return self.model.embed_tokens
1145
+
1146
+ def set_input_embeddings(self, value):
1147
+ self.model.embed_tokens = value
1148
+
1149
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1150
+ def forward(
1151
+ self,
1152
+ input_ids: torch.LongTensor = None,
1153
+ attention_mask: Optional[torch.Tensor] = None,
1154
+ position_ids: Optional[torch.LongTensor] = None,
1155
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1156
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1157
+ labels: Optional[torch.LongTensor] = None,
1158
+ use_cache: Optional[bool] = None,
1159
+ output_attentions: Optional[bool] = None,
1160
+ output_hidden_states: Optional[bool] = None,
1161
+ return_dict: Optional[bool] = None,
1162
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1163
+ r"""
1164
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1165
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1166
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1167
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1168
+ """
1169
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1170
+
1171
+ transformer_outputs = self.model(
1172
+ input_ids,
1173
+ attention_mask=attention_mask,
1174
+ position_ids=position_ids,
1175
+ past_key_values=past_key_values,
1176
+ inputs_embeds=inputs_embeds,
1177
+ use_cache=use_cache,
1178
+ output_attentions=output_attentions,
1179
+ output_hidden_states=output_hidden_states,
1180
+ return_dict=return_dict,
1181
+ )
1182
+ hidden_states = transformer_outputs[0]
1183
+ logits = self.score(hidden_states)
1184
+
1185
+ if input_ids is not None:
1186
+ batch_size = input_ids.shape[0]
1187
+ else:
1188
+ batch_size = inputs_embeds.shape[0]
1189
+
1190
+ if self.config.pad_token_id is None and batch_size != 1:
1191
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1192
+ if self.config.pad_token_id is None:
1193
+ sequence_lengths = -1
1194
+ else:
1195
+ if input_ids is not None:
1196
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
1197
+ logits.device)
1198
+ else:
1199
+ sequence_lengths = -1
1200
+
1201
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1202
+
1203
+ loss = None
1204
+ if labels is not None:
1205
+ labels = labels.to(logits.device)
1206
+ if self.config.problem_type is None:
1207
+ if self.num_labels == 1:
1208
+ self.config.problem_type = "regression"
1209
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1210
+ self.config.problem_type = "single_label_classification"
1211
+ else:
1212
+ self.config.problem_type = "multi_label_classification"
1213
+
1214
+ if self.config.problem_type == "regression":
1215
+ loss_fct = MSELoss()
1216
+ if self.num_labels == 1:
1217
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1218
+ else:
1219
+ loss = loss_fct(pooled_logits, labels)
1220
+ elif self.config.problem_type == "single_label_classification":
1221
+ loss_fct = CrossEntropyLoss()
1222
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1223
+ elif self.config.problem_type == "multi_label_classification":
1224
+ loss_fct = BCEWithLogitsLoss()
1225
+ loss = loss_fct(pooled_logits, labels)
1226
+ if not return_dict:
1227
+ output = (pooled_logits,) + transformer_outputs[1:]
1228
+ return ((loss,) + output) if loss is not None else output
1229
+
1230
+ return SequenceClassifierOutputWithPast(
1231
+ loss=loss,
1232
+ logits=pooled_logits,
1233
+ past_key_values=transformer_outputs.past_key_values,
1234
+ hidden_states=transformer_outputs.hidden_states,
1235
+ attentions=transformer_outputs.attentions,
1236
+ )
src/models_clm/modeling_llama_xformer.py ADDED
@@ -0,0 +1,992 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ # coding=utf-8
3
+ # Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """ PyTorch LLaMA model."""
22
+ import math
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.modeling_outputs import (
32
+ BaseModelOutputWithPast,
33
+ CausalLMOutputWithPast,
34
+ SequenceClassifierOutputWithPast,
35
+ )
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.utils import (
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ logging,
41
+ replace_return_docstrings,
42
+ )
43
+ from transformers.models.llama.configuration_llama import LlamaConfig
44
+ import xformers.ops as xops
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ _CONFIG_FOR_DOC = "LlamaConfig"
49
+
50
+
51
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
52
+ def _make_causal_mask(
53
+ input_ids_shape: torch.Size,
54
+ dtype: torch.dtype,
55
+ device: torch.device,
56
+ past_key_values_length: int = 0,
57
+ ):
58
+ """
59
+ Make causal mask used for bi-directional self-attention.
60
+ """
61
+ bsz, tgt_len = input_ids_shape
62
+ mask = torch.full(
63
+ (tgt_len, tgt_len),
64
+ torch.tensor(torch.finfo(dtype).min, device=device),
65
+ device=device,
66
+ )
67
+ mask_cond = torch.arange(mask.size(-1), device=device)
68
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
69
+ mask = mask.to(dtype)
70
+
71
+ if past_key_values_length > 0:
72
+ mask = torch.cat(
73
+ [
74
+ torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device),
75
+ mask,
76
+ ],
77
+ dim=-1,
78
+ )
79
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
80
+
81
+
82
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
83
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
84
+ """
85
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
86
+ """
87
+ bsz, src_len = mask.size()
88
+ tgt_len = tgt_len if tgt_len is not None else src_len
89
+
90
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
91
+
92
+ inverted_mask = 1.0 - expanded_mask
93
+
94
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
95
+
96
+
97
+ class LlamaRMSNorm(nn.Module):
98
+
99
+ def __init__(self, hidden_size, eps=1e-6):
100
+ """
101
+ LlamaRMSNorm is equivalent to T5LayerNorm
102
+ """
103
+ super().__init__()
104
+ self.weight = nn.Parameter(torch.ones(hidden_size))
105
+ self.variance_epsilon = eps
106
+
107
+ def forward(self, hidden_states):
108
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
109
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
110
+
111
+ # convert into half-precision if necessary
112
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
113
+ hidden_states = hidden_states.to(self.weight.dtype)
114
+
115
+ return self.weight * hidden_states
116
+
117
+
118
+ class LlamaRotaryEmbedding(torch.nn.Module):
119
+
120
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
121
+ super().__init__()
122
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
123
+ self.register_buffer("inv_freq", inv_freq)
124
+
125
+ # Build here to make `torch.jit.trace` work.
126
+ self.max_seq_len_cached = max_position_embeddings
127
+ t = torch.arange(
128
+ self.max_seq_len_cached,
129
+ device=self.inv_freq.device,
130
+ dtype=self.inv_freq.dtype,
131
+ )
132
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
133
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
134
+ emb = torch.cat((freqs, freqs), dim=-1)
135
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
136
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
137
+
138
+ def forward(self, x, seq_len=None):
139
+ # x: [bs, num_attention_heads, seq_len, head_size]
140
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
141
+ if seq_len > self.max_seq_len_cached:
142
+ self.max_seq_len_cached = seq_len
143
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
144
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
145
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
146
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
147
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
148
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
149
+ return (
150
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
151
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
152
+ # self.cos_cached[:, :, :, ...].to(dtype=x.dtype),
153
+ # self.sin_cached[:, :, :, ...].to(dtype=x.dtype),
154
+
155
+ )
156
+
157
+
158
+ def rotate_half(x):
159
+ """Rotates half the hidden dims of the input."""
160
+ x1 = x[..., :x.shape[-1] // 2]
161
+ x2 = x[..., x.shape[-1] // 2:]
162
+ return torch.cat((-x2, x1), dim=-1)
163
+
164
+
165
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
166
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
167
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
168
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
169
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
170
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
171
+ q_embed = (q * cos) + (rotate_half(q) * sin)
172
+ k_embed = (k * cos) + (rotate_half(k) * sin)
173
+ return q_embed, k_embed
174
+
175
+
176
+ class LlamaMLP(nn.Module):
177
+
178
+ def __init__(
179
+ self,
180
+ hidden_size: int,
181
+ intermediate_size: int,
182
+ hidden_act: str,
183
+ ):
184
+ super().__init__()
185
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
186
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
187
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
188
+ self.act_fn = ACT2FN[hidden_act]
189
+
190
+ def forward(self, x):
191
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
192
+
193
+
194
+ class LlamaAttention(nn.Module):
195
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
196
+
197
+ def __init__(self, config: LlamaConfig):
198
+ super().__init__()
199
+ self.config = config
200
+ self.hidden_size = config.hidden_size
201
+ self.num_heads = config.num_attention_heads
202
+ self.head_dim = self.hidden_size // self.num_heads
203
+ self.max_position_embeddings = config.max_position_embeddings
204
+
205
+ if (self.head_dim * self.num_heads) != self.hidden_size:
206
+ raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
207
+ f" and `num_heads`: {self.num_heads}).")
208
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
209
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
210
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
211
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
212
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
213
+
214
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
215
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
216
+
217
+ def forward(
218
+ self,
219
+ hidden_states: torch.Tensor,
220
+ attention_mask: Optional[torch.Tensor] = None,
221
+ position_ids: Optional[torch.LongTensor] = None,
222
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
223
+ output_attentions: bool = False,
224
+ use_cache: bool = False,
225
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
226
+ bsz, q_len, _ = hidden_states.size()
227
+
228
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
229
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
230
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
231
+
232
+ kv_seq_len = key_states.shape[-2]
233
+ if past_key_value is not None:
234
+ kv_seq_len += past_key_value[0].shape[-2]
235
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
236
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
237
+ # [bsz, nh, t, hd]
238
+
239
+ if past_key_value is not None:
240
+ # reuse k, v, self_attention
241
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
242
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
243
+
244
+ past_key_value = (key_states, value_states) if use_cache else None
245
+
246
+ # attn_weights
247
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
248
+ if attention_mask is None:
249
+ def lower_triangular_from_bottom_right_mask(qlen, klen, device):
250
+ """
251
+ Create a lower triangular mask from the bottom-right corner of a matrix.
252
+
253
+ Args:
254
+ - qlen (int): Length of the query dimension.
255
+ - klen (int): Length of the key dimension.
256
+
257
+ Returns:
258
+ - torch.Tensor: A mask with shape (1, 1, qlen, klen) where the bottom-right triangle is True.
259
+ """
260
+ # Create a grid of indices where rows correspond to query indices and columns to key indices
261
+ q_indices = torch.arange(qlen - 1, -1, -1, device=device).unsqueeze(1) # Reverse the query indices
262
+ k_indices = torch.arange(klen - 1, -1, -1, device=device).unsqueeze(0) # Reverse the key indices
263
+
264
+ # Generate the mask where we compare query indices to key indices
265
+ # The condition q_indices >= k_indices creates a lower triangular mask from the top-left corner
266
+ # By reversing both indices, we get the lower triangular effect from the bottom-right
267
+ mask = q_indices >= k_indices
268
+
269
+ # Reshape to (1, 1, qlen, klen) as required
270
+ return mask.unsqueeze(0).unsqueeze(0)
271
+
272
+ attention_mask = lower_triangular_from_bottom_right_mask(attn_weights.shape[-2], attn_weights.shape[-1],
273
+ device=attn_weights.device)
274
+ attn_weights = attn_weights + attention_mask
275
+ attn_weights = attn_weights[:, 0, :, :]
276
+ # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
277
+
278
+ query_states = query_states.transpose(1, 2)
279
+ key_states = key_states.transpose(1, 2)
280
+ value_states = value_states.transpose(1, 2)
281
+ if self.training:
282
+ attn_output = xops.memory_efficient_attention(
283
+ query_states,
284
+ key_states,
285
+ value_states,
286
+ attn_bias=xops.LowerTriangularMask(),
287
+ )
288
+ else:
289
+ xops_attention_mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask()
290
+ attn_output = xops.memory_efficient_attention(
291
+ query_states,
292
+ key_states,
293
+ value_states,
294
+ attn_bias=xops_attention_mask,
295
+ )
296
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
297
+ attn_output = self.o_proj(attn_output)
298
+
299
+ if not output_attentions:
300
+ attn_weights = None
301
+ return attn_output, attn_weights, past_key_value
302
+
303
+
304
+ class LlamaDecoderLayer(nn.Module):
305
+
306
+ def __init__(self, config: LlamaConfig):
307
+ super().__init__()
308
+ self.hidden_size = config.hidden_size
309
+ self.self_attn = LlamaAttention(config=config)
310
+ self.mlp = LlamaMLP(
311
+ hidden_size=self.hidden_size,
312
+ intermediate_size=config.intermediate_size,
313
+ hidden_act=config.hidden_act,
314
+ )
315
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
316
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
317
+
318
+ def forward(
319
+ self,
320
+ hidden_states: torch.Tensor,
321
+ attention_mask: Optional[torch.Tensor] = None,
322
+ position_ids: Optional[torch.LongTensor] = None,
323
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
324
+ output_attentions: Optional[bool] = False,
325
+ use_cache: Optional[bool] = False,
326
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
327
+ """
328
+ Args:
329
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
330
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
331
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
332
+ output_attentions (`bool`, *optional*):
333
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
334
+ returned tensors for more detail.
335
+ use_cache (`bool`, *optional*):
336
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
337
+ (see `past_key_values`).
338
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
339
+ """
340
+
341
+ residual = hidden_states
342
+
343
+ hidden_states = self.input_layernorm(hidden_states)
344
+ # Self Attention
345
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
346
+ hidden_states=hidden_states,
347
+ attention_mask=attention_mask,
348
+ position_ids=position_ids,
349
+ past_key_value=past_key_value,
350
+ output_attentions=output_attentions,
351
+ use_cache=use_cache,
352
+ )
353
+ hidden_states = residual + hidden_states
354
+
355
+ # Fully Connected
356
+ residual = hidden_states
357
+ hidden_states = self.post_attention_layernorm(hidden_states)
358
+ hidden_states = self.mlp(hidden_states)
359
+ hidden_states = residual + hidden_states
360
+
361
+ outputs = (hidden_states,)
362
+
363
+ if output_attentions:
364
+ outputs += (self_attn_weights,)
365
+
366
+ if use_cache:
367
+ outputs += (present_key_value,)
368
+ return outputs
369
+
370
+
371
+ LLAMA_START_DOCSTRING = r"""
372
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
373
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
374
+ etc.)
375
+
376
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
377
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
378
+ and behavior.
379
+
380
+ Parameters:
381
+ config ([`LlamaConfig`]):
382
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
383
+ load the weights associated with the model, only the configuration. Check out the
384
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
385
+ """
386
+
387
+
388
+ @add_start_docstrings(
389
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
390
+ LLAMA_START_DOCSTRING,
391
+ )
392
+ class LlamaPreTrainedModel(PreTrainedModel):
393
+ config_class = LlamaConfig
394
+ base_model_prefix = "model"
395
+ supports_gradient_checkpointing = True
396
+ _no_split_modules = ["LlamaDecoderLayer"]
397
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
398
+
399
+ def _init_weights(self, module):
400
+ std = self.config.initializer_range
401
+ if isinstance(module, nn.Linear):
402
+ module.weight.data.normal_(mean=0.0, std=std)
403
+ if module.bias is not None:
404
+ module.bias.data.zero_()
405
+ elif isinstance(module, nn.Embedding):
406
+ module.weight.data.normal_(mean=0.0, std=std)
407
+ if module.padding_idx is not None:
408
+ module.weight.data[module.padding_idx].zero_()
409
+
410
+ def _set_gradient_checkpointing(self, module, value=False):
411
+ if isinstance(module, LlamaModel):
412
+ module.gradient_checkpointing = value
413
+
414
+
415
+ LLAMA_INPUTS_DOCSTRING = r"""
416
+ Args:
417
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
418
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
419
+ it.
420
+
421
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
422
+ [`PreTrainedTokenizer.__call__`] for details.
423
+
424
+ [What are input IDs?](../glossary#input-ids)
425
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
426
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
427
+
428
+ - 1 for tokens that are **not masked**,
429
+ - 0 for tokens that are **masked**.
430
+
431
+ [What are attention masks?](../glossary#attention-mask)
432
+
433
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
434
+ [`PreTrainedTokenizer.__call__`] for details.
435
+
436
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
437
+ `past_key_values`).
438
+
439
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
440
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
441
+ information on the default strategy.
442
+
443
+ - 1 indicates the head is **not masked**,
444
+ - 0 indicates the head is **masked**.
445
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
446
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
447
+ config.n_positions - 1]`.
448
+
449
+ [What are position IDs?](../glossary#position-ids)
450
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
451
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
452
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
453
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
454
+
455
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
456
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
457
+
458
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
459
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
460
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
461
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
462
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
463
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
464
+ model's internal embedding lookup matrix.
465
+ use_cache (`bool`, *optional*):
466
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
467
+ `past_key_values`).
468
+ output_attentions (`bool`, *optional*):
469
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
470
+ tensors for more detail.
471
+ output_hidden_states (`bool`, *optional*):
472
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
473
+ more detail.
474
+ return_dict (`bool`, *optional*):
475
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
476
+ """
477
+
478
+
479
+ @add_start_docstrings(
480
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
481
+ LLAMA_START_DOCSTRING,
482
+ )
483
+ class LlamaModel(LlamaPreTrainedModel):
484
+ """
485
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
486
+
487
+ Args:
488
+ config: LlamaConfig
489
+ """
490
+
491
+ def __init__(self, config: LlamaConfig):
492
+ super().__init__(config)
493
+ self.padding_idx = config.pad_token_id
494
+ self.vocab_size = config.vocab_size
495
+
496
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
497
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
498
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
499
+
500
+ self.gradient_checkpointing = False
501
+ # Initialize weights and apply final processing
502
+ self.post_init()
503
+
504
+ def get_input_embeddings(self):
505
+ return self.embed_tokens
506
+
507
+ def set_input_embeddings(self, value):
508
+ self.embed_tokens = value
509
+
510
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
511
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
512
+ # create causal mask
513
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
514
+ combined_attention_mask = None
515
+ if input_shape[-1] > 1:
516
+ combined_attention_mask = _make_causal_mask(
517
+ input_shape,
518
+ inputs_embeds.dtype,
519
+ device=inputs_embeds.device,
520
+ past_key_values_length=past_key_values_length,
521
+ )
522
+
523
+ if attention_mask is not None:
524
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
525
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype,
526
+ tgt_len=input_shape[-1]).to(inputs_embeds.device)
527
+ combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
528
+
529
+ return combined_attention_mask
530
+
531
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
532
+ def forward(
533
+ self,
534
+ input_ids: torch.LongTensor = None,
535
+ attention_mask: Optional[torch.Tensor] = None,
536
+ position_ids: Optional[torch.LongTensor] = None,
537
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
538
+ inputs_embeds: Optional[torch.FloatTensor] = None,
539
+ use_cache: Optional[bool] = None,
540
+ output_attentions: Optional[bool] = None,
541
+ output_hidden_states: Optional[bool] = None,
542
+ return_dict: Optional[bool] = None,
543
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
544
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
545
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
546
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
547
+
548
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
549
+
550
+ # retrieve input_ids and inputs_embeds
551
+ # if input_ids is not None and inputs_embeds is not None:
552
+ # raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
553
+ # elif input_ids is not None:
554
+ if input_ids is not None:
555
+ batch_size, seq_length = input_ids.shape
556
+ elif inputs_embeds is not None:
557
+ batch_size, seq_length, _ = inputs_embeds.shape
558
+ else:
559
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
560
+
561
+ seq_length_with_past = seq_length
562
+ past_key_values_length = 0
563
+
564
+ if past_key_values is not None:
565
+ past_key_values_length = past_key_values[0][0].shape[2]
566
+ seq_length_with_past = seq_length_with_past + past_key_values_length
567
+
568
+ if position_ids is None:
569
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
570
+ position_ids = torch.arange(
571
+ past_key_values_length,
572
+ seq_length + past_key_values_length,
573
+ dtype=torch.long,
574
+ device=device,
575
+ )
576
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
577
+ else:
578
+ position_ids = position_ids.view(-1, seq_length).long()
579
+
580
+ if inputs_embeds is None:
581
+ inputs_embeds = self.embed_tokens(input_ids)
582
+ # embed positions
583
+
584
+ # rm when use streaming
585
+ # if attention_mask is None:
586
+ # attention_mask = torch.ones(
587
+ # (batch_size, seq_length_with_past),
588
+ # dtype=torch.bool,
589
+ # device=inputs_embeds.device,
590
+ # )
591
+ attention_mask = self._prepare_decoder_attention_mask(
592
+ attention_mask,
593
+ (batch_size, seq_length),
594
+ inputs_embeds,
595
+ past_key_values_length,
596
+ )
597
+
598
+ hidden_states = inputs_embeds
599
+
600
+ if self.gradient_checkpointing and self.training:
601
+ if use_cache:
602
+ logger.warning_once(
603
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
604
+ use_cache = False
605
+
606
+ # decoder layers
607
+ all_hidden_states = () if output_hidden_states else None
608
+ all_self_attns = () if output_attentions else None
609
+ next_decoder_cache = () if use_cache else None
610
+
611
+ for idx, decoder_layer in enumerate(self.layers):
612
+ if output_hidden_states:
613
+ all_hidden_states += (hidden_states,)
614
+
615
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
616
+
617
+ if self.gradient_checkpointing and self.training:
618
+
619
+ def create_custom_forward(module):
620
+
621
+ def custom_forward(*inputs):
622
+ # None for past_key_value
623
+ return module(*inputs, output_attentions, None)
624
+
625
+ return custom_forward
626
+
627
+ layer_outputs = torch.utils.checkpoint.checkpoint(
628
+ create_custom_forward(decoder_layer),
629
+ hidden_states,
630
+ attention_mask,
631
+ position_ids,
632
+ None,
633
+ )
634
+ else:
635
+ layer_outputs = decoder_layer(
636
+ hidden_states,
637
+ attention_mask=attention_mask,
638
+ position_ids=position_ids,
639
+ past_key_value=past_key_value,
640
+ output_attentions=output_attentions,
641
+ use_cache=use_cache,
642
+ )
643
+
644
+ hidden_states = layer_outputs[0]
645
+
646
+ if use_cache:
647
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
648
+
649
+ if output_attentions:
650
+ all_self_attns += (layer_outputs[1],)
651
+
652
+ hidden_states = self.norm(hidden_states)
653
+
654
+ # add hidden states from the last decoder layer
655
+ if output_hidden_states:
656
+ all_hidden_states += (hidden_states,)
657
+
658
+ next_cache = next_decoder_cache if use_cache else None
659
+ if not return_dict:
660
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
661
+ return BaseModelOutputWithPast(
662
+ last_hidden_state=hidden_states,
663
+ past_key_values=next_cache,
664
+ hidden_states=all_hidden_states,
665
+ attentions=all_self_attns,
666
+ )
667
+
668
+
669
+ class LlamaForCausalLM(LlamaPreTrainedModel):
670
+
671
+ def __init__(self, config):
672
+ super().__init__(config)
673
+ self.model = LlamaModel(config)
674
+
675
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
676
+ self.past_key_values = None
677
+ self.kv_cache_head = None
678
+ self.use_kv_cache_head = True
679
+ # self.position_ids = None
680
+ # Initialize weights and apply final processing
681
+ self.post_init()
682
+
683
+ def get_input_embeddings(self):
684
+ return self.model.embed_tokens
685
+
686
+ def set_input_embeddings(self, value):
687
+ self.model.embed_tokens = value
688
+
689
+ def get_output_embeddings(self):
690
+ return self.lm_head
691
+
692
+ def set_output_embeddings(self, new_embeddings):
693
+ self.lm_head = new_embeddings
694
+
695
+ def set_decoder(self, decoder):
696
+ self.model = decoder
697
+
698
+ def get_decoder(self):
699
+ return self.model
700
+
701
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
702
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
703
+ def forward(
704
+ self,
705
+ input_ids: torch.LongTensor = None,
706
+ attention_mask: Optional[torch.Tensor] = None,
707
+ position_ids: Optional[torch.LongTensor] = None,
708
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
709
+ inputs_embeds: Optional[torch.FloatTensor] = None,
710
+ labels: Optional[torch.LongTensor] = None,
711
+ use_cache: Optional[bool] = None,
712
+ output_attentions: Optional[bool] = None,
713
+ output_hidden_states: Optional[bool] = None,
714
+ return_dict: Optional[bool] = None,
715
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
716
+ r"""
717
+ Args:
718
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
719
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
720
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
721
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
722
+
723
+ Returns:
724
+
725
+ Example:
726
+
727
+ ```python
728
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
729
+
730
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
731
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
732
+
733
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
734
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
735
+
736
+ >>> # Generate
737
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
738
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
739
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
740
+ ```"""
741
+
742
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
743
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
744
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
745
+
746
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
747
+ outputs = self.model(
748
+ input_ids=input_ids,
749
+ attention_mask=attention_mask,
750
+ position_ids=position_ids,
751
+ past_key_values=past_key_values,
752
+ inputs_embeds=inputs_embeds,
753
+ use_cache=use_cache,
754
+ output_attentions=output_attentions,
755
+ output_hidden_states=output_hidden_states,
756
+ return_dict=return_dict,
757
+ )
758
+ hidden_states = outputs[0]
759
+ logits = self.lm_head(hidden_states)
760
+
761
+ loss = None
762
+ if labels is not None:
763
+ # Shift so that tokens < n predict n
764
+ shift_logits = logits[..., :-1, :].contiguous()
765
+ shift_labels = labels[..., 1:].contiguous()
766
+ # Flatten the tokens
767
+ loss_fct = CrossEntropyLoss()
768
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
769
+ shift_labels = shift_labels.view(-1)
770
+ # Enable model parallelism
771
+ shift_labels = shift_labels.to(shift_logits.device)
772
+ loss = loss_fct(shift_logits, shift_labels)
773
+
774
+ if not return_dict:
775
+ output = (logits,) + outputs[1:]
776
+ return (loss,) + output if loss is not None else output
777
+
778
+ self.past_key_values = outputs.past_key_values
779
+
780
+ if self.use_kv_cache_head and not self.training:
781
+ if self.kv_cache_head is None:
782
+ self.kv_cache_head = input_ids.shape[1]
783
+ else:
784
+ self.kv_cache_head += input_ids.shape[1]
785
+
786
+ # new_position_ids = torch.ones((1, 1), device=self.position_ids.device) * (self.position_ids[0, -1].item() + 1)
787
+ # self.position_ids = torch.cat((self.position_ids, new_position_ids), dim=1)
788
+ return CausalLMOutputWithPast(
789
+ loss=loss,
790
+ logits=logits,
791
+ past_key_values=outputs.past_key_values,
792
+ hidden_states=outputs.hidden_states,
793
+ attentions=outputs.attentions,
794
+ )
795
+
796
+ def prepare_inputs_for_generation(
797
+ self,
798
+ input_ids,
799
+ past_key_values=None,
800
+ attention_mask=None,
801
+ inputs_embeds=None,
802
+ **kwargs,
803
+ ):
804
+ if self.use_kv_cache_head and not self.training:
805
+ if past_key_values:
806
+ input_ids = input_ids[:, self.kv_cache_head:]
807
+ if inputs_embeds is not None:
808
+ inputs_embeds = inputs_embeds[:, self.kv_cache_head:]
809
+
810
+ position_ids = kwargs.get("position_ids", None)
811
+ if attention_mask is not None and position_ids is None:
812
+ # create position_ids on the fly for batch generation
813
+ position_ids = attention_mask.long().cumsum(-1) - 1
814
+ position_ids.masked_fill_(attention_mask == 0, 1)
815
+ if past_key_values:
816
+ position_ids = position_ids[:, self.kv_cache_head:].unsqueeze(-1)
817
+
818
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
819
+ if inputs_embeds is not None and past_key_values is None:
820
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
821
+ elif past_key_values is not None and input_ids.shape[1] > 1:
822
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
823
+ else:
824
+ model_inputs = {"input_ids": input_ids}
825
+
826
+ attention_mask = None
827
+ else:
828
+ if past_key_values:
829
+ input_ids = input_ids[:, -1:]
830
+
831
+ position_ids = kwargs.get("position_ids", None)
832
+ if attention_mask is not None and position_ids is None:
833
+ # create position_ids on the fly for batch generation
834
+ position_ids = attention_mask.long().cumsum(-1) - 1
835
+ position_ids.masked_fill_(attention_mask == 0, 1)
836
+ if past_key_values:
837
+ position_ids = position_ids[:, -1].unsqueeze(-1)
838
+
839
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
840
+ if inputs_embeds is not None and past_key_values is None:
841
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
842
+ else:
843
+ model_inputs = {"input_ids": input_ids}
844
+ attention_mask = None
845
+
846
+ model_inputs.update({
847
+ "position_ids": position_ids,
848
+ "past_key_values": past_key_values,
849
+ "use_cache": kwargs.get("use_cache"),
850
+ "attention_mask": attention_mask,
851
+ })
852
+ return model_inputs
853
+
854
+ @staticmethod
855
+ def _reorder_cache(past_key_values, beam_idx):
856
+ reordered_past = ()
857
+ for layer_past in past_key_values:
858
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
859
+ return reordered_past
860
+
861
+
862
+ @add_start_docstrings(
863
+ """
864
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
865
+
866
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
867
+ (e.g. GPT-2) do.
868
+
869
+ Since it does classification on the last token, it requires to know the position of the last token. If a
870
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
871
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
872
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
873
+ each row of the batch).
874
+ """,
875
+ LLAMA_START_DOCSTRING,
876
+ )
877
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
878
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
879
+
880
+ def __init__(self, config):
881
+ super().__init__(config)
882
+ self.num_labels = config.num_labels
883
+ self.model = LlamaModel(config)
884
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
885
+
886
+ # Initialize weights and apply final processing
887
+ self.post_init()
888
+
889
+ def get_input_embeddings(self):
890
+ return self.model.embed_tokens
891
+
892
+ def set_input_embeddings(self, value):
893
+ self.model.embed_tokens = value
894
+
895
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
896
+ def forward(
897
+ self,
898
+ input_ids: torch.LongTensor = None,
899
+ attention_mask: Optional[torch.Tensor] = None,
900
+ position_ids: Optional[torch.LongTensor] = None,
901
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
902
+ inputs_embeds: Optional[torch.FloatTensor] = None,
903
+ labels: Optional[torch.LongTensor] = None,
904
+ use_cache: Optional[bool] = None,
905
+ output_attentions: Optional[bool] = None,
906
+ output_hidden_states: Optional[bool] = None,
907
+ return_dict: Optional[bool] = None,
908
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
909
+ r"""
910
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
911
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
912
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
913
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
914
+ """
915
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
916
+
917
+ transformer_outputs = self.model(
918
+ input_ids,
919
+ attention_mask=attention_mask,
920
+ position_ids=position_ids,
921
+ past_key_values=past_key_values,
922
+ inputs_embeds=inputs_embeds,
923
+ use_cache=use_cache,
924
+ output_attentions=output_attentions,
925
+ output_hidden_states=output_hidden_states,
926
+ return_dict=return_dict,
927
+ )
928
+ hidden_states = transformer_outputs[0]
929
+ logits = self.score(hidden_states)
930
+
931
+ if input_ids is not None:
932
+ batch_size = input_ids.shape[0]
933
+ else:
934
+ batch_size = inputs_embeds.shape[0]
935
+
936
+ if self.config.pad_token_id is None and batch_size != 1:
937
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
938
+ if self.config.pad_token_id is None:
939
+ sequence_lengths = -1
940
+ else:
941
+ if input_ids is not None:
942
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
943
+ else:
944
+ sequence_lengths = -1
945
+
946
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
947
+
948
+ loss = None
949
+ if labels is not None:
950
+ labels = labels.to(logits.device)
951
+ if self.config.problem_type is None:
952
+ if self.num_labels == 1:
953
+ self.config.problem_type = "regression"
954
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
955
+ self.config.problem_type = "single_label_classification"
956
+ else:
957
+ self.config.problem_type = "multi_label_classification"
958
+
959
+ if self.config.problem_type == "regression":
960
+ loss_fct = MSELoss()
961
+ if self.num_labels == 1:
962
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
963
+ else:
964
+ loss = loss_fct(pooled_logits, labels)
965
+ elif self.config.problem_type == "single_label_classification":
966
+ loss_fct = CrossEntropyLoss()
967
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
968
+ elif self.config.problem_type == "multi_label_classification":
969
+ loss_fct = BCEWithLogitsLoss()
970
+ loss = loss_fct(pooled_logits, labels)
971
+ if not return_dict:
972
+ output = (pooled_logits,) + transformer_outputs[1:]
973
+ return ((loss,) + output) if loss is not None else output
974
+
975
+ return SequenceClassifierOutputWithPast(
976
+ loss=loss,
977
+ logits=pooled_logits,
978
+ past_key_values=transformer_outputs.past_key_values,
979
+ hidden_states=transformer_outputs.hidden_states,
980
+ attentions=transformer_outputs.attentions,
981
+ )
982
+
983
+
984
+ if __name__ == "__main__":
985
+ from transformers import LlamaTokenizer
986
+
987
+ model = LlamaForCausalLM.from_pretrained("luodian/llama-7b-hf", device_map="auto")
988
+ tokenizer = LlamaTokenizer.from_pretrained("luodian/llama-7b-hf")
989
+ prompt = "Hey, are you consciours? Can you talk to me?"
990
+ inputs = tokenizer(prompt, return_tensors="pt")
991
+ generate_ids = model.generate(inputs.input_ids, max_length=30)
992
+ print(tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])
src/models_clm/models.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import LlamaForCausalLM, LlamaConfig
4
+ from transformers import LogitsProcessor, LogitsProcessorList
5
+ from .generation import AutoImageTokenGenerationProcessor
6
+ import torch.nn.functional as F
7
+
8
+ BOI_TOKEN = '<img>'
9
+ EOI_TOKEN = '</img>'
10
+ IMG_TOKEN = '<img_{:05d}>'
11
+
12
+
13
+ def cosine_loss(rec, target):
14
+ target = target / target.norm(dim=-1, keepdim=True)
15
+ rec = rec / rec.norm(dim=-1, keepdim=True)
16
+ rec_loss = (1 - (target * rec).sum(-1)).mean()
17
+ return rec_loss
18
+
19
+
20
+ class ContinuousLVLM(nn.Module):
21
+
22
+ def __init__(self, llm, input_resampler, output_resampler, lm_loss_scale=1.0, rec_loss_scale=1.0) -> None:
23
+ super().__init__()
24
+ self.llm = llm
25
+ self.input_resampler = input_resampler
26
+ self.output_resampler = output_resampler
27
+ self.lm_loss_scale = lm_loss_scale
28
+ self.rec_loss_scale = rec_loss_scale
29
+
30
+ # input_resampler.requires_grad_(False)
31
+ # output_resampler.requires_grad_(False)
32
+
33
+ def forward(self, input_ids, attention_mask, labels, image_embeds, embeds_gen_mask, embeds_cmp_mask, ids_gen_mask,
34
+ ids_cmp_mask, return_recon_image_embeds=False):
35
+
36
+ input_embeds = self.llm.get_input_embeddings()(input_ids) # bz x seq_len x dim, 4 x 160 x 4096
37
+
38
+ bz, sq, dim = input_embeds.shape
39
+
40
+ if image_embeds is not None:
41
+ image_embeds_lm = self.input_resampler(image_embeds) # num_imgs_in_batch x nq x dim, 4 x 64 x 4096
42
+ has_image = True
43
+ else:
44
+ image_embeds = torch.randn(bz, self.output_resampler.num_queries,
45
+ self.output_resampler.embed_dim).to(input_embeds.device,
46
+ dtype=input_embeds.dtype)
47
+ image_embeds_lm = self.input_resampler(image_embeds)
48
+ has_image = False
49
+
50
+ has_image_input = has_image and embeds_cmp_mask.sum().item() > 0
51
+ has_image_output = has_image and embeds_gen_mask.sum().item() > 0
52
+
53
+ if has_image_input:
54
+ input_embeds[ids_cmp_mask] = image_embeds_lm[embeds_cmp_mask].view(-1, dim) # eg, 128 x 4096
55
+ # zero_loss = 0.0
56
+ else:
57
+ min_bz = min(input_embeds.shape[0], image_embeds_lm.shape[0])
58
+ input_embeds[:min_bz, :self.input_resampler.
59
+ num_queries, :] = input_embeds[:min_bz, :self.input_resampler.
60
+ num_queries, :] + 0.0 * image_embeds_lm[:min_bz, :, :]
61
+
62
+ output_lm = self.llm(attention_mask=attention_mask,
63
+ inputs_embeds=input_embeds,
64
+ labels=labels,
65
+ output_hidden_states=True,
66
+ return_dict=True)
67
+ lm_loss = output_lm['loss']
68
+
69
+ last_hidden_state = output_lm.hidden_states[-1] # 4 x 160 x 4096
70
+
71
+ if has_image_output:
72
+ target_embeds = image_embeds[embeds_gen_mask] # num_imgs_gen_target x nq_in x dim_in, 2 x 256 x 4096
73
+ num_imgs_for_rec = target_embeds.shape[0]
74
+ output_image_embeds = last_hidden_state[ids_gen_mask].view(num_imgs_for_rec, -1,
75
+ dim) # 128 x 4096 -> 2 x 64 x 4096
76
+
77
+ recon_image_embeds = self.output_resampler(output_image_embeds) # 2 x 256 x 4096
78
+
79
+ rec_loss = cosine_loss(recon_image_embeds, target_embeds)
80
+ else:
81
+ output_image_embeds = torch.randn(bz, self.input_resampler.num_queries,
82
+ self.input_resampler.embed_dim).to(input_embeds.device,
83
+ dtype=input_embeds.dtype)
84
+ recon_image_embeds = self.output_resampler(output_image_embeds)
85
+ target_embeds = torch.randn(bz, self.output_resampler.num_queries,
86
+ self.output_resampler.embed_dim).to(input_embeds.device,
87
+ dtype=input_embeds.dtype)
88
+ rec_loss = cosine_loss(recon_image_embeds, target_embeds) * 0.0
89
+
90
+ total_loss = self.lm_loss_scale * lm_loss + self.rec_loss_scale * rec_loss
91
+
92
+ if return_recon_image_embeds and has_image_output:
93
+ return {'total_loss': total_loss, 'lm_loss': lm_loss, 'rec_loss': rec_loss,
94
+ 'recon_image_embeds': recon_image_embeds}
95
+ else:
96
+ return {'total_loss': total_loss, 'lm_loss': lm_loss, 'rec_loss': rec_loss}
97
+
98
+ def generate(self,
99
+ tokenizer,
100
+ prompt=None,
101
+ input_ids=None,
102
+ image_embeds=None,
103
+ embeds_cmp_mask=None,
104
+ ids_cmp_mask=None,
105
+ logits_processor=None,
106
+ num_img_gen_tokens=64,
107
+ temperature=0.7,
108
+ num_beams=1,
109
+ max_new_tokens=120,
110
+ top_p=0.5,
111
+ past_key_values=None,
112
+ # position_ids=None,
113
+ dtype=torch.float16,
114
+ device='cuda'):
115
+ if logits_processor is None:
116
+ logits_processor = LogitsProcessorList()
117
+ logits_processor.append(
118
+ AutoImageTokenGenerationProcessor(tokenizer=tokenizer, num_img_gen_tokens=num_img_gen_tokens))
119
+
120
+ if prompt is not None:
121
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
122
+
123
+ if isinstance(input_ids, list):
124
+ input_ids = torch.tensor(input_ids)
125
+
126
+ input_ids = input_ids.to(device=device)
127
+ input_embeds = self.llm.get_input_embeddings()(input_ids)
128
+ bz, sq, dim = input_embeds.shape
129
+
130
+ if image_embeds is not None:
131
+ assert embeds_cmp_mask is not None and ids_cmp_mask is not None
132
+ with torch.no_grad():
133
+ image_embeds_lm = self.input_resampler(image_embeds)
134
+
135
+ input_embeds[ids_cmp_mask] = image_embeds_lm[embeds_cmp_mask].view(-1, dim)
136
+
137
+ generation_config = {
138
+ 'temperature': temperature,
139
+ 'num_beams': num_beams,
140
+ 'max_new_tokens': max_new_tokens,
141
+ 'top_p': top_p,
142
+ 'do_sample': False
143
+ }
144
+
145
+ # generate_ids = self.llm.generate(input_ids=input_ids, **generation_config)
146
+ output = self.llm.generate(input_ids=input_ids,
147
+ inputs_embeds=input_embeds,
148
+ output_hidden_states=True,
149
+ return_dict_in_generate=True,
150
+ logits_processor=logits_processor,
151
+ past_key_values=past_key_values,
152
+ # position_ids=position_ids,
153
+ **generation_config)
154
+ # self.llm.base_model.model.position_ids = self.llm.base_model.model.position_ids[:, :-2]
155
+
156
+ output_past_key_values = self.llm.past_key_values
157
+ generate_ids = output.sequences[0][input_ids.shape[1]:]
158
+ generate_id_list = generate_ids.tolist()
159
+ boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
160
+ eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
161
+
162
+ attn_weights = ()
163
+
164
+ def merge_attn_weights(attn_weights):
165
+ merged_attn_weights = attn_weights[0]
166
+
167
+ # Iterate through the remaining attention weight tensors
168
+ for i, attn_weight in enumerate(attn_weights[1:]):
169
+ merged_attn_weights = F.pad(merged_attn_weights, (0, 1), "constant", float('nan'))
170
+ # Concatenate the expanded tensor to the merged tensor along the kv_len dimension
171
+ merged_attn_weights = torch.cat([merged_attn_weights, attn_weight], dim=1)
172
+
173
+ return merged_attn_weights
174
+
175
+ if output.attentions is not None:
176
+ # for idx in [0, 1, 2, 9, 16, 23, 31]:
177
+ for idx in range(32):
178
+ attn_weights += (
179
+ merge_attn_weights([output.attentions[j][idx] for j in range(len(output.attentions))]),)
180
+
181
+ # for skip image multi turn kvcache
182
+ last_hidden_states = torch.cat([hidden_state[-1] for hidden_state in output.hidden_states], dim=1)
183
+ if past_key_values is None:
184
+ last_hidden_states = last_hidden_states[0, input_ids.shape[1]:, :]
185
+ eoi_indices = torch.where(generate_ids == eoi_token_id)[0].tolist()
186
+ else:
187
+ last_hidden_states = last_hidden_states[0, :, :]
188
+ hidden_len = last_hidden_states.shape[0]
189
+ eoi_indices = torch.where(output.sequences[0][-hidden_len:] == eoi_token_id)[0].tolist()
190
+
191
+ num_gen_imgs = 1 if len(eoi_indices) > 0 else 0
192
+
193
+ text_mask = torch.ones_like(generate_ids, dtype=torch.bool)
194
+ has_img_output = num_gen_imgs > 0
195
+ if has_img_output:
196
+ img_gen_feats = []
197
+ img_gen_feats.append(last_hidden_states[eoi_indices[-1] - num_img_gen_tokens:eoi_indices[-1]])
198
+ text_mask[eoi_indices[-1] - num_img_gen_tokens:eoi_indices[-1]] = False
199
+
200
+ # for eoi_idx in eoi_indices:
201
+ # img_gen_feats.append(last_hidden_states[eoi_idx - num_img_gen_tokens:eoi_idx])
202
+ # text_mask[eoi_idx - num_img_gen_tokens:eoi_idx] = False
203
+
204
+ img_gen_feats = torch.stack(img_gen_feats)
205
+ img_gen_feat = self.output_resampler(img_gen_feats)
206
+ else:
207
+ img_gen_feat = None
208
+
209
+ text_mask[generate_ids == boi_token_id] = False
210
+ # generate_ids = generate_ids[text_mask]
211
+ generate_text = tokenizer.decode(generate_ids, skip_special_tokens=False)
212
+
213
+ return {
214
+ 'text': generate_text,
215
+ 'generate_ids': generate_ids,
216
+ 'has_img_output': has_img_output,
217
+ 'img_gen_feat': img_gen_feat,
218
+ 'num_gen_imgs': num_gen_imgs,
219
+ 'attn_weights': attn_weights,
220
+ 'past_key_values': output_past_key_values
221
+ }
222
+
223
+ @classmethod
224
+ def from_pretrained(cls, llm, input_resampler, output_resampler, pretrained_model_path=None, **kwargs):
225
+ model = cls(llm=llm, input_resampler=input_resampler, output_resampler=output_resampler, **kwargs)
226
+ if pretrained_model_path is not None:
227
+ ckpt = torch.load(pretrained_model_path, map_location='cpu')
228
+ missing, unexpected = model.load_state_dict(ckpt, strict=False)
229
+ print('agent model, missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
230
+ return model
231
+
232
+
233
+ class SEEDLLaMAAlignGeneration(nn.Module):
234
+
235
+ def __init__(self, llm, output_resampler) -> None:
236
+ super().__init__()
237
+
238
+ self.llm = llm
239
+ self.output_resampler = output_resampler
240
+ # self.rec_loss_scale = rec_loss_scale
241
+
242
+ self.llm.requires_grad_(False)
243
+
244
+ def forward(self, input_ids, attention_mask, labels, image_embeds, embeds_gen_mask, embeds_cmp_mask, ids_gen_mask,
245
+ ids_cmp_mask):
246
+
247
+ input_embeds = self.llm.get_input_embeddings()(input_ids) # bz x seq_len x dim, 4 x 160 x 4096
248
+
249
+ bz, sq, dim = input_embeds.shape
250
+
251
+ output_lm = self.llm(attention_mask=attention_mask,
252
+ inputs_embeds=input_embeds,
253
+ labels=labels,
254
+ output_hidden_states=True,
255
+ return_dict=True)
256
+
257
+ last_hidden_state = output_lm.hidden_states[-1] # 4 x 160 x 4096
258
+
259
+ target_embeds = image_embeds[embeds_gen_mask] # num_imgs_gen_target x nq_in x dim_in, 2 x 256 x 4096
260
+ num_imgs_for_rec = target_embeds.shape[0]
261
+ output_image_embeds = last_hidden_state[ids_gen_mask].view(num_imgs_for_rec, -1,
262
+ dim) # 128 x 4096 -> 2 x 64 x 4096
263
+
264
+ recon_image_embeds = self.output_resampler(output_image_embeds) # 2 x 256 x 4096
265
+
266
+ rec_loss = cosine_loss(recon_image_embeds, target_embeds)
267
+
268
+ return {'total_loss': rec_loss, 'rec_loss': rec_loss}
269
+
270
+ @classmethod
271
+ def from_pretrained(cls, llm, output_resampler, pretrained_model_path=None, **kwargs):
272
+ model = cls(llm=llm, output_resampler=output_resampler, **kwargs)
273
+ if pretrained_model_path is not None:
274
+ ckpt = torch.load(pretrained_model_path, map_location='cpu')
275
+ missing, unexpected = model.load_state_dict(ckpt, strict=False)
276
+ print('agent model, missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
277
+ return model
278
+
279
+ def generate(self,
280
+ tokenizer,
281
+ input_ids=None,
282
+ temperature=0.7,
283
+ num_beams=1,
284
+ max_new_tokens=120,
285
+ num_img_gen_tokens=64,
286
+ top_p=0.5,
287
+ dtype=torch.float16,
288
+ device='cuda'):
289
+ input_ids = input_ids.to(device=device)
290
+ input_embeds = self.llm.get_input_embeddings()(input_ids) # bz x seq_len x dim, 4 x 160 x 4096
291
+
292
+ generation_config = {
293
+ 'temperature': temperature,
294
+ 'num_beams': num_beams,
295
+ 'max_new_tokens': max_new_tokens,
296
+ 'top_p': top_p,
297
+ 'do_sample': False
298
+ }
299
+ output = self.llm.generate(input_ids=input_ids,
300
+ inputs_embeds=input_embeds,
301
+ output_hidden_states=True,
302
+ return_dict_in_generate=True,
303
+ **generation_config)
304
+
305
+ generate_ids = output.sequences[0][input_ids.shape[1]:]
306
+ generate_id_list = generate_ids.tolist()
307
+ # boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
308
+ eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
309
+
310
+ # print('output ids: ', generate_ids, generate_ids.shape)
311
+ # last_hidden_states = output.hidden_states[-1]
312
+
313
+ last_hidden_states = torch.cat([hidden_state[-1] for hidden_state in output.hidden_states],
314
+ dim=1)[:1, input_ids.shape[1]:, :]
315
+
316
+ has_img_output = eoi_token_id in generate_id_list
317
+
318
+ if has_img_output:
319
+ # print(boi_token_id, generate_id_list, generate_id_list.index(boi_token_id))
320
+ # boi_idx = generate_id_list.index(boi_token_id)
321
+ eoi_idx = generate_id_list.index(eoi_token_id)
322
+ print(len(generate_id_list), generate_id_list, eoi_idx)
323
+ # print(generate_id_list[boi_idx + 1:boi_idx + 1 + num_img_gen_tokens])
324
+
325
+ # img_gen_feat = last_hidden_states[:, eoi_idx - num_img_gen_tokens:eoi_idx]
326
+ img_gen_feat = last_hidden_states[:, 0:eoi_idx]
327
+ print('img_gen_feat', img_gen_feat.shape, last_hidden_states.shape, num_img_gen_tokens)
328
+ img_gen_feat = self.output_resampler(img_gen_feat)
329
+
330
+ else:
331
+ img_gen_feat = None
332
+
333
+ generate_text = tokenizer.decode(generate_ids, skip_special_tokens=False)
334
+ # print('output keys: ', output.keys())
335
+
336
+ return {'text': generate_text, 'has_img_output': has_img_output, 'img_gen_feat': img_gen_feat}
src/models_clm/peft_models.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from peft import (
2
+ LoraConfig,
3
+ PeftModel,
4
+ LoraModel,
5
+ PeftModelForCausalLM,
6
+ get_peft_model,
7
+ get_peft_model_state_dict,
8
+ prepare_model_for_int8_training,
9
+ set_peft_model_state_dict,
10
+ )
11
+ from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING
12
+ from peft.utils import _set_trainable, PromptLearningConfig
13
+ from peft.utils import PeftConfig
14
+
15
+ import torch
16
+ from transformers import LlamaForCausalLM
17
+ from omegaconf import DictConfig
18
+ import hydra
19
+
20
+
21
+ def get_peft_model_with_resize_embedding(
22
+ model,
23
+ peft_config=None,
24
+ model_id=None,
25
+ vocab_size=None,
26
+ torch_dtype='bf16'
27
+ ):
28
+ if torch_dtype == 'bf16' or torch_dtype == 'bfloat16':
29
+ torch_dtype = torch.bfloat16
30
+ elif torch_dtype == 'fp16' or torch_dtype == 'float16':
31
+ torch_dtype = torch.float16
32
+ else:
33
+ torch_dtype = torch.float32
34
+
35
+ if isinstance(model, DictConfig):
36
+ model = hydra.utils.instantiate(model, torch_dtype=torch_dtype)
37
+
38
+ # model.gradient_checkpointing_enable()
39
+
40
+ assert (peft_config is None) + (model_id is None) == 1
41
+
42
+ # print(type(peft_config.target_modules))
43
+ if vocab_size is not None:
44
+ print(f'Length of tokenizer and resize embedding: {vocab_size}')
45
+ model.resize_token_embeddings(vocab_size)
46
+
47
+ if peft_config is not None:
48
+ print('peft config: ', peft_config)
49
+ peft_model = get_peft_model(model=model, peft_config=peft_config)
50
+ peft_model.get_input_embeddings().requires_grad_(True)
51
+ peft_model.get_output_embeddings().requires_grad_(True)
52
+
53
+ peft_model.print_trainable_parameters()
54
+
55
+ # param_count = 0
56
+ # if peft_model.modules_to_save is not None:
57
+ # for name, param in peft_model.named_parameters():
58
+ # if any(module_name in name for module_name in peft_model.modules_to_save):
59
+ # param_count += param.numel()
60
+ # print(name, param.numel())
61
+
62
+ else:
63
+ peft_model = PeftModel.from_pretrained(model=model, model_id=model_id)
64
+
65
+ return peft_model
66
+
67
+
68
+ def get_model_with_resize_embedding(model, vocab_size=None, torch_dtype='bf16'):
69
+ if torch_dtype == 'bf16' or torch_dtype == 'bfloat16':
70
+ torch_dtype = torch.bfloat16
71
+ elif torch_dtype == 'fp16' or torch_dtype == 'float16':
72
+ torch_dtype = torch.float16
73
+ else:
74
+ torch_dtype = torch.float32
75
+
76
+ if isinstance(model, DictConfig):
77
+ model = hydra.utils.instantiate(model, torch_dtype=torch_dtype)
78
+
79
+ model.requires_grad_(False)
80
+ if vocab_size is not None:
81
+ print(f'Length of tokenizer and resize embedding: {vocab_size}')
82
+ model.resize_token_embeddings(vocab_size)
83
+ model.get_input_embeddings().requires_grad_(True)
84
+ model.get_output_embeddings().requires_grad_(True)
85
+
86
+ return model
87
+
88
+
89
+ def get_full_model_with_resize_embedding(model, vocab_size=None, torch_dtype='bf16'):
90
+ if torch_dtype == 'bf16' or torch_dtype == 'bfloat16':
91
+ torch_dtype = torch.bfloat16
92
+ elif torch_dtype == 'fp16' or torch_dtype == 'float16':
93
+ torch_dtype = torch.float16
94
+ else:
95
+ torch_dtype = torch.float32
96
+
97
+ if isinstance(model, DictConfig):
98
+ model = hydra.utils.instantiate(model, torch_dtype=torch_dtype)
99
+
100
+ if vocab_size is not None:
101
+ print(f'Length of tokenizer and resize embedding: {vocab_size}')
102
+ model.resize_token_embeddings(vocab_size)
103
+
104
+ return model
src/models_ipa/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
src/models_ipa/adapter_modules.py ADDED
@@ -0,0 +1,920 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import itertools
4
+ import torch.nn.functional as F
5
+ from typing import List
6
+ from diffusers import (
7
+ StableDiffusionPipeline,
8
+ StableDiffusionXLPipeline,
9
+ StableDiffusionXLInstructPix2PixPipeline,
10
+ StableDiffusionInstructPix2PixPipeline,
11
+ )
12
+ from PIL import Image
13
+ from .ipa_utils import is_torch2_available
14
+
15
+ if is_torch2_available():
16
+ from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
17
+ else:
18
+ from .attention_processor import IPAttnProcessor, AttnProcessor
19
+
20
+ from diffusers.loaders import LoraLoaderMixin
21
+ from diffusers.models.lora import LoRALinearLayer
22
+ from diffusers.models.unet_2d_blocks import DownBlock2D
23
+
24
+
25
+ # from .pipeline_stable_diffusion_xl_t2i_edit import StableDiffusionXLText2ImageAndEditPipeline
26
+ # from .pipeline_stable_diffusion_t2i_edit import StableDiffusionText2ImageAndEditPipeline
27
+
28
+
29
+ class IPAdapterSD(nn.Module):
30
+
31
+ def __init__(self, unet, resampler) -> None:
32
+ super().__init__()
33
+ self.unet = unet
34
+ self.resampler = resampler
35
+ self.set_ip_adapter()
36
+ self.set_trainable()
37
+
38
+ def set_ip_adapter(self):
39
+ attn_procs = {}
40
+ unet_sd = self.unet.state_dict()
41
+ for name in self.unet.attn_processors.keys():
42
+ cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
43
+ if name.startswith("mid_block"):
44
+ hidden_size = self.unet.config.block_out_channels[-1]
45
+ elif name.startswith("up_blocks"):
46
+ block_id = int(name[len("up_blocks.")])
47
+ hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
48
+ elif name.startswith("down_blocks"):
49
+ block_id = int(name[len("down_blocks.")])
50
+ hidden_size = self.unet.config.block_out_channels[block_id]
51
+ if cross_attention_dim is None:
52
+ attn_procs[name] = AttnProcessor()
53
+ else:
54
+ layer_name = name.split(".processor")[0]
55
+ weights = {
56
+ "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
57
+ "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
58
+ }
59
+ attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
60
+ attn_procs[name].load_state_dict(weights)
61
+ self.unet.set_attn_processor(attn_procs)
62
+ self.adapter = torch.nn.ModuleList(self.unet.attn_processors.values())
63
+
64
+ def set_trainable(self):
65
+ self.unet.requires_grad_(False)
66
+ self.resampler.requires_grad_(True)
67
+ self.adapter.requires_grad_(True)
68
+
69
+ def params_to_opt(self):
70
+ return itertools.chain(self.resampler.parameters(), self.adapter.parameters())
71
+
72
+ def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise):
73
+
74
+ image_embeds = self.resampler(image_embeds)
75
+ # image_embeds = image_embeds.to(dtype=text_embeds.dtype)
76
+
77
+ text_embeds = torch.cat([text_embeds, image_embeds], dim=1)
78
+ # Predict the noise residual and compute loss
79
+ noise_pred = self.unet(noisy_latents, timesteps, text_embeds).sample
80
+
81
+ # if noise is not None:
82
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
83
+ # else:
84
+ # loss = torch.tensor(0.0, device=noisy_latents)
85
+
86
+ return {'total_loss': loss, 'noise_pred': noise_pred}
87
+
88
+ def encode_image_embeds(self, image_embeds):
89
+ dtype = image_embeds.dtype
90
+ image_embeds = self.resampler(image_embeds)
91
+ image_embeds = image_embeds.to(dtype=dtype)
92
+ return image_embeds
93
+
94
+ @classmethod
95
+ def from_pretrained(cls,
96
+ unet,
97
+ resampler,
98
+ pretrained_model_path=None,
99
+ pretrained_resampler_path=None,
100
+ pretrained_adapter_path=None):
101
+ model = cls(unet=unet, resampler=resampler)
102
+ if pretrained_model_path is not None:
103
+ ckpt = torch.load(pretrained_model_path, map_location='cpu')
104
+ missing, unexpected = model.load_state_dict(ckpt, strict=False)
105
+ print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
106
+ if pretrained_resampler_path is not None:
107
+ ckpt = torch.load(pretrained_resampler_path, map_location='cpu')
108
+ missing, unexpected = model.resampler.load_state_dict(ckpt, strict=True)
109
+ print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
110
+ if pretrained_adapter_path is not None:
111
+ ckpt = torch.load(pretrained_adapter_path, map_location='cpu')
112
+ missing, unexpected = model.adapter.load_state_dict(ckpt, strict=True)
113
+ print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
114
+ return model
115
+
116
+ @classmethod
117
+ def from_pretrained_legacy(cls, unet, resampler, pretrained_model_path=None):
118
+ model = cls(unet=unet, resampler=resampler)
119
+ if pretrained_model_path is not None:
120
+ ckpt = torch.load(pretrained_model_path, map_location='cpu')
121
+ ckpt_image_proj = {}
122
+ ckpt_ip_layers = {}
123
+
124
+ for key, value in ckpt.items():
125
+ if key.startswith('image_proj_model'):
126
+ new_key = key.replace('image_proj_model.', '')
127
+ ckpt_image_proj[new_key] = value
128
+ elif key.startswith('adapter_modules.'):
129
+ new_key = key.replace('adapter_modules.', '')
130
+ ckpt_ip_layers[new_key] = value
131
+
132
+ missing, unexpected = model.resampler.load_state_dict(ckpt_image_proj, strict=True)
133
+ print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
134
+ missing, unexpected = model.adapter.load_state_dict(ckpt_ip_layers, strict=True)
135
+ print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
136
+
137
+ return model
138
+
139
+
140
+ class IPAdapterSDPipe(nn.Module):
141
+
142
+ def __init__(
143
+ self,
144
+ ip_adapter,
145
+ discrete_model,
146
+ vae,
147
+ visual_encoder,
148
+ text_encoder,
149
+ tokenizer,
150
+ scheduler,
151
+ image_transform,
152
+ device,
153
+ dtype,
154
+ ) -> None:
155
+ super().__init__()
156
+
157
+ self.ip_adapter = ip_adapter
158
+ self.vae = vae
159
+ self.visual_encoder = visual_encoder
160
+ self.text_encoder = text_encoder
161
+ self.tokenizer = tokenizer
162
+ self.scheduler = scheduler
163
+ self.image_transform = image_transform
164
+ self.discrete_model = discrete_model
165
+ self.device = device
166
+ self.dtype = dtype
167
+
168
+ self.sd_pipe = StableDiffusionPipeline(vae=vae,
169
+ text_encoder=text_encoder,
170
+ tokenizer=tokenizer,
171
+ unet=ip_adapter.unet,
172
+ scheduler=scheduler,
173
+ safety_checker=None,
174
+ feature_extractor=None,
175
+ requires_safety_checker=False)
176
+
177
+ def set_scale(self, scale):
178
+ for attn_processor in self.sd_pipe.unet.attn_processors.values():
179
+ if isinstance(attn_processor, IPAttnProcessor):
180
+ attn_processor.scale = scale
181
+
182
+ @torch.inference_mode()
183
+ def get_image_embeds(self, image_pil=None, image_tensor=None, return_negative=True):
184
+ assert int(image_pil is not None) + int(image_tensor is not None) == 1
185
+ if image_pil is not None:
186
+ image_tensor = self.image_transform(image_pil).unsqueeze(0).to(self.device, dtype=self.dtype)
187
+ if return_negative:
188
+ image_tensor_neg = torch.zeros_like(image_tensor)
189
+ image_tensor = torch.cat([image_tensor, image_tensor_neg], dim=0)
190
+ with torch.cuda.amp.autocast(dtype=self.dtype):
191
+ image_embeds = self.visual_encoder(image_tensor)
192
+ image_embeds = self.discrete_model.encode_image_embeds(image_embeds)
193
+ image_embeds = self.ip_adapter.encode_image_embeds(image_embeds)
194
+
195
+ if return_negative:
196
+ # bz = image_embeds.shape[0]
197
+ # image_embeds_neg = image_embeds[bz // 2:]
198
+ # image_embeds = image_embeds[0:bz // 2]
199
+ image_embeds, image_embeds_neg = image_embeds.chunk(2)
200
+ else:
201
+ image_embeds_neg = None
202
+
203
+ return image_embeds, image_embeds_neg
204
+
205
+ def generate(self,
206
+ image_pil=None,
207
+ image_tensor=None,
208
+ prompt=None,
209
+ negative_prompt=None,
210
+ scale=1.0,
211
+ num_samples=1,
212
+ seed=42,
213
+ guidance_scale=7.5,
214
+ num_inference_steps=30,
215
+ **kwargs):
216
+ self.set_scale(scale)
217
+ assert int(image_pil is not None) + int(image_tensor is not None) == 1
218
+
219
+ if image_pil is not None:
220
+ assert isinstance(image_pil, Image.Image)
221
+ num_prompts = 1
222
+ else:
223
+ num_prompts = image_tensor.shape[0]
224
+
225
+ if prompt is None:
226
+ # prompt = "best quality, high quality"
227
+ prompt = ""
228
+ if negative_prompt is None:
229
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
230
+
231
+ if not isinstance(prompt, List):
232
+ prompt = [prompt] * num_prompts
233
+ if not isinstance(negative_prompt, List):
234
+ negative_prompt = [negative_prompt] * num_prompts
235
+
236
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
237
+ image_pil=image_pil,
238
+ image_tensor=image_tensor,
239
+ return_negative=True,
240
+ )
241
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
242
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
243
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
244
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
245
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
246
+
247
+ with torch.inference_mode():
248
+ prompt_embeds, negative_prompt_embeds = self.sd_pipe.encode_prompt(
249
+ prompt,
250
+ device=self.device,
251
+ num_images_per_prompt=num_samples,
252
+ do_classifier_free_guidance=True,
253
+ negative_prompt=negative_prompt,
254
+ )
255
+
256
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
257
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
258
+
259
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
260
+ images = self.sd_pipe(
261
+ prompt_embeds=prompt_embeds,
262
+ negative_prompt_embeds=negative_prompt_embeds,
263
+ guidance_scale=guidance_scale,
264
+ num_inference_steps=num_inference_steps,
265
+ generator=generator,
266
+ **kwargs,
267
+ ).images
268
+
269
+ return images
270
+
271
+
272
+ def compute_time_ids(original_size, crops_coords_top_left, target_resolution):
273
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
274
+ target_size = (target_resolution, target_resolution)
275
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
276
+ add_time_ids = torch.tensor([add_time_ids])
277
+ # add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
278
+ return add_time_ids
279
+
280
+
281
+ class SDXLAdapter(nn.Module):
282
+
283
+ def __init__(self, unet, resampler, full_ft=False) -> None:
284
+ super().__init__()
285
+ self.unet = unet
286
+ self.resampler = resampler
287
+ self.full_ft = full_ft
288
+ self.set_trainable_v2()
289
+ # self.set_adapter()
290
+
291
+ # self.set_trainable()
292
+
293
+ # def set_adapter(self):
294
+
295
+ # adapter = []
296
+ # for name, module in self.unet.named_modules():
297
+ # if name.endswith('to_k') or name.endswith('to_v'):
298
+ # if module is not None:
299
+ # adapter.append(module)
300
+
301
+ # self.adapter = torch.nn.ModuleList(adapter)
302
+ # print(f'adapter: {self.adapter}')
303
+
304
+ # def set_trainable(self):
305
+ # self.unet.requires_grad_(False)
306
+ # self.resampler.requires_grad_(True)
307
+ # self.adapter.requires_grad_(True)
308
+
309
+ def set_trainable_v2(self):
310
+ self.resampler.requires_grad_(True)
311
+ adapter_parameters = []
312
+ if self.full_ft:
313
+ self.unet.requires_grad_(True)
314
+ adapter_parameters.extend(self.unet.parameters())
315
+ else:
316
+ self.unet.requires_grad_(False)
317
+ for name, module in self.unet.named_modules():
318
+ if name.endswith('to_k') or name.endswith('to_v'):
319
+ if module is not None:
320
+ adapter_parameters.extend(module.parameters())
321
+ self.adapter_parameters = adapter_parameters
322
+ for param in self.adapter_parameters:
323
+ param.requires_grad_(True)
324
+
325
+ # def params_to_opt(self):
326
+ # return itertools.chain(self.resampler.parameters(), self.adapter.parameters())
327
+ def params_to_opt(self):
328
+ return itertools.chain(self.resampler.parameters(), self.adapter_parameters)
329
+
330
+ def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise, time_ids):
331
+
332
+ image_embeds, pooled_image_embeds = self.resampler(image_embeds)
333
+
334
+ unet_added_conditions = {"time_ids": time_ids, 'text_embeds': pooled_image_embeds}
335
+
336
+ noise_pred = self.unet(noisy_latents, timesteps, image_embeds, added_cond_kwargs=unet_added_conditions).sample
337
+
338
+ # if noise is not None:
339
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
340
+ # else:
341
+ # loss = torch.tensor(0.0, device=noisy_latents)
342
+
343
+ return {'total_loss': loss, 'noise_pred': noise_pred}
344
+
345
+ def encode_image_embeds(self, image_embeds):
346
+ image_embeds, pooled_image_embeds = self.resampler(image_embeds)
347
+
348
+ return image_embeds, pooled_image_embeds
349
+
350
+ @classmethod
351
+ def from_pretrained(cls, unet, resampler, pretrained_model_path=None, **kwargs):
352
+ model = cls(unet=unet, resampler=resampler, **kwargs)
353
+ if pretrained_model_path is not None:
354
+ ckpt = torch.load(pretrained_model_path, map_location='cpu')
355
+ missing, unexpected = model.load_state_dict(ckpt, strict=False)
356
+ print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
357
+ return model
358
+
359
+ def init_pipe(self,
360
+ vae,
361
+ scheduler,
362
+ visual_encoder,
363
+ image_transform,
364
+ discrete_model=None,
365
+ dtype=torch.float16,
366
+ device='cuda'):
367
+ self.device = device
368
+ self.dtype = dtype
369
+ sdxl_pipe = StableDiffusionXLPipeline(tokenizer=None,
370
+ tokenizer_2=None,
371
+ text_encoder=None,
372
+ text_encoder_2=None,
373
+ vae=vae,
374
+ unet=self.unet,
375
+ scheduler=scheduler)
376
+
377
+ self.sdxl_pipe = sdxl_pipe # .to(self.device, dtype=self.dtype)
378
+ # print(sdxl_pipe.text_encoder_2, sdxl_pipe.text_encoder)
379
+
380
+ self.visual_encoder = visual_encoder.to(self.device, dtype=self.dtype)
381
+ if discrete_model is not None:
382
+ self.discrete_model = discrete_model.to(self.device, dtype=self.dtype)
383
+ else:
384
+ self.discrete_model = None
385
+ self.image_transform = image_transform
386
+
387
+ @torch.inference_mode()
388
+ def get_image_embeds(self,
389
+ image_pil=None,
390
+ image_tensor=None,
391
+ image_embeds=None,
392
+ return_negative=True,
393
+ image_size=448
394
+ ):
395
+ assert int(image_pil is not None) + int(image_tensor is not None) + int(image_embeds is not None) == 1
396
+
397
+ if image_pil is not None:
398
+ image_tensor = self.image_transform(image_pil).unsqueeze(0).to(self.device, dtype=self.dtype)
399
+
400
+ if image_tensor is not None:
401
+ if return_negative:
402
+ image_tensor_neg = torch.zeros_like(image_tensor)
403
+ image_tensor = torch.cat([image_tensor, image_tensor_neg], dim=0)
404
+
405
+ image_embeds = self.visual_encoder(image_tensor)
406
+ elif return_negative:
407
+ image_tensor_neg = torch.zeros(
408
+ 1, 3,
409
+ image_size, image_size
410
+ ).to(
411
+ image_embeds.device, dtype=image_embeds.dtype
412
+ )
413
+ image_embeds_neg = self.visual_encoder(image_tensor_neg)
414
+ image_embeds = torch.cat([image_embeds, image_embeds_neg], dim=0)
415
+
416
+ if self.discrete_model is not None:
417
+ image_embeds = self.discrete_model.encode_image_embeds(image_embeds)
418
+ image_embeds, pooled_image_embeds = self.encode_image_embeds(image_embeds)
419
+
420
+ if return_negative:
421
+ image_embeds, image_embeds_neg = image_embeds.chunk(2)
422
+ pooled_image_embeds, pooled_image_embeds_neg = pooled_image_embeds.chunk(2)
423
+
424
+ else:
425
+ image_embeds_neg = None
426
+ pooled_image_embeds_neg = None
427
+
428
+ return image_embeds, image_embeds_neg, pooled_image_embeds, pooled_image_embeds_neg
429
+
430
+ def generate(self,
431
+ image_pil=None,
432
+ image_tensor=None,
433
+ image_embeds=None,
434
+ seed=42,
435
+ height=1024,
436
+ width=1024,
437
+ guidance_scale=7.5,
438
+ num_inference_steps=30,
439
+ input_image_size=448,
440
+ **kwargs):
441
+ if image_pil is not None:
442
+ assert isinstance(image_pil, Image.Image)
443
+
444
+ image_prompt_embeds, uncond_image_prompt_embeds, pooled_image_prompt_embeds, \
445
+ pooled_uncond_image_prompt_embeds = self.get_image_embeds(
446
+ image_pil=image_pil,
447
+ image_tensor=image_tensor,
448
+ image_embeds=image_embeds,
449
+ return_negative=True,
450
+ image_size=input_image_size,
451
+ )
452
+ # print(image_prompt_embeds.shape, pooled_image_prompt_embeds.shape)
453
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
454
+
455
+ images = self.sdxl_pipe(
456
+ prompt_embeds=image_prompt_embeds,
457
+ negative_prompt_embeds=uncond_image_prompt_embeds,
458
+ pooled_prompt_embeds=pooled_image_prompt_embeds,
459
+ negative_pooled_prompt_embeds=pooled_uncond_image_prompt_embeds,
460
+ guidance_scale=guidance_scale,
461
+ num_inference_steps=num_inference_steps,
462
+ generator=generator,
463
+ height=height,
464
+ width=width,
465
+ **kwargs,
466
+ ).images
467
+
468
+ return images
469
+
470
+
471
+ class SDXLText2ImageAndEditAdapter(nn.Module):
472
+
473
+ def __init__(self, unet, resampler, lora_rank=16, fully_ft=False) -> None:
474
+ super().__init__()
475
+
476
+ self.unet = unet
477
+ self.resampler = resampler
478
+ self.lora_rank = lora_rank
479
+
480
+ if fully_ft:
481
+ self.set_fully_trainable()
482
+ else:
483
+ self.set_adapter()
484
+
485
+ def set_adapter(self):
486
+ self.unet.requires_grad_(False)
487
+ adapter_parameters = []
488
+
489
+ in_channels = 8
490
+ out_channels = self.unet.conv_in.out_channels
491
+ self.unet.register_to_config(in_channels=in_channels)
492
+
493
+ with torch.no_grad():
494
+ new_conv_in = nn.Conv2d(in_channels, out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride,
495
+ self.unet.conv_in.padding)
496
+
497
+ new_conv_in.weight.zero_()
498
+ new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
499
+ self.unet.conv_in = new_conv_in
500
+ self.unet.conv_in.requires_grad_(True)
501
+ print('Make conv_in trainable.')
502
+ adapter_parameters.extend(self.unet.conv_in.parameters())
503
+
504
+ for name, module in self.unet.named_modules():
505
+ if isinstance(module, DownBlock2D):
506
+ module.requires_grad_(True)
507
+ adapter_parameters.extend(module.parameters())
508
+ print('Make DownBlock2D trainable.')
509
+
510
+ for attn_processor_name, attn_processor in self.unet.attn_processors.items():
511
+ # Parse the attention module.
512
+ attn_module = self.unet
513
+ for n in attn_processor_name.split(".")[:-1]:
514
+ attn_module = getattr(attn_module, n)
515
+
516
+ # Set the `lora_layer` attribute of the attention-related matrices.
517
+ attn_module.to_q.set_lora_layer(
518
+ LoRALinearLayer(in_features=attn_module.to_q.in_features,
519
+ out_features=attn_module.to_q.out_features,
520
+ rank=self.lora_rank))
521
+ # attn_module.to_k.set_lora_layer(
522
+ # LoRALinearLayer(in_features=attn_module.to_k.in_features,
523
+ # out_features=attn_module.to_k.out_features,
524
+ # rank=self.lora_rank))
525
+ # attn_module.to_v.set_lora_layer(
526
+ # LoRALinearLayer(in_features=attn_module.to_v.in_features,
527
+ # out_features=attn_module.to_v.out_features,
528
+ # rank=self.lora_rank))
529
+ attn_module.to_out[0].set_lora_layer(
530
+ LoRALinearLayer(
531
+ in_features=attn_module.to_out[0].in_features,
532
+ out_features=attn_module.to_out[0].out_features,
533
+ rank=self.lora_rank,
534
+ ))
535
+
536
+ attn_module.to_k.requires_grad_(True)
537
+ attn_module.to_v.requires_grad_(True)
538
+
539
+ adapter_parameters.extend(attn_module.to_q.lora_layer.parameters())
540
+ adapter_parameters.extend(attn_module.to_k.parameters())
541
+ adapter_parameters.extend(attn_module.to_v.parameters())
542
+ adapter_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
543
+
544
+ self.adapter_parameters = adapter_parameters
545
+
546
+ def set_fully_trainable(self):
547
+
548
+ in_channels = 8
549
+ out_channels = self.unet.conv_in.out_channels
550
+ self.unet.register_to_config(in_channels=in_channels)
551
+ with torch.no_grad():
552
+ new_conv_in = nn.Conv2d(in_channels, out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride,
553
+ self.unet.conv_in.padding)
554
+
555
+ new_conv_in.weight.zero_()
556
+ new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
557
+ self.unet.conv_in = new_conv_in
558
+
559
+ self.unet.requires_grad_(True)
560
+ self.adapter_parameters = self.unet.parameters()
561
+
562
+ def params_to_opt(self):
563
+ return itertools.chain(self.resampler.parameters(), self.adapter_parameters)
564
+
565
+ def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise, time_ids, pooled_text_embeds=None):
566
+
567
+ text_embeds, pooled_text_embeds = self.resampler(text_embeds, pooled_text_embeds=pooled_text_embeds)
568
+ unet_added_conditions = {"time_ids": time_ids, 'text_embeds': pooled_text_embeds}
569
+
570
+ noise_pred = self.unet(noisy_latents, timesteps, text_embeds, added_cond_kwargs=unet_added_conditions).sample
571
+
572
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
573
+ return {'total_loss': loss, 'noise_pred': noise_pred}
574
+
575
+ def encode_text_embeds(self, text_embeds, pooled_text_embeds=None):
576
+ text_embeds, pooled_text_embeds = self.resampler(text_embeds, pooled_text_embeds=pooled_text_embeds)
577
+
578
+ return text_embeds, pooled_text_embeds
579
+
580
+ @classmethod
581
+ def from_pretrained(cls, unet, resampler, pretrained_model_path=None, **kwargs):
582
+ model = cls(unet=unet, resampler=resampler, **kwargs)
583
+ if pretrained_model_path is not None:
584
+ ckpt = torch.load(pretrained_model_path, map_location='cpu')
585
+ missing, unexpected = model.load_state_dict(ckpt, strict=False)
586
+ print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
587
+ return model
588
+
589
+ def init_pipe(self,
590
+ vae,
591
+ scheduler,
592
+ text_encoder,
593
+ text_encoder_2,
594
+ tokenizer,
595
+ tokenizer_2,
596
+ dtype=torch.float16,
597
+ device='cuda'):
598
+ self.device = device
599
+ self.dtype = dtype
600
+
601
+ sdxl_pipe = StableDiffusionXLText2ImageAndEditPipeline(
602
+ tokenizer=None,
603
+ tokenizer_2=None,
604
+ text_encoder=None,
605
+ text_encoder_2=None,
606
+ vae=vae,
607
+ unet=self.unet,
608
+ scheduler=scheduler,
609
+ )
610
+
611
+ self.sdxl_pipe = sdxl_pipe
612
+ self.sdxl_pipe.to(device, dtype=dtype)
613
+
614
+ self.tokenizer = tokenizer
615
+ self.tokenizer_2 = tokenizer_2
616
+ self.text_encoder = text_encoder
617
+ self.text_encoder_2 = text_encoder_2
618
+
619
+ @torch.inference_mode()
620
+ def get_text_embeds(self, prompt=None, negative_prompt='', text_embeds=None):
621
+ assert int(prompt is not None) + int(text_embeds is not None) == 1
622
+
623
+ if prompt is not None:
624
+ text_input_ids = self.tokenizer([prompt, negative_prompt],
625
+ max_length=self.tokenizer.model_max_length,
626
+ padding="max_length",
627
+ truncation=True,
628
+ return_tensors="pt").input_ids
629
+ text_input_ids_2 = self.tokenizer_2([prompt, negative_prompt],
630
+ max_length=self.tokenizer.model_max_length,
631
+ padding="max_length",
632
+ truncation=True,
633
+ return_tensors="pt").input_ids
634
+ encoder_output = self.text_encoder(text_input_ids.to(self.device), output_hidden_states=True)
635
+ text_embeds = encoder_output.hidden_states[-2]
636
+
637
+ encoder_output_2 = self.text_encoder_2(text_input_ids_2.to(self.device), output_hidden_states=True)
638
+ pooled_text_embeds = encoder_output_2[0]
639
+ text_embeds_2 = encoder_output_2.hidden_states[-2]
640
+
641
+ text_embeds = torch.cat([text_embeds, text_embeds_2], dim=-1)
642
+ else:
643
+ text_input_ids = self.tokenizer(negative_prompt,
644
+ max_length=self.tokenizer.model_max_length,
645
+ padding="max_length",
646
+ truncation=True,
647
+ return_tensors="pt").input_ids
648
+ text_input_ids_2 = self.tokenizer_2(negative_prompt,
649
+ max_length=self.tokenizer.model_max_length,
650
+ padding="max_length",
651
+ truncation=True,
652
+ return_tensors="pt").input_ids
653
+ encoder_output = self.text_encoder(text_input_ids.to(self.device), output_hidden_states=True)
654
+ text_embeds_neg = encoder_output.hidden_states[-2]
655
+
656
+ encoder_output_2 = self.text_encoder_2(text_input_ids_2.to(self.device), output_hidden_states=True)
657
+ text_embeds_neg_2 = encoder_output_2.hidden_states[-2]
658
+ pooled_text_embeds = encoder_output_2[0]
659
+
660
+ text_embeds_neg = torch.cat([text_embeds_neg, text_embeds_neg_2], dim=-1)
661
+
662
+ text_embeds = torch.cat([text_embeds, text_embeds_neg], dim=0)
663
+
664
+ text_embeds, pooled_text_embeds = self.encode_text_embeds(text_embeds, pooled_text_embeds=pooled_text_embeds)
665
+ text_embeds, text_embeds_neg = text_embeds.chunk(2)
666
+ pooled_text_embeds, pooled_text_embeds_neg = pooled_text_embeds.chunk(2)
667
+
668
+ return text_embeds, text_embeds_neg, pooled_text_embeds, pooled_text_embeds_neg
669
+
670
+ def generate(self,
671
+ prompt=None,
672
+ negative_prompt='',
673
+ image=None,
674
+ text_embeds=None,
675
+ seed=42,
676
+ height=1024,
677
+ width=1024,
678
+ guidance_scale=7.5,
679
+ num_inference_steps=30,
680
+ **kwargs):
681
+
682
+ text_embeds, text_embeds_neg, pooled_text_embeds, pooled_text_embeds_neg = self.get_text_embeds(
683
+ prompt=prompt, negative_prompt=negative_prompt, text_embeds=text_embeds)
684
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
685
+
686
+ images = self.sdxl_pipe(
687
+ image=image,
688
+ prompt_embeds=text_embeds,
689
+ negative_prompt_embeds=text_embeds_neg,
690
+ pooled_prompt_embeds=pooled_text_embeds,
691
+ negative_pooled_prompt_embeds=pooled_text_embeds_neg,
692
+ guidance_scale=guidance_scale,
693
+ num_inference_steps=num_inference_steps,
694
+ generator=generator,
695
+ height=height,
696
+ width=width,
697
+ **kwargs,
698
+ ).images
699
+
700
+ return images
701
+
702
+
703
+ class SD21Text2ImageAndEditAdapter(SDXLText2ImageAndEditAdapter):
704
+
705
+ def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise):
706
+
707
+ text_embeds, _ = self.resampler(text_embeds)
708
+ # unet_added_conditions = {"time_ids": time_ids, 'text_embeds': pooled_text_embeds}
709
+
710
+ noise_pred = self.unet(noisy_latents, timesteps, text_embeds).sample
711
+
712
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
713
+ return {'total_loss': loss, 'noise_pred': noise_pred}
714
+
715
+ def init_pipe(self,
716
+ vae,
717
+ scheduler,
718
+ text_encoder,
719
+ tokenizer,
720
+ feature_extractor,
721
+ dtype=torch.float16,
722
+ device='cuda'):
723
+ self.device = device
724
+ self.dtype = dtype
725
+
726
+ sd_pipe = StableDiffusionText2ImageAndEditPipeline(
727
+ tokenizer=tokenizer,
728
+ text_encoder=text_encoder,
729
+ vae=vae,
730
+ unet=self.unet,
731
+ feature_extractor=feature_extractor,
732
+ safety_checker=None,
733
+ requires_safety_checker=False,
734
+ scheduler=scheduler,
735
+ )
736
+
737
+ self.sd_pipe = sd_pipe
738
+ self.sd_pipe.to(device, dtype=dtype)
739
+
740
+ self.tokenizer = tokenizer
741
+ self.text_encoder = text_encoder
742
+
743
+ @torch.inference_mode()
744
+ def get_text_embeds(self, prompt=None, negative_prompt='', text_embeds=None):
745
+ assert int(prompt is not None) + int(text_embeds is not None) == 1
746
+
747
+ if prompt is not None:
748
+ text_input_ids = self.tokenizer([prompt, negative_prompt],
749
+ max_length=self.tokenizer.model_max_length,
750
+ padding="max_length",
751
+ truncation=True,
752
+ return_tensors="pt").input_ids
753
+ encoder_output = self.text_encoder(text_input_ids.to(self.device))
754
+ text_embeds = encoder_output[0]
755
+
756
+ else:
757
+ text_input_ids = self.tokenizer(negative_prompt,
758
+ max_length=self.tokenizer.model_max_length,
759
+ padding="max_length",
760
+ truncation=True,
761
+ return_tensors="pt").input_ids
762
+ encoder_output = self.text_encoder(text_input_ids.to(self.device))
763
+ text_embeds_neg = encoder_output[0]
764
+
765
+ text_embeds = torch.cat([text_embeds, text_embeds_neg], dim=0)
766
+
767
+ text_embeds, _ = self.encode_text_embeds(text_embeds)
768
+ text_embeds, text_embeds_neg = text_embeds.chunk(2)
769
+
770
+ return text_embeds, text_embeds_neg
771
+
772
+ def generate(self,
773
+ prompt=None,
774
+ negative_prompt='',
775
+ image=None,
776
+ text_embeds=None,
777
+ seed=42,
778
+ height=1024,
779
+ width=1024,
780
+ guidance_scale=7.5,
781
+ num_inference_steps=30,
782
+ **kwargs):
783
+
784
+ text_embeds, text_embeds_neg = self.get_text_embeds(
785
+ prompt=prompt, negative_prompt=negative_prompt, text_embeds=text_embeds)
786
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
787
+
788
+ print(f'text_embeds: {text_embeds.shape}')
789
+ print(f'text_embeds_neg: {text_embeds_neg.shape}')
790
+ images = self.sd_pipe(
791
+ image=image,
792
+ prompt_embeds=text_embeds,
793
+ negative_prompt_embeds=text_embeds_neg,
794
+ guidance_scale=guidance_scale,
795
+ num_inference_steps=num_inference_steps,
796
+ generator=generator,
797
+ height=height,
798
+ width=width,
799
+ **kwargs,
800
+ ).images
801
+
802
+ return images
803
+
804
+
805
+ class SDXLAdapterWithLatentImage(SDXLAdapter):
806
+ def __init__(self, unet, resampler, full_ft=False, set_trainable_late=False) -> None:
807
+ nn.Module.__init__(self)
808
+ self.unet = unet
809
+ self.resampler = resampler
810
+ self.full_ft = full_ft
811
+ if not set_trainable_late:
812
+ self.set_trainable()
813
+
814
+ def set_trainable(self):
815
+ self.resampler.requires_grad_(True)
816
+ adapter_parameters = []
817
+
818
+ in_channels = 8
819
+ out_channels = self.unet.conv_in.out_channels
820
+ self.unet.register_to_config(in_channels=in_channels)
821
+ self.unet.requires_grad_(False)
822
+ with torch.no_grad():
823
+ new_conv_in = nn.Conv2d(in_channels, out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride,
824
+ self.unet.conv_in.padding)
825
+
826
+ new_conv_in.weight.zero_()
827
+ new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
828
+ self.unet.conv_in = new_conv_in
829
+ self.unet.conv_in.requires_grad_(True)
830
+
831
+ if self.full_ft:
832
+ self.unet.requires_grad_(True)
833
+ adapter_parameters.extend(self.unet.parameters())
834
+ else:
835
+ adapter_parameters.extend(self.unet.conv_in.parameters())
836
+ for name, module in self.unet.named_modules():
837
+ if name.endswith('to_k') or name.endswith('to_v'):
838
+ if module is not None:
839
+ adapter_parameters.extend(module.parameters())
840
+ self.adapter_parameters = adapter_parameters
841
+
842
+ @classmethod
843
+ def from_pretrained(cls, unet, resampler, pretrained_model_path=None, set_trainable_late=False, **kwargs):
844
+ model = cls(unet=unet, resampler=resampler, set_trainable_late=set_trainable_late, **kwargs)
845
+ if pretrained_model_path is not None:
846
+ ckpt = torch.load(pretrained_model_path, map_location='cpu')
847
+ missing, unexpected = model.load_state_dict(ckpt, strict=False)
848
+ print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
849
+ if set_trainable_late:
850
+ model.set_trainable()
851
+ return model
852
+
853
+ def init_pipe(self,
854
+ vae,
855
+ scheduler,
856
+ visual_encoder,
857
+ image_transform,
858
+ dtype=torch.float16,
859
+ device='cuda'):
860
+ self.device = device
861
+ self.dtype = dtype
862
+
863
+ sdxl_pipe = StableDiffusionXLText2ImageAndEditPipeline(
864
+ tokenizer=None,
865
+ tokenizer_2=None,
866
+ text_encoder=None,
867
+ text_encoder_2=None,
868
+ vae=vae,
869
+ unet=self.unet,
870
+ scheduler=scheduler,
871
+ )
872
+
873
+ self.sdxl_pipe = sdxl_pipe
874
+ self.sdxl_pipe.to(device, dtype=dtype)
875
+ self.discrete_model = None
876
+
877
+ self.visual_encoder = visual_encoder.to(self.device, dtype=self.dtype)
878
+ self.image_transform = image_transform
879
+
880
+ def generate(self,
881
+ image_pil=None,
882
+ image_tensor=None,
883
+ image_embeds=None,
884
+ latent_image=None,
885
+ seed=42,
886
+ height=1024,
887
+ width=1024,
888
+ guidance_scale=7.5,
889
+ num_inference_steps=30,
890
+ input_image_size=448,
891
+ **kwargs):
892
+ if image_pil is not None:
893
+ assert isinstance(image_pil, Image.Image)
894
+
895
+ image_prompt_embeds, uncond_image_prompt_embeds, \
896
+ pooled_image_prompt_embeds, pooled_uncond_image_prompt_embeds = self.get_image_embeds(
897
+ image_pil=image_pil,
898
+ image_tensor=image_tensor,
899
+ image_embeds=image_embeds,
900
+ return_negative=True,
901
+ image_size=input_image_size,
902
+ )
903
+ # print(image_prompt_embeds.shape, pooled_image_prompt_embeds.shape)
904
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
905
+
906
+ images = self.sdxl_pipe(
907
+ image=latent_image,
908
+ prompt_embeds=image_prompt_embeds,
909
+ negative_prompt_embeds=uncond_image_prompt_embeds,
910
+ pooled_prompt_embeds=pooled_image_prompt_embeds,
911
+ negative_pooled_prompt_embeds=pooled_uncond_image_prompt_embeds,
912
+ guidance_scale=guidance_scale,
913
+ num_inference_steps=num_inference_steps,
914
+ generator=generator,
915
+ height=height,
916
+ width=width,
917
+ **kwargs,
918
+ ).images
919
+
920
+ return images
src/models_ipa/attention_processor.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class AttnProcessor(nn.Module):
8
+ r"""
9
+ Default processor for performing attention-related computations.
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ hidden_size=None,
15
+ cross_attention_dim=None,
16
+ ):
17
+ super().__init__()
18
+
19
+ def __call__(
20
+ self,
21
+ attn,
22
+ hidden_states,
23
+ encoder_hidden_states=None,
24
+ attention_mask=None,
25
+ temb=None,
26
+ ):
27
+ residual = hidden_states
28
+
29
+ if attn.spatial_norm is not None:
30
+ hidden_states = attn.spatial_norm(hidden_states, temb)
31
+
32
+ input_ndim = hidden_states.ndim
33
+
34
+ if input_ndim == 4:
35
+ batch_size, channel, height, width = hidden_states.shape
36
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
37
+
38
+ batch_size, sequence_length, _ = (
39
+ hidden_states.shape
40
+ if encoder_hidden_states is None
41
+ else encoder_hidden_states.shape
42
+ )
43
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
44
+
45
+ if attn.group_norm is not None:
46
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
47
+
48
+ query = attn.to_q(hidden_states)
49
+
50
+ if encoder_hidden_states is None:
51
+ encoder_hidden_states = hidden_states
52
+ elif attn.norm_cross:
53
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
54
+
55
+ key = attn.to_k(encoder_hidden_states)
56
+ value = attn.to_v(encoder_hidden_states)
57
+
58
+ query = attn.head_to_batch_dim(query)
59
+ key = attn.head_to_batch_dim(key)
60
+ value = attn.head_to_batch_dim(value)
61
+
62
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
63
+ hidden_states = torch.bmm(attention_probs, value)
64
+ hidden_states = attn.batch_to_head_dim(hidden_states)
65
+
66
+ # linear proj
67
+ hidden_states = attn.to_out[0](hidden_states)
68
+ # dropout
69
+ hidden_states = attn.to_out[1](hidden_states)
70
+
71
+ if input_ndim == 4:
72
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
73
+
74
+ if attn.residual_connection:
75
+ hidden_states = hidden_states + residual
76
+
77
+ hidden_states = hidden_states / attn.rescale_output_factor
78
+
79
+ return hidden_states
80
+
81
+
82
+ class IPAttnProcessor(nn.Module):
83
+ r"""
84
+ Attention processor for IP-Adapater.
85
+ Args:
86
+ hidden_size (`int`):
87
+ The hidden size of the attention layer.
88
+ cross_attention_dim (`int`):
89
+ The number of channels in the `encoder_hidden_states`.
90
+ text_context_len (`int`, defaults to 77):
91
+ The context length of the text features.
92
+ scale (`float`, defaults to 1.0):
93
+ the weight scale of image prompt.
94
+ """
95
+
96
+ def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
97
+ super().__init__()
98
+
99
+ self.hidden_size = hidden_size
100
+ self.cross_attention_dim = cross_attention_dim
101
+ self.text_context_len = text_context_len
102
+ self.scale = scale
103
+
104
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
105
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
106
+
107
+ def __call__(
108
+ self,
109
+ attn,
110
+ hidden_states,
111
+ encoder_hidden_states=None,
112
+ attention_mask=None,
113
+ temb=None,
114
+ ):
115
+ residual = hidden_states
116
+
117
+ if attn.spatial_norm is not None:
118
+ hidden_states = attn.spatial_norm(hidden_states, temb)
119
+
120
+ input_ndim = hidden_states.ndim
121
+
122
+ if input_ndim == 4:
123
+ batch_size, channel, height, width = hidden_states.shape
124
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
125
+
126
+ batch_size, sequence_length, _ = (
127
+ hidden_states.shape
128
+ if encoder_hidden_states is None
129
+ else encoder_hidden_states.shape
130
+ )
131
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
132
+
133
+ if attn.group_norm is not None:
134
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
135
+
136
+ query = attn.to_q(hidden_states)
137
+
138
+ if encoder_hidden_states is None:
139
+ encoder_hidden_states = hidden_states
140
+ elif attn.norm_cross:
141
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
142
+
143
+ # split hidden states
144
+ encoder_hidden_states, \
145
+ ip_hidden_states = \
146
+ encoder_hidden_states[:, :self.text_context_len, :], \
147
+ encoder_hidden_states[:, self.text_context_len:, :]
148
+
149
+ key = attn.to_k(encoder_hidden_states)
150
+ value = attn.to_v(encoder_hidden_states)
151
+
152
+ query = attn.head_to_batch_dim(query)
153
+ key = attn.head_to_batch_dim(key)
154
+ value = attn.head_to_batch_dim(value)
155
+
156
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
157
+ hidden_states = torch.bmm(attention_probs, value)
158
+ hidden_states = attn.batch_to_head_dim(hidden_states)
159
+
160
+ # for ip-adapter
161
+ ip_key = self.to_k_ip(ip_hidden_states)
162
+ ip_value = self.to_v_ip(ip_hidden_states)
163
+
164
+ ip_key = attn.head_to_batch_dim(ip_key)
165
+ ip_value = attn.head_to_batch_dim(ip_value)
166
+
167
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
168
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
169
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
170
+
171
+ hidden_states = hidden_states + self.scale * ip_hidden_states
172
+
173
+ # linear proj
174
+ hidden_states = attn.to_out[0](hidden_states)
175
+ # dropout
176
+ hidden_states = attn.to_out[1](hidden_states)
177
+
178
+ if input_ndim == 4:
179
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
180
+
181
+ if attn.residual_connection:
182
+ hidden_states = hidden_states + residual
183
+
184
+ hidden_states = hidden_states / attn.rescale_output_factor
185
+
186
+ return hidden_states
187
+
188
+
189
+ class AttnProcessor2_0(torch.nn.Module):
190
+ r"""
191
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ hidden_size=None,
197
+ cross_attention_dim=None,
198
+ ):
199
+ super().__init__()
200
+ if not hasattr(F, "scaled_dot_product_attention"):
201
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
202
+
203
+ def __call__(
204
+ self,
205
+ attn,
206
+ hidden_states,
207
+ encoder_hidden_states=None,
208
+ attention_mask=None,
209
+ temb=None,
210
+ ):
211
+ residual = hidden_states
212
+
213
+ if attn.spatial_norm is not None:
214
+ hidden_states = attn.spatial_norm(hidden_states, temb)
215
+
216
+ input_ndim = hidden_states.ndim
217
+
218
+ if input_ndim == 4:
219
+ batch_size, channel, height, width = hidden_states.shape
220
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
221
+
222
+ batch_size, sequence_length, _ = (
223
+ hidden_states.shape
224
+ if encoder_hidden_states is None
225
+ else encoder_hidden_states.shape
226
+ )
227
+
228
+ if attention_mask is not None:
229
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
230
+ # scaled_dot_product_attention expects attention_mask shape to be
231
+ # (batch, heads, source_length, target_length)
232
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
233
+
234
+ if attn.group_norm is not None:
235
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
236
+
237
+ query = attn.to_q(hidden_states)
238
+
239
+ if encoder_hidden_states is None:
240
+ encoder_hidden_states = hidden_states
241
+ elif attn.norm_cross:
242
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
243
+
244
+ key = attn.to_k(encoder_hidden_states)
245
+ value = attn.to_v(encoder_hidden_states)
246
+
247
+ inner_dim = key.shape[-1]
248
+ head_dim = inner_dim // attn.heads
249
+
250
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
251
+
252
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
253
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
254
+
255
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
256
+ # TODO: add support for attn.scale when we move to Torch 2.1
257
+ hidden_states = F.scaled_dot_product_attention(query,
258
+ key,
259
+ value,
260
+ attn_mask=attention_mask,
261
+ dropout_p=0.0,
262
+ is_causal=False)
263
+
264
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
265
+ hidden_states = hidden_states.to(query.dtype)
266
+
267
+ # linear proj
268
+ hidden_states = attn.to_out[0](hidden_states)
269
+ # dropout
270
+ hidden_states = attn.to_out[1](hidden_states)
271
+
272
+ if input_ndim == 4:
273
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
274
+
275
+ if attn.residual_connection:
276
+ hidden_states = hidden_states + residual
277
+
278
+ hidden_states = hidden_states / attn.rescale_output_factor
279
+
280
+ return hidden_states
281
+
282
+
283
+ class IPAttnProcessor2_0(torch.nn.Module):
284
+ r"""
285
+ Attention processor for IP-Adapater for PyTorch 2.0.
286
+ Args:
287
+ hidden_size (`int`):
288
+ The hidden size of the attention layer.
289
+ cross_attention_dim (`int`):
290
+ The number of channels in the `encoder_hidden_states`.
291
+ text_context_len (`int`, defaults to 77):
292
+ The context length of the text features.
293
+ scale (`float`, defaults to 1.0):
294
+ the weight scale of image prompt.
295
+ """
296
+
297
+ def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
298
+ super().__init__()
299
+
300
+ if not hasattr(F, "scaled_dot_product_attention"):
301
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
302
+
303
+ self.hidden_size = hidden_size
304
+ self.cross_attention_dim = cross_attention_dim
305
+ self.text_context_len = text_context_len
306
+ self.scale = scale
307
+
308
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
309
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
310
+
311
+ def __call__(
312
+ self,
313
+ attn,
314
+ hidden_states,
315
+ encoder_hidden_states=None,
316
+ attention_mask=None,
317
+ temb=None,
318
+ ):
319
+ residual = hidden_states
320
+
321
+ if attn.spatial_norm is not None:
322
+ hidden_states = attn.spatial_norm(hidden_states, temb)
323
+
324
+ input_ndim = hidden_states.ndim
325
+
326
+ if input_ndim == 4:
327
+ batch_size, channel, height, width = hidden_states.shape
328
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
329
+
330
+ batch_size, sequence_length, _ = (
331
+ hidden_states.shape
332
+ if encoder_hidden_states is None
333
+ else encoder_hidden_states.shape
334
+ )
335
+ if attention_mask is not None:
336
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
337
+ # scaled_dot_product_attention expects attention_mask shape to be
338
+ # (batch, heads, source_length, target_length)
339
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
340
+
341
+ if attn.group_norm is not None:
342
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
343
+
344
+ query = attn.to_q(hidden_states)
345
+
346
+ if encoder_hidden_states is None:
347
+ encoder_hidden_states = hidden_states
348
+ elif attn.norm_cross:
349
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
350
+
351
+ # split hidden states
352
+ encoder_hidden_states, \
353
+ ip_hidden_states = \
354
+ encoder_hidden_states[:, :self.text_context_len, :], \
355
+ encoder_hidden_states[:, self.text_context_len:, :]
356
+
357
+ key = attn.to_k(encoder_hidden_states)
358
+ value = attn.to_v(encoder_hidden_states)
359
+
360
+ inner_dim = key.shape[-1]
361
+ head_dim = inner_dim // attn.heads
362
+
363
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
364
+
365
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
366
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
367
+
368
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
369
+ # TODO: add support for attn.scale when we move to Torch 2.1
370
+ hidden_states = F.scaled_dot_product_attention(query,
371
+ key,
372
+ value,
373
+ attn_mask=attention_mask,
374
+ dropout_p=0.0,
375
+ is_causal=False)
376
+
377
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
378
+ hidden_states = hidden_states.to(query.dtype)
379
+
380
+ # for ip-adapter
381
+ ip_key = self.to_k_ip(ip_hidden_states)
382
+ ip_value = self.to_v_ip(ip_hidden_states)
383
+
384
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
385
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
386
+
387
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
388
+ # TODO: add support for attn.scale when we move to Torch 2.1
389
+ ip_hidden_states = F.scaled_dot_product_attention(query,
390
+ ip_key,
391
+ ip_value,
392
+ attn_mask=None,
393
+ dropout_p=0.0,
394
+ is_causal=False)
395
+
396
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
397
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
398
+
399
+ hidden_states = hidden_states + self.scale * ip_hidden_states
400
+
401
+ # linear proj
402
+ hidden_states = attn.to_out[0](hidden_states)
403
+ # dropout
404
+ hidden_states = attn.to_out[1](hidden_states)
405
+
406
+ if input_ndim == 4:
407
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
408
+
409
+ if attn.residual_connection:
410
+ hidden_states = hidden_states + residual
411
+
412
+ hidden_states = hidden_states / attn.rescale_output_factor
413
+
414
+ return hidden_states
src/models_ipa/ipa_utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+
3
+
4
+ def is_torch2_available():
5
+ return hasattr(F, "scaled_dot_product_attention")
src/models_ipa/resampler.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ # FFN
10
+ def FeedForward(dim, mult=4):
11
+ inner_dim = int(dim * mult)
12
+ return nn.Sequential(
13
+ nn.LayerNorm(dim),
14
+ nn.Linear(dim, inner_dim, bias=False),
15
+ nn.GELU(),
16
+ nn.Linear(inner_dim, dim, bias=False),
17
+ )
18
+
19
+
20
+ def reshape_tensor(x, heads):
21
+ bs, length, width = x.shape
22
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
23
+ x = x.view(bs, length, heads, -1)
24
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
25
+ x = x.transpose(1, 2)
26
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
27
+ x = x.reshape(bs, heads, length, -1)
28
+ return x
29
+
30
+
31
+ class PerceiverAttention(nn.Module):
32
+
33
+ def __init__(self, *, dim, dim_head=64, heads=8):
34
+ super().__init__()
35
+ self.scale = dim_head ** -0.5
36
+ self.dim_head = dim_head
37
+ self.heads = heads
38
+ inner_dim = dim_head * heads
39
+
40
+ self.norm1 = nn.LayerNorm(dim)
41
+ self.norm2 = nn.LayerNorm(dim)
42
+
43
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
44
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
45
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
46
+
47
+ def forward(self, x, latents):
48
+ """
49
+ Args:
50
+ x (torch.Tensor): image features
51
+ shape (b, n1, D)
52
+ latent (torch.Tensor): latent features
53
+ shape (b, n2, D)
54
+ """
55
+ x = self.norm1(x)
56
+ latents = self.norm2(latents)
57
+
58
+ b, l, _ = latents.shape
59
+
60
+ q = self.to_q(latents)
61
+ kv_input = torch.cat((x, latents), dim=-2)
62
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
63
+
64
+ q = reshape_tensor(q, self.heads)
65
+ k = reshape_tensor(k, self.heads)
66
+ v = reshape_tensor(v, self.heads)
67
+
68
+ # attention
69
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
70
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
71
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
72
+ out = weight @ v
73
+
74
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
75
+
76
+ return self.to_out(out)
77
+
78
+
79
+ class AttentionPool2d(nn.Module):
80
+
81
+ def __init__(self, seq_len: int, embed_dim: int, num_heads: int, output_dim: int = None):
82
+ super().__init__()
83
+ self.positional_embedding = nn.Parameter(torch.randn(seq_len + 1, embed_dim) / embed_dim ** 0.5)
84
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
85
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
86
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
87
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
88
+ self.num_heads = num_heads
89
+
90
+ def forward(self, x, return_all_tokens=False):
91
+ # x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
92
+ x = x.permute(1, 0, 2) # (N(HW)C) => (HW)NC
93
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
94
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
95
+ x, _ = F.multi_head_attention_forward(query=x,
96
+ key=x,
97
+ value=x,
98
+ embed_dim_to_check=x.shape[-1],
99
+ num_heads=self.num_heads,
100
+ q_proj_weight=self.q_proj.weight,
101
+ k_proj_weight=self.k_proj.weight,
102
+ v_proj_weight=self.v_proj.weight,
103
+ in_proj_weight=None,
104
+ in_proj_bias=torch.cat(
105
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
106
+ bias_k=None,
107
+ bias_v=None,
108
+ add_zero_attn=False,
109
+ dropout_p=0,
110
+ out_proj_weight=self.c_proj.weight,
111
+ out_proj_bias=self.c_proj.bias,
112
+ use_separate_proj_weight=True,
113
+ training=self.training,
114
+ need_weights=False)
115
+ if return_all_tokens:
116
+ return x
117
+ else:
118
+ return x[0]
119
+
120
+
121
+ class Resampler(nn.Module):
122
+
123
+ def __init__(
124
+ self,
125
+ dim=1024,
126
+ depth=8,
127
+ dim_head=64,
128
+ heads=16,
129
+ num_queries=8,
130
+ embedding_dim=768,
131
+ output_dim=1024,
132
+ ff_mult=4,
133
+ ):
134
+ super().__init__()
135
+
136
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
137
+
138
+ self.proj_in = nn.Linear(embedding_dim, dim)
139
+
140
+ self.proj_out = nn.Linear(dim, output_dim)
141
+ self.norm_out = nn.LayerNorm(output_dim)
142
+
143
+ self.in_dim = dim
144
+ self.out_dim = output_dim
145
+
146
+ self.layers = nn.ModuleList([])
147
+ for _ in range(depth):
148
+ self.layers.append(
149
+ nn.ModuleList([
150
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
151
+ FeedForward(dim=dim, mult=ff_mult),
152
+ ]))
153
+
154
+ def forward(self, x):
155
+
156
+ latents = self.latents.repeat(x.size(0), 1, 1)
157
+
158
+ x = self.proj_in(x)
159
+
160
+ for attn, ff in self.layers:
161
+ latents = attn(x, latents) + latents
162
+ latents = ff(latents) + latents
163
+
164
+ latents = self.proj_out(latents)
165
+ output_embeds = self.norm_out(latents)
166
+
167
+ return output_embeds
168
+
169
+
170
+ class ResamplerXL(nn.Module):
171
+
172
+ def __init__(
173
+ self,
174
+ dim=1024,
175
+ depth=8,
176
+ dim_head=64,
177
+ heads=16,
178
+ num_queries=8,
179
+ embedding_dim=768,
180
+ output1_dim=768,
181
+ output2_dim=1280,
182
+ ff_mult=4,
183
+ ):
184
+ super().__init__()
185
+
186
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
187
+
188
+ self.proj_in = nn.Linear(embedding_dim, dim)
189
+
190
+ # self.proj_out = nn.Linear(dim, output_dim)
191
+ self.norm_out = nn.LayerNorm(dim)
192
+
193
+ self.in_dim = dim
194
+ self.out_dim = output1_dim + output2_dim
195
+
196
+ self.layers = nn.ModuleList([])
197
+ for _ in range(depth):
198
+ self.layers.append(
199
+ nn.ModuleList([
200
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
201
+ FeedForward(dim=dim, mult=ff_mult),
202
+ ]))
203
+
204
+ self.unet_proj_1 = nn.Linear(self.in_dim, output1_dim)
205
+ self.unet_proj_2 = nn.Linear(self.in_dim, output2_dim)
206
+ self.unet_attnpool = AttentionPool2d(num_queries, self.in_dim, heads, output2_dim)
207
+
208
+ def forward(self, x):
209
+
210
+ latents = self.latents.repeat(x.size(0), 1, 1)
211
+
212
+ x = self.proj_in(x)
213
+
214
+ for attn, ff in self.layers:
215
+ latents = attn(x, latents) + latents
216
+ latents = ff(latents) + latents
217
+
218
+ hidden_embeds = self.norm_out(latents)
219
+
220
+ encoder_hidden_1 = self.unet_proj_1(hidden_embeds) # [bs, 256, 768]
221
+ encoder_hidden_2 = self.unet_proj_2(hidden_embeds) # [bs, 256, 1280]
222
+ prompt_embeds = torch.cat([encoder_hidden_1, encoder_hidden_2], dim=-1) # [bs, 256, 2048]
223
+ pooled_prompt_embeds = self.unet_attnpool(hidden_embeds) # [bs, 1280]
224
+
225
+ return prompt_embeds, pooled_prompt_embeds
226
+
227
+
228
+ class ResamplerXLV2(nn.Module):
229
+
230
+ def __init__(
231
+ self,
232
+ dim=1024,
233
+ depth=8,
234
+ dim_head=64,
235
+ heads=16,
236
+ num_queries=8,
237
+ embedding_dim=768,
238
+ output1_dim=768,
239
+ output2_dim=1280,
240
+ ff_mult=4,
241
+ ):
242
+ super().__init__()
243
+
244
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
245
+
246
+ self.proj_in = nn.Linear(embedding_dim, dim)
247
+
248
+ # self.proj_out = nn.Linear(dim, output_dim)
249
+ self.norm_out = nn.LayerNorm(dim)
250
+
251
+ self.in_dim = dim
252
+ self.out_dim = output1_dim + output2_dim
253
+
254
+ self.layers = nn.ModuleList([])
255
+ for _ in range(depth):
256
+ self.layers.append(
257
+ nn.ModuleList([
258
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
259
+ FeedForward(dim=dim, mult=ff_mult),
260
+ ]))
261
+
262
+ self.unet_proj_1 = nn.Linear(self.in_dim, output1_dim)
263
+ self.unet_proj_2 = nn.Linear(self.in_dim, output2_dim)
264
+ self.unet_attnpool = AttentionPool2d(num_queries, self.in_dim, heads, output2_dim)
265
+
266
+ def forward(self, x, pooled_text_embeds=None):
267
+
268
+ latents = self.latents.repeat(x.size(0), 1, 1)
269
+ x = F.normalize(x)
270
+
271
+ x = self.proj_in(x)
272
+
273
+ for attn, ff in self.layers:
274
+ latents = attn(x, latents) + latents
275
+ latents = ff(latents) + latents
276
+
277
+ hidden_embeds = self.norm_out(latents)
278
+
279
+ encoder_hidden_1 = self.unet_proj_1(hidden_embeds) # [bs, 256, 768]
280
+ encoder_hidden_2 = self.unet_proj_2(hidden_embeds) # [bs, 256, 1280]
281
+ prompt_embeds = torch.cat([encoder_hidden_1, encoder_hidden_2], dim=-1) # [bs, 256, 2048]
282
+ pooled_prompt_embeds = self.unet_attnpool(hidden_embeds) # [bs, 1280]
283
+
284
+ return prompt_embeds, pooled_prompt_embeds
285
+
286
+
287
+ class ResamplerXLIdentity(nn.Module):
288
+ def __init__(self) -> None:
289
+ super().__init__()
290
+
291
+ def forward(self, x, pooled_text_embeds=None):
292
+ return x, pooled_text_embeds
293
+
294
+
295
+ if __name__ == '__main__':
296
+ image_proj_model = Resampler(dim=1024,
297
+ depth=4,
298
+ dim_head=64,
299
+ heads=12,
300
+ num_queries=1024,
301
+ embedding_dim=1024,
302
+ output_dim=1024,
303
+ ff_mult=4)
304
+ numel = 0
305
+ for name, param in image_proj_model.named_parameters():
306
+ numel += param.numel()
307
+
308
+ print(f'Total params: {numel}')
src/processer/tokenizer.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer
2
+
3
+
4
+ def bert_tokenizer(pretrained_model_name_or_path):
5
+ tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path,
6
+ truncation_side='right')
7
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
8
+ return tokenizer
src/processer/transforms.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+
3
+
4
+ def get_transform(type='clip', keep_ratio=True, image_size=224):
5
+ if type == 'clip':
6
+ transform = []
7
+ if keep_ratio:
8
+ transform.extend([
9
+ transforms.Resize(image_size),
10
+ transforms.CenterCrop(image_size),
11
+ ])
12
+ else:
13
+ transform.append(transforms.Resize((image_size, image_size)))
14
+ transform.extend([
15
+ transforms.ToTensor(),
16
+ transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
17
+ ])
18
+
19
+ return transforms.Compose(transform)
20
+ elif type == 'clipa':
21
+ transform = []
22
+ if keep_ratio:
23
+ transform.extend([
24
+ transforms.Resize(image_size),
25
+ transforms.CenterCrop(image_size),
26
+ ])
27
+ else:
28
+ transform.append(transforms.Resize((image_size, image_size)))
29
+ transform.extend(
30
+ [transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])
31
+
32
+ return transforms.Compose(transform)
33
+ elif type == 'sd':
34
+ transform = []
35
+ if keep_ratio:
36
+ transform.extend([
37
+ transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
38
+ transforms.CenterCrop(image_size),
39
+ ])
40
+ else:
41
+ transform.append(
42
+ transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC))
43
+ transform.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
44
+
45
+ return transforms.Compose(transform)
46
+ else:
47
+ raise NotImplementedError
src/tools/reload_qwen_vit.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM
3
+
4
+ torch.manual_seed(1234)
5
+
6
+ qwen_model_path = 'pretrained/Qwen-VL-Chat'
7
+ save_path = 'pretrained/QwenViT/qwen_vit_G.pt'
8
+
9
+ model = AutoModelForCausalLM.from_pretrained(qwen_model_path, device_map="cpu", trust_remote_code=True).eval()
10
+
11
+ visual_encoder = model.transformer.visual
12
+ print(visual_encoder)
13
+
14
+ torch.save(visual_encoder.state_dict(), save_path)
src/train/dist_utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+
5
+ def all_gather(tensor):
6
+ world_size = dist.get_world_size()
7
+ tensor_list = [torch.zeros_like(tensor) for _ in range(world_size)]
8
+ dist.all_gather(tensor_list, tensor)
9
+ return tensor_list
10
+
11
+
12
+ def is_dist_avail_and_initialized():
13
+ if not dist.is_available():
14
+ return False
15
+ if not dist.is_initialized():
16
+ return False
17
+ return True
18
+
19
+
20
+ @torch.no_grad()
21
+ def concat_all_gather(tensor):
22
+ """
23
+ Performs all_gather operation on the provided tensors.
24
+ *** Warning ***: torch.distributed.all_gather has no gradient.
25
+ """
26
+ # if use distributed training
27
+ if not is_dist_avail_and_initialized():
28
+ return tensor
29
+
30
+ tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
31
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
32
+
33
+ output = torch.cat(tensors_gather, dim=0)
34
+ return output
src/train/schedular.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from functools import partial
4
+ from typing import Callable, Iterable, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.optim import Optimizer
9
+ from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
10
+ from transformers.trainer_utils import SchedulerType
11
+ from transformers.utils import logging
12
+
13
+ from transformers.optimization import get_linear_schedule_with_warmup, \
14
+ get_cosine_with_hard_restarts_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, \
15
+ get_constant_schedule, get_constant_schedule_with_warmup, get_inverse_sqrt_schedule, get_reduce_on_plateau_schedule
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ def _get_cosine_schedule_with_warmup_lr_lambda(current_step: int,
21
+ *,
22
+ num_warmup_steps: int,
23
+ num_training_steps: int,
24
+ num_cycles: float,
25
+ min_lr_ratio: float = 0.0):
26
+ if current_step < num_warmup_steps:
27
+ return float(current_step) / float(max(1, num_warmup_steps))
28
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
29
+ # return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
30
+ return max(0.0,
31
+ 0.5 * ((1.0 + min_lr_ratio) + (1.0 - min_lr_ratio) * math.cos(
32
+ math.pi * float(num_cycles) * 2.0 * progress)))
33
+
34
+
35
+ def get_cosine_schedule_with_warmup(optimizer: Optimizer,
36
+ num_warmup_steps: int,
37
+ num_training_steps: int,
38
+ num_cycles: float = 0.5,
39
+ last_epoch: int = -1,
40
+ min_lr_ratio: float = 0.0):
41
+ """
42
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
43
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
44
+ initial lr set in the optimizer.
45
+
46
+ Args:
47
+ optimizer ([`~torch.optim.Optimizer`]):
48
+ The optimizer for which to schedule the learning rate.
49
+ num_warmup_steps (`int`):
50
+ The number of steps for the warmup phase.
51
+ num_training_steps (`int`):
52
+ The total number of training steps.
53
+ num_cycles (`float`, *optional*, defaults to 0.5):
54
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
55
+ following a half-cosine).
56
+ last_epoch (`int`, *optional*, defaults to -1):
57
+ The index of the last epoch when resuming training.
58
+
59
+ Return:
60
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
61
+ """
62
+
63
+ lr_lambda = partial(
64
+ _get_cosine_schedule_with_warmup_lr_lambda,
65
+ num_warmup_steps=num_warmup_steps,
66
+ num_training_steps=num_training_steps,
67
+ num_cycles=num_cycles,
68
+ min_lr_ratio=min_lr_ratio,
69
+ )
70
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
71
+
72
+
73
+ TYPE_TO_SCHEDULER_FUNCTION = {
74
+ SchedulerType.LINEAR: get_linear_schedule_with_warmup,
75
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
76
+ SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
77
+ SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
78
+ SchedulerType.CONSTANT: get_constant_schedule,
79
+ SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
80
+ SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
81
+ SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule,
82
+ }
83
+
84
+
85
+ def get_scheduler(
86
+ name: Union[str, SchedulerType],
87
+ optimizer: Optimizer,
88
+ num_warmup_steps: Optional[int] = None,
89
+ num_training_steps: Optional[int] = None,
90
+ min_lr_ratio: Optional[float] = 0.0,
91
+ ):
92
+ """
93
+ Unified API to get any scheduler from its name.
94
+
95
+ Args:
96
+ name (`str` or `SchedulerType`):
97
+ The name of the scheduler to use.
98
+ optimizer (`torch.optim.Optimizer`):
99
+ The optimizer that will be used during training.
100
+ num_warmup_steps (`int`, *optional*):
101
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
102
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
103
+ num_training_steps (`int``, *optional*):
104
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
105
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
106
+ """
107
+ name = SchedulerType(name)
108
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
109
+ if name == SchedulerType.CONSTANT or name == SchedulerType.REDUCE_ON_PLATEAU:
110
+ return schedule_func(optimizer)
111
+
112
+ # All other schedulers require `num_warmup_steps`
113
+ if num_warmup_steps is None:
114
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
115
+
116
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
117
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
118
+
119
+ if name == SchedulerType.INVERSE_SQRT:
120
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
121
+
122
+ # All other schedulers require `num_training_steps`
123
+ if num_training_steps is None:
124
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
125
+
126
+ logger.info(f'Initialize lr scheduler with min_lr_ratio: {min_lr_ratio}')
127
+ return schedule_func(optimizer,
128
+ num_warmup_steps=num_warmup_steps,
129
+ num_training_steps=num_training_steps,
130
+ min_lr_ratio=min_lr_ratio)
src/train/train.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ import hydra
3
+
4
+ import pyrootutils
5
+ import os
6
+ import torch
7
+ from accelerate import Accelerator
8
+ from accelerate.logging import get_logger
9
+ from accelerate.utils import ProjectConfiguration
10
+
11
+ from tqdm.auto import tqdm
12
+ from omegaconf import OmegaConf
13
+ from omegaconf.dictconfig import DictConfig
14
+ import argparse
15
+ from flask import Flask, request
16
+ from typing import List, Union
17
+ import json
18
+ from typing import Optional
19
+ import transformers
20
+ from dataclasses import dataclass, field, asdict, is_dataclass
21
+ from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, DistributedReadingService, \
22
+ SequentialReadingService
23
+ import logging
24
+
25
+ pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
26
+ from src.train.schedular import get_scheduler
27
+ from src.train.dist_utils import all_gather
28
+
29
+ # logger = get_logger(__name__, log_level='info')
30
+ log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
31
+ logging.basicConfig(level=logging.INFO, format=log_format)
32
+
33
+ logger = logging.getLogger(__name__)
34
+ os.environ["WANDB_MODE"] = "offline"
35
+
36
+
37
+ @dataclass
38
+ class ConfigPathArguments:
39
+ image_transform: Optional[str] = field(default=None, metadata={"help": "config path of image transform"})
40
+ tokenizer: Optional[str] = field(default=None,
41
+ metadata={"help": "config path of tokenizer used to initialize tokenizer"})
42
+ # model: Optional[str] = field(default=None, metadata={"help": "config path of llm"})
43
+ visual_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"})
44
+ text_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"})
45
+ discrete_model: Optional[str] = field(default=None, metadata={"help": "config path of discrete model"})
46
+ train_dataset: Optional[str] = field(default=None, metadata={"help": "config path of training dataset"})
47
+
48
+
49
+ @dataclass
50
+ class TrainingArguments:
51
+ output_dir: str = field(
52
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, )
53
+ resume_from_checkpoint: Optional[str] = field(
54
+ default=None, metadata={"help": "The path to a folder with a valid checkpoint for your model."})
55
+ resume_steps: Optional[int] = field(default=None, metadata={"help": "The training sterps of saved checkpoint"})
56
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
57
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
58
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
59
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
60
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
61
+ max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
62
+ gradient_accumulation_steps: int = field(
63
+ default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."})
64
+ mixed_precision: Optional[str] = field(
65
+ default='no',
66
+ metadata={
67
+ "help":
68
+ "Whether to use mixed precision. \
69
+ Choose between fp16 and bf16 (bfloat16). \
70
+ Bf16 requires PyTorch >=1.10.and an Nvidia Ampere GPU."
71
+ })
72
+ num_train_epochs: int = field(default=3, metadata={"help": "Total number of training epochs to perform."})
73
+ max_steps: int = field(default=-1, metadata={"help": "Total number of training steps to perform. "})
74
+ save_steps: int = field(default=10000, metadata={"help": "Number of updates steps before two checkpoint saves."})
75
+ lr_scheduler_type: str = field(default="cosine", metadata={"help": "The scheduler type to use."})
76
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
77
+ min_lr_ratio: float = field(default=0.01, metadata={"help": "Minimal learning rate ratio."})
78
+ dataloader_num_workers: int = field(default=8, metadata={"help": "The number of workers to use for data loading."})
79
+ project_name: str = field(default="DiscreteLearning", metadata={"help": "The name of experiment"})
80
+ expr_name: str = field(default="", metadata={"help": "The name of experiment"})
81
+
82
+
83
+ def build_dataloader(dataset_cfg, image_transform, tokenizer, dataloader_num_workers=4):
84
+ dataset = hydra.utils.instantiate(dataset_cfg, image_transform=image_transform, tokenizer=tokenizer)
85
+ mp_service = MultiProcessingReadingService(num_workers=dataloader_num_workers)
86
+ dist_service = DistributedReadingService()
87
+ reading_service = SequentialReadingService(dist_service, mp_service)
88
+ dataloader = DataLoader2(dataset, reading_service=reading_service)
89
+ return dataloader
90
+
91
+
92
+ def get_metric(output):
93
+ metric = {}
94
+ for key, value in output.items():
95
+ if 'loss' in key:
96
+ metric[key] = value.item()
97
+ return metric
98
+
99
+
100
+ def get_code_usage(indices):
101
+ indices_list = all_gather(indices)
102
+ indices = torch.cat(indices_list, dim=0)
103
+ code_usage = indices.unique().numel()
104
+ return code_usage
105
+
106
+
107
+ def merge_config(**kwargs):
108
+ config = {}
109
+ for key, value in kwargs.items():
110
+ if isinstance(value, argparse.Namespace):
111
+ config[key] = vars(value)
112
+ elif isinstance(value, DictConfig):
113
+ config[key] = OmegaConf.to_object(value)
114
+ elif is_dataclass(value):
115
+ config[key] = asdict(value)
116
+ elif isinstance(value, dict):
117
+ config[key] = value
118
+ else:
119
+ logger.error(f'key: {key}, value: {value} will not be merged.')
120
+ return config
121
+
122
+
123
+ def trainable_params(model):
124
+ count = 0
125
+ for name, param in model.named_parameters():
126
+ count += param.numel()
127
+ return count
128
+
129
+
130
+ def train():
131
+ parser = transformers.HfArgumentParser((ConfigPathArguments, TrainingArguments))
132
+ cfg_path, args = parser.parse_args_into_dataclasses()
133
+
134
+ project_config = ProjectConfiguration(project_dir=args.output_dir,
135
+ logging_dir=os.path.join(args.output_dir, 'logs'))
136
+
137
+ accelerator = Accelerator(
138
+ mixed_precision=args.mixed_precision,
139
+ log_with=['tensorboard', 'wandb'],
140
+ project_config=project_config,
141
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
142
+ step_scheduler_with_optimizer=False,
143
+ )
144
+ logger.info('Init accelerator done.')
145
+
146
+ os.makedirs(args.output_dir, exist_ok=True)
147
+
148
+ visual_encoder_cfg = OmegaConf.load(cfg_path.visual_encoder)
149
+ visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
150
+ logger.info('Load visual encoder done.')
151
+
152
+ discrete_model_cfg = OmegaConf.load(cfg_path.discrete_model)
153
+ discrete_model = hydra.utils.instantiate(discrete_model_cfg)
154
+ logger.info('Load discrete model done.')
155
+
156
+ train_dataset_cfg = OmegaConf.load(cfg_path.train_dataset)
157
+
158
+ if cfg_path.text_encoder is not None:
159
+ text_encoder_cfg = OmegaConf.load(cfg_path.text_encoder)
160
+ text_encoder = hydra.utils.instantiate(text_encoder_cfg)
161
+ else:
162
+ text_encoder_cfg = None
163
+ text_encoder = None
164
+
165
+ if cfg_path.image_transform is not None:
166
+ image_transform_cfg = OmegaConf.load(cfg_path.image_transform)
167
+ image_transform = hydra.utils.instantiate(image_transform_cfg)
168
+ else:
169
+ image_transform_cfg = None
170
+ image_transform = None
171
+
172
+ if cfg_path.tokenizer is not None:
173
+ tokenizer_cfg = OmegaConf.load(cfg_path.tokenizer)
174
+ tokenizer = hydra.utils.instantiate(tokenizer_cfg)
175
+ else:
176
+ tokenizer_cfg = None
177
+ tokenizer = None
178
+
179
+ weight_dtype = torch.float32
180
+ if accelerator.mixed_precision == "fp16":
181
+ weight_dtype = torch.float16
182
+ elif accelerator.mixed_precision == "bf16":
183
+ weight_dtype = torch.bfloat16
184
+
185
+ visual_encoder.to(accelerator.device, dtype=weight_dtype)
186
+ logger.info('Freeze visual encoder...')
187
+ visual_encoder.requires_grad_(False)
188
+ if text_encoder is not None:
189
+ logger.info('Freeze text encoder...')
190
+ text_encoder.requires_grad_(False)
191
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
192
+ discrete_model.to(accelerator.device, dtype=weight_dtype)
193
+
194
+ discrete_model = accelerator.prepare(discrete_model)
195
+ optimizer = torch.optim.AdamW(discrete_model.parameters(),
196
+ lr=args.learning_rate,
197
+ betas=[args.adam_beta1, args.adam_beta2],
198
+ eps=args.adam_epsilon,
199
+ weight_decay=args.weight_decay)
200
+ logger.info('Init optimizer done.')
201
+ scheduler = get_scheduler(name=args.lr_scheduler_type,
202
+ optimizer=optimizer,
203
+ num_warmup_steps=args.warmup_steps,
204
+ num_training_steps=args.max_steps,
205
+ min_lr_ratio=args.min_lr_ratio)
206
+ # accelerator.register_for_checkpointing(scheduler)
207
+
208
+ optimizer, scheduler = accelerator.prepare(optimizer, scheduler)
209
+ logger.info('Prepare accelerator done.')
210
+
211
+ config_record = merge_config(discrete_model=discrete_model_cfg,
212
+ visual_encoder=visual_encoder_cfg,
213
+ text_encoder=text_encoder_cfg,
214
+ image_transform=image_transform_cfg,
215
+ tokenizer=tokenizer_cfg,
216
+ train_dataset=train_dataset_cfg,
217
+ train_args=args)
218
+ accelerator.init_trackers(project_name=args.project_name,
219
+ init_kwargs={"wandb": {
220
+ "config": config_record,
221
+ "name": args.expr_name,
222
+ "dir": args.output_dir
223
+ }})
224
+ if args.resume_from_checkpoint is not None:
225
+ logger.info(f'Load checkpoint from {args.resume_from_checkpoint}')
226
+ accelerator.load_state(args.resume_from_checkpoint)
227
+
228
+ num_params = trainable_params(discrete_model)
229
+ logger.info("***** Running training *****")
230
+ logger.info(f" Total optimization steps = {args.max_steps}")
231
+ logger.info(f" Total trainable params = {num_params}")
232
+ # Only show the progress bar once on each machine.
233
+ progress_bar = tqdm(range(args.max_steps), disable=not accelerator.is_main_process)
234
+ progress_bar.set_description("Steps")
235
+ global_step = 0
236
+ if args.resume_steps is not None:
237
+ global_step = args.resume_steps
238
+ progress_bar.update(args.resume_steps)
239
+
240
+ train_dataloader = build_dataloader(dataset_cfg=train_dataset_cfg,
241
+ image_transform=image_transform,
242
+ tokenizer=tokenizer,
243
+ dataloader_num_workers=args.dataloader_num_workers)
244
+ for epoch in range(args.num_train_epochs):
245
+ discrete_model.train()
246
+ logger.info('Start new epoch')
247
+
248
+ for step, batch in enumerate(train_dataloader):
249
+ with accelerator.accumulate(discrete_model):
250
+ with torch.no_grad():
251
+ image_embeds = visual_encoder(batch['images'].to(accelerator.device, dtype=weight_dtype))
252
+ if text_encoder is not None:
253
+ text_embeds = text_encoder(batch['text_input_ids'].to(accelerator.device))
254
+ else:
255
+ text_embeds = None
256
+
257
+ output = discrete_model(image_embeds=image_embeds, text_embeds=text_embeds)
258
+
259
+ loss = output['total_loss']
260
+ accelerator.backward(loss)
261
+ if accelerator.sync_gradients:
262
+ accelerator.clip_grad_norm_(discrete_model.parameters(), max_norm=args.max_grad_norm)
263
+ optimizer.step()
264
+ scheduler.step()
265
+ optimizer.zero_grad()
266
+
267
+ if accelerator.sync_gradients:
268
+ progress_bar.update(1)
269
+ global_step += 1
270
+
271
+ if global_step % args.save_steps == 0:
272
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
273
+ accelerator.save_state(save_path)
274
+
275
+ metric = get_metric(output)
276
+ metric['lr'] = optimizer.param_groups[0]['lr']
277
+ metric['code_usage'] = get_code_usage(output['indices'])
278
+ metric = {key: (format(value, ".6f") if isinstance(value, float) else value) for key, value in
279
+ metric.items()}
280
+ accelerator.log(metric, step=global_step)
281
+ if accelerator.is_main_process:
282
+ tqdm.write(str(metric))
283
+ # print(metric)
284
+ if global_step >= args.max_steps:
285
+ break
286
+
287
+ accelerator.end_training()
288
+
289
+
290
+ if __name__ == '__main__':
291
+ train()
src/train/train_clm_sft.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ import hydra
3
+
4
+ import pyrootutils
5
+ import os
6
+ import torch
7
+ from accelerate import Accelerator
8
+ from accelerate.logging import get_logger
9
+ from accelerate.utils import ProjectConfiguration
10
+ from torch.utils.data import DataLoader
11
+
12
+ from deepspeed.runtime.engine import DummyOptim
13
+ from tqdm.auto import tqdm
14
+ from omegaconf import OmegaConf
15
+ from omegaconf.dictconfig import DictConfig
16
+ import argparse
17
+ from flask import Flask, request
18
+ from typing import List, Union
19
+ import json
20
+ from typing import Optional
21
+ import transformers
22
+ from dataclasses import dataclass, field, asdict, is_dataclass
23
+ from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, DistributedReadingService, \
24
+ SequentialReadingService
25
+ import gc
26
+ import logging
27
+ from accelerate import FullyShardedDataParallelPlugin, DistributedDataParallelKwargs
28
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
29
+
30
+ pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
31
+ from src.train.schedular import get_scheduler
32
+ from src.train.dist_utils import all_gather
33
+
34
+ # logger = get_logger(__name__, log_level='info')
35
+ log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
36
+ logging.basicConfig(level=logging.INFO, format=log_format)
37
+
38
+ logger = logging.getLogger(__name__)
39
+ os.environ["WANDB_MODE"] = "offline"
40
+
41
+
42
+ @dataclass
43
+ class ConfigPathArguments:
44
+ image_transform: Optional[str] = field(default=None, metadata={"help": "config path of image transform"})
45
+ tokenizer: Optional[str] = field(default=None,
46
+ metadata={"help": "config path of tokenizer used to initialize tokenizer"})
47
+ # model: Optional[str] = field(default=None, metadata={"help": "config path of llm"})
48
+ visual_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"})
49
+ llm_model: Optional[str] = field(default=None, metadata={"help": "config path of llm"})
50
+ agent_model: Optional[str] = field(default=None, metadata={"help": "config path of agent"})
51
+ train_dataset: Optional[str] = field(default=None, metadata={"help": "config path of training dataset"})
52
+ fsdp_plugin: Optional[str] = field(default=None, metadata={"help": "config path of fsdp plugin"})
53
+ deepspeed_plugin: Optional[str] = field(default=None, metadata={"help": "config path of deepspeed plugin"})
54
+
55
+
56
+ @dataclass
57
+ class TrainingArguments:
58
+ output_dir: str = field(
59
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, )
60
+ resume_from_checkpoint: Optional[str] = field(
61
+ default=None, metadata={"help": "The path to a folder with a valid checkpoint for your model."})
62
+ resume_steps: Optional[int] = field(default=None, metadata={"help": "The training sterps of saved checkpoint"})
63
+ batch_size: Optional[int] = field(default=60, metadata={"help": "The training batch size"})
64
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
65
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
66
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
67
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
68
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
69
+ max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
70
+ gradient_accumulation_steps: int = field(
71
+ default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."})
72
+ mixed_precision: Optional[str] = field(
73
+ default='no',
74
+ metadata={
75
+ "help":
76
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=1.10.and an Nvidia Ampere GPU."
77
+ })
78
+ num_train_epochs: int = field(default=3, metadata={"help": "Total number of training epochs to perform."})
79
+ max_steps: int = field(default=-1, metadata={"help": "Total number of training steps to perform. "})
80
+ save_steps: int = field(default=10000, metadata={"help": "Number of updates steps before two checkpoint saves."})
81
+ lr_scheduler_type: str = field(default="cosine", metadata={"help": "The scheduler type to use."})
82
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
83
+ min_lr_ratio: float = field(default=0.01, metadata={"help": "Minimal learning rate ratio."})
84
+ dataloader_num_workers: int = field(default=8, metadata={"help": "The number of workers to use for data loading."})
85
+ project_name: str = field(default="ContinuousVLM", metadata={"help": "The name of experiment"})
86
+ expr_name: str = field(default="", metadata={"help": "The name of experiment"})
87
+
88
+
89
+ def build_dataloader(dataset_cfg, image_transform, tokenizer, batch_size, dataloader_num_workers=4):
90
+ dataset = hydra.utils.instantiate(dataset_cfg, image_transform=image_transform, tokenizer=tokenizer)
91
+ mp_service = MultiProcessingReadingService(num_workers=dataloader_num_workers)
92
+ dist_service = DistributedReadingService()
93
+ reading_service = SequentialReadingService(dist_service, mp_service)
94
+ dataloader = DataLoader2(dataset, reading_service=reading_service)
95
+ # dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=dataloader_num_workers)
96
+ return dataloader
97
+
98
+
99
+ def get_metric(output):
100
+ metric = {}
101
+ for key, value in output.items():
102
+ if 'loss' in key:
103
+ gathered_metric = torch.stack(all_gather(value)).mean()
104
+ # metric[key] = value.item()
105
+ metric[key] = gathered_metric.item()
106
+ if 'acc' in key:
107
+ metric[key] = value.item()
108
+ return metric
109
+
110
+
111
+ def merge_config(**kwargs):
112
+ config = {}
113
+ for key, value in kwargs.items():
114
+ if isinstance(value, argparse.Namespace):
115
+ config[key] = vars(value)
116
+ elif isinstance(value, DictConfig):
117
+ config[key] = OmegaConf.to_object(value)
118
+ elif is_dataclass(value):
119
+ config[key] = asdict(value)
120
+ elif isinstance(value, (int, str, float, dict)) or value is None:
121
+ config[key] = value
122
+ else:
123
+ logger.error(f'key: {key}, value: {value} will not be merged.')
124
+ return config
125
+
126
+
127
+ def trainable_params(model):
128
+ count = 0
129
+ for name, param in model.named_parameters():
130
+ if param.requires_grad:
131
+ count += param.numel()
132
+ return count
133
+
134
+
135
+ def train():
136
+ parser = transformers.HfArgumentParser((ConfigPathArguments, TrainingArguments))
137
+ cfg_path, args = parser.parse_args_into_dataclasses()
138
+
139
+ project_config = ProjectConfiguration(project_dir=args.output_dir,
140
+ logging_dir=os.path.join(args.output_dir, 'logs'))
141
+
142
+ assert int(cfg_path.fsdp_plugin is not None) + int(cfg_path.deepspeed_plugin is not None) <= 1
143
+ if cfg_path.fsdp_plugin is not None:
144
+ fsdp_plugin_cfg = OmegaConf.load(cfg_path.fsdp_plugin)
145
+ fsdp_plugin = hydra.utils.instantiate(fsdp_plugin_cfg)
146
+ logger.info('Use FSDP plugin')
147
+ else:
148
+ fsdp_plugin = None
149
+
150
+ if cfg_path.deepspeed_plugin is not None:
151
+ deepspeed_plugin_cfg = OmegaConf.load(cfg_path.deepspeed_plugin)
152
+ deepspeed_plugin = hydra.utils.instantiate(deepspeed_plugin_cfg)
153
+ logger.info('Use deepspeed plugin')
154
+ else:
155
+ deepspeed_plugin = None
156
+
157
+ # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
158
+ accelerator = Accelerator(
159
+ mixed_precision=args.mixed_precision,
160
+ log_with=['tensorboard', 'wandb'],
161
+ project_config=project_config,
162
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
163
+ step_scheduler_with_optimizer=False,
164
+ fsdp_plugin=fsdp_plugin,
165
+ deepspeed_plugin=deepspeed_plugin,
166
+ # kwargs_handlers=[ddp_kwargs],
167
+ )
168
+ accelerator.wait_for_everyone()
169
+ logger.info('Init accelerator done.')
170
+
171
+ if cfg_path.deepspeed_plugin is not None:
172
+ accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = 8
173
+
174
+ # print('deepspeed config: ', accelerator.state.deepspeed_plugin.deepspeed_config)
175
+
176
+ os.makedirs(args.output_dir, exist_ok=True)
177
+
178
+ # if cfg_path.image_transform is not None:
179
+ image_transform_cfg = OmegaConf.load(cfg_path.image_transform)
180
+ image_transform = hydra.utils.instantiate(image_transform_cfg)
181
+ # else:
182
+ # image_transform_cfg = None
183
+ # image_transform = None
184
+
185
+ # if cfg_path.tokenizer is not None:
186
+ tokenizer_cfg = OmegaConf.load(cfg_path.tokenizer)
187
+ tokenizer = hydra.utils.instantiate(tokenizer_cfg)
188
+ # else:
189
+ # tokenizer_cfg = None
190
+ # tokenizer = None
191
+ train_dataset_cfg = OmegaConf.load(cfg_path.train_dataset)
192
+
193
+ visual_encoder_cfg = OmegaConf.load(cfg_path.visual_encoder)
194
+ visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
195
+ logger.info('Load visual encoder done.')
196
+
197
+ llm_model_cfg = OmegaConf.load(cfg_path.llm_model)
198
+ llm_model = hydra.utils.instantiate(llm_model_cfg)
199
+ llm_model.gradient_checkpointing_enable()
200
+ llm_model.config.use_cache = False
201
+ logger.info('Load llm model done.')
202
+
203
+ agent_model_cfg = OmegaConf.load(cfg_path.agent_model)
204
+ agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm_model)
205
+ logger.info('Load agent model done.')
206
+
207
+ weight_dtype = torch.float32
208
+ if accelerator.mixed_precision == "fp16":
209
+ weight_dtype = torch.float16
210
+ elif accelerator.mixed_precision == "bf16":
211
+ weight_dtype = torch.bfloat16
212
+
213
+ visual_encoder.to(accelerator.device, dtype=weight_dtype)
214
+ logger.info('Freeze visual encoder...')
215
+ visual_encoder.requires_grad_(False)
216
+
217
+ if cfg_path.fsdp_plugin is not None:
218
+ agent_model = accelerator.prepare(agent_model)
219
+
220
+ optimizer = torch.optim.AdamW(agent_model.parameters(),
221
+ lr=args.learning_rate,
222
+ betas=[args.adam_beta1, args.adam_beta2],
223
+ eps=args.adam_epsilon,
224
+ weight_decay=args.weight_decay)
225
+ logger.info('Init optimizer done.')
226
+ scheduler = get_scheduler(name=args.lr_scheduler_type,
227
+ optimizer=optimizer,
228
+ num_warmup_steps=args.warmup_steps,
229
+ num_training_steps=args.max_steps,
230
+ min_lr_ratio=args.min_lr_ratio)
231
+ # accelerator.register_for_checkpointing(scheduler)
232
+ train_dataloader = build_dataloader(dataset_cfg=train_dataset_cfg,
233
+ image_transform=image_transform,
234
+ tokenizer=tokenizer,
235
+ batch_size=args.batch_size,
236
+ dataloader_num_workers=args.dataloader_num_workers)
237
+ if cfg_path.fsdp_plugin is not None:
238
+ optimizer, scheduler = accelerator.prepare(optimizer, scheduler)
239
+ else:
240
+ agent_model, optimizer, scheduler = accelerator.prepare(agent_model, optimizer, scheduler)
241
+ logger.info('Prepare accelerator done.')
242
+
243
+ config_record = merge_config(agent_model=agent_model_cfg,
244
+ llm_model=llm_model,
245
+ visual_encoder=visual_encoder_cfg,
246
+ image_transform=image_transform_cfg,
247
+ tokenizer=tokenizer_cfg,
248
+ train_dataset=train_dataset_cfg,
249
+ train_args=args)
250
+ accelerator.init_trackers(project_name=args.project_name,
251
+ init_kwargs={"wandb": {
252
+ "config": config_record,
253
+ "name": args.expr_name,
254
+ "dir": args.output_dir
255
+ }})
256
+ if args.resume_from_checkpoint is not None:
257
+ logger.info(f'Load checkpoint from {args.resume_from_checkpoint}')
258
+ accelerator.load_state(args.resume_from_checkpoint)
259
+ torch.cuda.empty_cache()
260
+ gc.collect()
261
+
262
+ num_params = trainable_params(agent_model)
263
+ logger.info("***** Running training *****")
264
+ logger.info(f" Total optimization steps = {args.max_steps}")
265
+ logger.info(f" Total trainable params = {num_params}")
266
+ # Only show the progress bar once on each machine.
267
+ progress_bar = tqdm(range(args.max_steps), disable=not accelerator.is_main_process)
268
+ progress_bar.set_description("Steps")
269
+ global_step = 0
270
+ if args.resume_steps is not None:
271
+ global_step = args.resume_steps
272
+ progress_bar.update(args.resume_steps)
273
+
274
+ for epoch in range(args.num_train_epochs):
275
+ agent_model.train()
276
+ logger.info('Start new epoch')
277
+
278
+ for step, batch in enumerate(train_dataloader):
279
+ with accelerator.accumulate(agent_model):
280
+ # accelerator.wait_for_everyone()
281
+ # print('1')
282
+ with torch.no_grad():
283
+ if batch['images'] is not None:
284
+ image_embeds = visual_encoder(batch['images'].to(accelerator.device, dtype=weight_dtype))
285
+ # image_embeds = visual_encoder(batch['images'])
286
+ else:
287
+ image_embeds = None
288
+ # accelerator.wait_for_everyone()
289
+ # print('2')
290
+ output = agent_model(input_ids=batch['input_ids'].to(accelerator.device),
291
+ attention_mask=batch['attention_mask'].to(accelerator.device),
292
+ labels=batch['labels'].to(accelerator.device),
293
+ image_embeds=image_embeds,
294
+ embeds_gen_mask=batch['embeds_gen_mask'].to(accelerator.device)
295
+ if batch['embeds_gen_mask'] is not None else None,
296
+ embeds_cmp_mask=batch['embeds_cmp_mask'].to(accelerator.device)
297
+ if batch['embeds_cmp_mask'] is not None else None,
298
+ ids_gen_mask=batch['ids_gen_mask'].to(accelerator.device),
299
+ ids_cmp_mask=batch['ids_cmp_mask'].to(accelerator.device))
300
+ # output = agent_model(
301
+ # input_ids=batch['input_ids'], #.squeeze(0),
302
+ # attention_mask=batch['attention_mask'], # .squeeze(0),
303
+ # labels=batch['labels'], # .squeeze(0),
304
+ # image_embeds=image_embeds,
305
+ # embeds_gen_mask=batch['embeds_gen_mask'], #.squeeze(0),
306
+ # embeds_cmp_mask=batch['embeds_cmp_mask'], #.squeeze(0),
307
+ # ids_gen_mask=batch['ids_gen_mask'], #.squeeze(0),
308
+ # ids_cmp_mask=batch['ids_cmp_mask']) #.squeeze(0))
309
+ loss = output['total_loss']
310
+ # accelerator.wait_for_everyone()
311
+ # print('3')
312
+ accelerator.backward(loss)
313
+ # accelerator.wait_for_everyone()
314
+ # print('4')
315
+ if accelerator.sync_gradients:
316
+ accelerator.clip_grad_norm_(agent_model.parameters(), max_norm=args.max_grad_norm)
317
+
318
+ optimizer.step()
319
+ scheduler.step()
320
+ optimizer.zero_grad()
321
+ # accelerator.wait_for_everyone()
322
+ # print('5')
323
+
324
+ if accelerator.sync_gradients:
325
+ progress_bar.update(1)
326
+ global_step += 1
327
+
328
+ if global_step % args.save_steps == 0:
329
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
330
+ accelerator.save_state(save_path)
331
+
332
+ metric = get_metric(output)
333
+ metric['lr'] = optimizer.param_groups[0]['lr']
334
+ accelerator.log(metric, step=global_step)
335
+ metric = {key: (format(value, ".6f") if isinstance(value, float) else value) for key, value in
336
+ metric.items()}
337
+ if accelerator.is_main_process:
338
+ tqdm.write(str(metric))
339
+ # print(metric)
340
+ if global_step >= args.max_steps:
341
+ break
342
+
343
+ accelerator.end_training()
344
+
345
+
346
+ if __name__ == '__main__':
347
+ train()
src/train/train_sdxl_img2img_llm.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ import hydra
3
+
4
+ import pyrootutils
5
+ import os
6
+ import torch
7
+ from accelerate import Accelerator
8
+ from accelerate.logging import get_logger
9
+ from accelerate.utils import ProjectConfiguration
10
+
11
+ from tqdm.auto import tqdm
12
+ from omegaconf import OmegaConf
13
+ from omegaconf.dictconfig import DictConfig
14
+ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler, \
15
+ Transformer2DModel
16
+
17
+ from transformers import CLIPTextModel, CLIPTokenizer
18
+ import argparse
19
+ from flask import Flask, request
20
+ from typing import List, Union
21
+ import json
22
+ from typing import Optional
23
+ import transformers
24
+ from dataclasses import dataclass, field, asdict, is_dataclass
25
+ from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, DistributedReadingService, \
26
+ SequentialReadingService
27
+ import logging
28
+
29
+ pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
30
+ from src.train.schedular import get_scheduler
31
+ from src.train.dist_utils import all_gather
32
+
33
+ # logger = get_logger(__name__, log_level='info')
34
+ log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
35
+ logging.basicConfig(level=logging.INFO, format=log_format)
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ # os.environ["WANDB_MODE"] = "offline"
41
+
42
+
43
+ @dataclass
44
+ class ConfigPathArguments:
45
+ image_transform: Optional[str] = field(default=None, metadata={"help": "config path of image transform"})
46
+ sd_image_transform: Optional[str] = field(default=None,
47
+ metadata={"help": "config path of stable diffusion image transform"})
48
+ # tokenizer: Optional[str] = field(default=None, metadata={"help": "config path of tokenizer used to initialize tokenizer"})
49
+ visual_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"})
50
+ # text_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"})
51
+ discrete_model: Optional[str] = field(default=None, metadata={"help": "config path of discrete model"})
52
+ # noise_scheduler: Optional[str] = field(default=None, metadata={"help": "config path of noise scheduler"})
53
+ # vae: Optional[str] = field(default=None, metadata={"help": "config path of vae"})
54
+ adapter: Optional[str] = field(default=None, metadata={"help": "config path of adapter"})
55
+ train_dataset: Optional[str] = field(default=None, metadata={"help": "config path of training dataset"})
56
+ fsdp_plugin: Optional[str] = field(default=None, metadata={"help": "config path of fsdp plugin"})
57
+ deepspeed_plugin: Optional[str] = field(default=None, metadata={"help": "config path of deepspeed plugin"})
58
+ tokenizer: Optional[str] = field(default=None,
59
+ metadata={"help": "config path of tokenizer used to initialize tokenizer"})
60
+ llm_model: Optional[str] = field(default=None, metadata={"help": "config path of llm"})
61
+ agent_model: Optional[str] = field(default=None, metadata={"help": "config path of agent"})
62
+
63
+
64
+ @dataclass
65
+ class TrainingArguments:
66
+ output_dir: str = field(
67
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, )
68
+ diffusion_model_path: Optional[str] = field(default=None, metadata={"help": "config path of training dataset"})
69
+ resume_from_checkpoint: Optional[str] = field(
70
+ default=None, metadata={"help": "The path to a folder with a valid checkpoint for your model."})
71
+ resume_steps: Optional[int] = field(default=None, metadata={"help": "The training sterps of saved checkpoint"})
72
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
73
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
74
+ # adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
75
+ # adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
76
+ # adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
77
+ max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
78
+ gradient_accumulation_steps: int = field(
79
+ default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."})
80
+ mixed_precision: Optional[str] = field(
81
+ default='no',
82
+ metadata={
83
+ "help":
84
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=1.10.and an Nvidia Ampere GPU."
85
+ })
86
+ num_train_epochs: int = field(default=3, metadata={"help": "Total number of training epochs to perform."})
87
+ max_steps: int = field(default=-1, metadata={"help": "Total number of training steps to perform. "})
88
+ save_steps: int = field(default=10000, metadata={"help": "Number of updates steps before two checkpoint saves."})
89
+ lr_scheduler_type: str = field(default="cosine", metadata={"help": "The scheduler type to use."})
90
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
91
+ min_lr_ratio: float = field(default=0.01, metadata={"help": "Minimal learning rate ratio."})
92
+ dataloader_num_workers: int = field(default=8, metadata={"help": "The number of workers to use for data loading."})
93
+ project_name: str = field(default="IPAdapter", metadata={"help": "The name of experiment"})
94
+ expr_name: str = field(default="", metadata={"help": "The name of experiment"})
95
+
96
+
97
+ def build_dataloader(dataset_cfg, image_transform, sd_image_transform, tokenizer, dataloader_num_workers=4):
98
+ dataset = hydra.utils.instantiate(dataset_cfg,
99
+ image_transform=image_transform,
100
+ sd_image_transform=sd_image_transform,
101
+ tokenizer=tokenizer)
102
+ mp_service = MultiProcessingReadingService(num_workers=dataloader_num_workers)
103
+ dist_service = DistributedReadingService()
104
+ reading_service = SequentialReadingService(dist_service, mp_service)
105
+ dataloader = DataLoader2(dataset, reading_service=reading_service)
106
+ return dataloader
107
+
108
+
109
+ def get_metric(output):
110
+ metric = {}
111
+ for key, value in output.items():
112
+ if 'loss' in key:
113
+ metric[key] = value.item()
114
+ return metric
115
+
116
+
117
+ def merge_config(**kwargs):
118
+ config = {}
119
+ for key, value in kwargs.items():
120
+ if isinstance(value, argparse.Namespace):
121
+ config[key] = vars(value)
122
+ elif isinstance(value, DictConfig):
123
+ config[key] = OmegaConf.to_object(value)
124
+ elif is_dataclass(value):
125
+ config[key] = asdict(value)
126
+ elif isinstance(value, dict):
127
+ config[key] = value
128
+ else:
129
+ logger.error(f'key: {key}, value: {value} will not be merged.')
130
+ return config
131
+
132
+
133
+ def trainable_params(model):
134
+ count = 0
135
+ for name, param in model.named_parameters():
136
+ if param.requires_grad:
137
+ count += param.numel()
138
+ return count
139
+
140
+
141
+ def train():
142
+ parser = transformers.HfArgumentParser((ConfigPathArguments, TrainingArguments))
143
+ cfg_path, args = parser.parse_args_into_dataclasses()
144
+
145
+ project_config = ProjectConfiguration(project_dir=args.output_dir,
146
+ logging_dir=os.path.join(args.output_dir, 'logs'))
147
+
148
+ assert int(cfg_path.fsdp_plugin is not None) + int(cfg_path.deepspeed_plugin is not None) <= 1
149
+ if cfg_path.fsdp_plugin is not None:
150
+ fsdp_plugin_cfg = OmegaConf.load(cfg_path.fsdp_plugin)
151
+ fsdp_plugin = hydra.utils.instantiate(fsdp_plugin_cfg)
152
+ logger.info('Use FSDP plugin')
153
+ else:
154
+ fsdp_plugin = None
155
+
156
+ if cfg_path.deepspeed_plugin is not None:
157
+ deepspeed_plugin_cfg = OmegaConf.load(cfg_path.deepspeed_plugin)
158
+ deepspeed_plugin = hydra.utils.instantiate(deepspeed_plugin_cfg)
159
+ logger.info('Use deepspeed plugin')
160
+ else:
161
+ deepspeed_plugin = None
162
+
163
+ accelerator = Accelerator(
164
+ mixed_precision=args.mixed_precision,
165
+ log_with=['tensorboard', 'wandb'],
166
+ project_config=project_config,
167
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
168
+ step_scheduler_with_optimizer=False,
169
+ fsdp_plugin=fsdp_plugin,
170
+ deepspeed_plugin=deepspeed_plugin,
171
+ )
172
+ logger.info('Init accelerator done.')
173
+
174
+ if cfg_path.deepspeed_plugin is not None:
175
+ accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = 100
176
+
177
+ os.makedirs(args.output_dir, exist_ok=True)
178
+
179
+ image_transform_cfg = OmegaConf.load(cfg_path.image_transform)
180
+ image_transform = hydra.utils.instantiate(image_transform_cfg)
181
+ sd_image_transform_cfg = OmegaConf.load(cfg_path.sd_image_transform)
182
+ sd_image_transform = hydra.utils.instantiate(sd_image_transform_cfg)
183
+
184
+ tokenizer_cfg = OmegaConf.load(cfg_path.tokenizer)
185
+ tokenizer = hydra.utils.instantiate(tokenizer_cfg)
186
+
187
+ visual_encoder_cfg = OmegaConf.load(cfg_path.visual_encoder)
188
+ visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
189
+ logger.info('Load visual encoder done.')
190
+
191
+ discrete_model_cfg = OmegaConf.load(cfg_path.discrete_model)
192
+ discrete_model = hydra.utils.instantiate(discrete_model_cfg)
193
+ logger.info('Load discrete model done.')
194
+
195
+ # noise_scheduler_cfg = OmegaConf.load(cfg_path.noise_scheduler)
196
+ # noise_scheduler = hydra.utils.instantiate(noise_scheduler_cfg)
197
+
198
+ # if cfg_path.tokenizer is not None:
199
+ # tokenizer_cfg = OmegaConf.load(cfg_path.tokenizer)
200
+ # tokenizer = hydra.utils.instantiate(tokenizer_cfg)
201
+ # else:
202
+ # tokenizer_cfg = None
203
+ # tokenizer = None
204
+
205
+ # if cfg_path.text_encoder is not None:
206
+ # text_encoder_cfg = OmegaConf.load(cfg_path.text_encoder)
207
+ # text_encoder = hydra.utils.instantiate(text_encoder_cfg)
208
+ # logger.info('Load text encoder done.')
209
+ # else:
210
+ # text_encoder_cfg = None
211
+ # text_encoder = None
212
+
213
+ # vae_cfg = OmegaConf.load(cfg_path.vae)
214
+ # vae = hydra.utils.instantiate(vae_cfg)
215
+ # logger.info('Load vae done.')
216
+
217
+ # noise_scheduler = DDPMScheduler.from_pretrained(args.diffusion_model_path, subfolder="scheduler")
218
+ # tokenizer = CLIPTokenizer.from_pretrained(args.diffusion_model_path, subfolder="tokenizer")
219
+ # text_encoder = CLIPTextModel.from_pretrained(args.diffusion_model_path, subfolder="text_encoder")
220
+ # vae = AutoencoderKL.from_pretrained(args.diffusion_model_path, subfolder="vae")
221
+ # unet = UNet2DConditionModel.from_pretrained(args.diffusion_model_path, subfolder="unet")
222
+ # print('load diffusion model done')
223
+
224
+ # noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(args.diffusion_model_path, subfolder="scheduler")
225
+ noise_scheduler = DDPMScheduler.from_pretrained(args.diffusion_model_path, subfolder="scheduler")
226
+ text_encoder = None
227
+ vae = AutoencoderKL.from_pretrained(args.diffusion_model_path, subfolder="vae")
228
+ unet = UNet2DConditionModel.from_pretrained(args.diffusion_model_path, subfolder="unet")
229
+
230
+ unet.enable_xformers_memory_efficient_attention()
231
+ unet.enable_gradient_checkpointing()
232
+
233
+ vae.requires_grad_(False)
234
+ visual_encoder.requires_grad_(False)
235
+ discrete_model.requires_grad_(False)
236
+
237
+ adapter_cfg = OmegaConf.load(cfg_path.adapter)
238
+ adapter = hydra.utils.instantiate(adapter_cfg, unet=unet)
239
+ logger.info('Load adapter done.')
240
+
241
+ weight_dtype = torch.float32
242
+ if accelerator.mixed_precision == "fp16":
243
+ weight_dtype = torch.float16
244
+ elif accelerator.mixed_precision == "bf16":
245
+ weight_dtype = torch.bfloat16
246
+
247
+ vae.to(accelerator.device, dtype=weight_dtype)
248
+ visual_encoder.to(accelerator.device, dtype=weight_dtype)
249
+ discrete_model.to(accelerator.device, dtype=weight_dtype)
250
+ if text_encoder is not None:
251
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
252
+
253
+ train_dataset_cfg = OmegaConf.load(cfg_path.train_dataset)
254
+ train_dataloader = build_dataloader(dataset_cfg=train_dataset_cfg,
255
+ image_transform=image_transform,
256
+ sd_image_transform=sd_image_transform,
257
+ tokenizer=tokenizer,
258
+ dataloader_num_workers=args.dataloader_num_workers)
259
+
260
+ llm_model_cfg = OmegaConf.load(cfg_path.llm_model)
261
+ llm_model = hydra.utils.instantiate(llm_model_cfg)
262
+ llm_model.gradient_checkpointing_enable()
263
+ llm_model.config.use_cache = False
264
+ logger.info('Load llm model done.')
265
+
266
+ agent_model_cfg = OmegaConf.load(cfg_path.agent_model)
267
+ agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm_model).to(accelerator.device, dtype=weight_dtype)
268
+ agent_model.requires_grad_(False)
269
+ agent_model.llm.base_model.model.use_kv_cache_head = False
270
+ logger.info('Load agent model done.')
271
+
272
+ if cfg_path.fsdp_plugin is not None:
273
+ adapter = accelerator.prepare(adapter)
274
+
275
+ optimizer = torch.optim.AdamW(adapter.params_to_opt(), lr=args.learning_rate, weight_decay=args.weight_decay)
276
+ logger.info('Init optimizer done.')
277
+ scheduler = get_scheduler(name=args.lr_scheduler_type,
278
+ optimizer=optimizer,
279
+ num_warmup_steps=args.warmup_steps,
280
+ num_training_steps=args.max_steps,
281
+ min_lr_ratio=args.min_lr_ratio)
282
+ # accelerator.register_for_checkpointing(scheduler)
283
+
284
+ # adapter.adapter, adapter.resampler, optimizer, scheduler = accelerator.prepare(
285
+ # adapter.adapter,
286
+ # adapter.resampler,
287
+ # optimizer,
288
+ # scheduler,
289
+ # )
290
+
291
+ # adapter, optimizer, scheduler = accelerator.prepare(
292
+ # adapter,
293
+ # optimizer,
294
+ # scheduler,
295
+ # )
296
+ if cfg_path.fsdp_plugin is not None:
297
+ optimizer, scheduler = accelerator.prepare(optimizer, scheduler)
298
+ else:
299
+ adapter, optimizer, scheduler = accelerator.prepare(adapter, optimizer, scheduler)
300
+ logger.info('Prepare accelerator done.')
301
+
302
+ # config_record = merge_config(discrete_model=discrete_model_cfg,
303
+ # visual_encoder=visual_encoder_cfg,
304
+ # text_encoder=text_encoder_cfg,
305
+ # image_transform=image_transform_cfg,
306
+ # sd_image_transform=sd_image_transform_cfg,
307
+ # tokenizer=tokenizer_cfg,
308
+ # train_dataset=train_dataset_cfg,
309
+ # vae=vae_cfg,
310
+ # adapter=adapter_cfg,
311
+ # train_args=args)
312
+ config_record = merge_config(discrete_model=discrete_model_cfg,
313
+ visual_encoder=visual_encoder_cfg,
314
+ image_transform=image_transform_cfg,
315
+ sd_image_transform=sd_image_transform_cfg,
316
+ train_dataset=train_dataset_cfg,
317
+ adapter=adapter_cfg,
318
+ train_args=args,
319
+ agent_model=agent_model_cfg,
320
+ llm_model=llm_model,
321
+ tokenizer=tokenizer_cfg)
322
+ accelerator.init_trackers(project_name=args.project_name,
323
+ init_kwargs={"wandb": {
324
+ "config": config_record,
325
+ "name": args.expr_name,
326
+ "dir": args.output_dir
327
+ }})
328
+ if args.resume_from_checkpoint is not None:
329
+ logger.info(f'Load checkpoint from {args.resume_from_checkpoint}')
330
+ accelerator.load_state(args.resume_from_checkpoint)
331
+
332
+ num_params = trainable_params(adapter)
333
+ logger.info("***** Running training *****")
334
+ logger.info(f" Total optimization steps = {args.max_steps}")
335
+ logger.info(f" Total trainable params = {num_params}")
336
+ for name, param in adapter.named_parameters():
337
+ if param.requires_grad:
338
+ print(name)
339
+ # print(f'adapter: {trainable_params(adapter.adapter)}')
340
+ # print(f'resampler: {trainable_params(adapter.resampler)}')
341
+ # Only show the progress bar once on each machine.
342
+ progress_bar = tqdm(range(args.max_steps), disable=not accelerator.is_main_process)
343
+ progress_bar.set_description("Steps")
344
+ global_step = 0
345
+ if args.resume_steps is not None:
346
+ global_step = args.resume_steps
347
+ progress_bar.update(args.resume_steps)
348
+
349
+ for epoch in range(args.num_train_epochs):
350
+ logger.info('Start new epoch')
351
+ for step, batch in enumerate(train_dataloader):
352
+ with accelerator.accumulate(adapter):
353
+ with torch.no_grad():
354
+ image_embeds = visual_encoder(batch['images'].to(accelerator.device, dtype=weight_dtype))
355
+ image_embeds = discrete_model.encode_image_embeds(image_embeds)
356
+ if text_encoder is not None:
357
+ text_embeds = text_encoder(batch['text_input_ids'].to(accelerator.device))[0]
358
+ else:
359
+ text_embeds = None
360
+ latents = vae.encode(
361
+ batch["sd_images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()
362
+ latents = latents * vae.config.scaling_factor
363
+ llm_output = agent_model(input_ids=batch['input_ids'].to(accelerator.device),
364
+ attention_mask=batch['attention_mask'].to(accelerator.device),
365
+ labels=batch['labels'].to(accelerator.device),
366
+ image_embeds=image_embeds,
367
+ embeds_gen_mask=batch['embeds_gen_mask'].to(accelerator.device)
368
+ if batch['embeds_gen_mask'] is not None else None,
369
+ embeds_cmp_mask=batch['embeds_cmp_mask'].to(accelerator.device)
370
+ if batch['embeds_cmp_mask'] is not None else None,
371
+ ids_gen_mask=batch['ids_gen_mask'].to(accelerator.device),
372
+ ids_cmp_mask=batch['ids_cmp_mask'].to(accelerator.device),
373
+ return_recon_image_embeds=True)
374
+
375
+ time_ids = batch['time_ids'].to(accelerator.device)
376
+
377
+ # Sample noise that we'll add to the latents
378
+ noise = torch.randn_like(latents)
379
+ bsz = latents.shape[0]
380
+ # Sample a random timestep for each image
381
+ timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
382
+ timesteps = timesteps.long()
383
+
384
+ # Add noise to the latents according to the noise magnitude at each timestep
385
+ # (this is the forward diffusion process)
386
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
387
+
388
+ output = adapter(noisy_latents=noisy_latents,
389
+ timesteps=timesteps,
390
+ image_embeds=llm_output['recon_image_embeds'],
391
+ text_embeds=None,
392
+ noise=noise,
393
+ time_ids=time_ids)
394
+
395
+ loss = output['total_loss']
396
+ accelerator.backward(loss)
397
+ if accelerator.sync_gradients:
398
+ accelerator.clip_grad_norm_(adapter.parameters(), max_norm=args.max_grad_norm)
399
+ optimizer.step()
400
+ scheduler.step()
401
+ optimizer.zero_grad()
402
+
403
+ if accelerator.sync_gradients:
404
+ progress_bar.update(1)
405
+ global_step += 1
406
+
407
+ if global_step % args.save_steps == 0:
408
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
409
+ accelerator.save_state(save_path)
410
+
411
+ metric = get_metric(output)
412
+ metric['lr'] = optimizer.param_groups[0]['lr']
413
+ accelerator.log(metric, step=global_step)
414
+ metric = {key: (format(value, ".6f") if isinstance(value, float) else value) for key, value in
415
+ metric.items()}
416
+
417
+ # if accelerator.is_local_main_process:
418
+ if accelerator.is_main_process:
419
+ tqdm.write(str(metric))
420
+ # print(metric)
421
+ if global_step >= args.max_steps:
422
+ break
423
+
424
+ accelerator.end_training()
425
+
426
+
427
+ if __name__ == '__main__':
428
+ train()
utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+
7
+ handler = None
8
+
9
+
10
+ def build_logger(logger_name, logger_dir):
11
+ global handler
12
+
13
+ formatter = logging.Formatter(
14
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
15
+ datefmt="%Y-%m-%d %H:%M:%S",
16
+ )
17
+
18
+ # Set the format of root handlers
19
+ if not logging.getLogger().handlers:
20
+ logging.basicConfig(level=logging.INFO)
21
+ logging.getLogger().handlers[0].setFormatter(formatter)
22
+
23
+ # Redirect stdout and stderr to loggers
24
+ stdout_logger = logging.getLogger("stdout")
25
+ stdout_logger.setLevel(logging.INFO)
26
+ sl = StreamToLogger(stdout_logger, logging.INFO)
27
+ sys.stdout = sl
28
+
29
+ stderr_logger = logging.getLogger("stderr")
30
+ stderr_logger.setLevel(logging.ERROR)
31
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
32
+ sys.stderr = sl
33
+
34
+ # Get logger
35
+ logger = logging.getLogger(logger_name)
36
+ logger.setLevel(logging.INFO)
37
+
38
+ # Add a file handler for all loggers
39
+ if handler is None:
40
+ os.makedirs(logger_dir, exist_ok=True)
41
+ filename = os.path.join(logger_dir, logger_name + '.log')
42
+ handler = logging.handlers.TimedRotatingFileHandler(filename, when='D', utc=True)
43
+ handler.setFormatter(formatter)
44
+
45
+ for name, item in logging.root.manager.loggerDict.items():
46
+ if isinstance(item, logging.Logger):
47
+ item.addHandler(handler)
48
+
49
+ return logger
50
+
51
+
52
+ class StreamToLogger(object):
53
+ """
54
+ Fake file-like stream object that redirects writes to a logger instance.
55
+ """
56
+
57
+ def __init__(self, logger, log_level=logging.INFO):
58
+ self.terminal = sys.stdout
59
+ self.logger = logger
60
+ self.log_level = log_level
61
+ self.linebuf = ''
62
+
63
+ def __getattr__(self, attr):
64
+ return getattr(self.terminal, attr)
65
+
66
+ def write(self, buf):
67
+ temp_linebuf = self.linebuf + buf
68
+ self.linebuf = ''
69
+ for line in temp_linebuf.splitlines(True):
70
+ # From the io.TextIOWrapper docs:
71
+ # On output, if newline is None, any '\n' characters written
72
+ # are translated to the system default line separator.
73
+ # By default sys.stdout.write() expects '\n' newlines and then
74
+ # translates them so this is still cross platform.
75
+ if line[-1] == '\n':
76
+ self.logger.log(self.log_level, line.rstrip())
77
+ else:
78
+ self.linebuf += line
79
+
80
+ def flush(self):
81
+ if self.linebuf != '':
82
+ self.logger.log(self.log_level, self.linebuf.rstrip())
83
+ self.linebuf = ''