Spaces:
Runtime error
Runtime error
seedx
Browse files- .project-root +0 -0
- README.md +3 -3
- app.py +750 -0
- conversation.py +185 -0
- src/data/__init__.py +1 -0
- src/data/__pycache__/__init__.cpython-38.pyc +0 -0
- src/data/__pycache__/datapipes.cpython-38.pyc +0 -0
- src/data/dataloader_utils.py +163 -0
- src/data/datapipes.py +62 -0
- src/data/story_telling.py +634 -0
- src/eval/gpt_comparative_eval.py +249 -0
- src/eval/gpt_score_eval.py +222 -0
- src/inference/gen_george.py +270 -0
- src/inference/vis_george_sink.py +320 -0
- src/models/__init__.py +0 -0
- src/models/discrete_models.py +454 -0
- src/models/qwen_visual.py +501 -0
- src/models_clm/__init__.py +0 -0
- src/models_clm/generation.py +31 -0
- src/models_clm/modeling_llama_4_35.py +1236 -0
- src/models_clm/modeling_llama_xformer.py +992 -0
- src/models_clm/models.py +336 -0
- src/models_clm/peft_models.py +104 -0
- src/models_ipa/__init__.py +1 -0
- src/models_ipa/adapter_modules.py +920 -0
- src/models_ipa/attention_processor.py +414 -0
- src/models_ipa/ipa_utils.py +5 -0
- src/models_ipa/resampler.py +308 -0
- src/processer/tokenizer.py +8 -0
- src/processer/transforms.py +47 -0
- src/tools/reload_qwen_vit.py +14 -0
- src/train/dist_utils.py +34 -0
- src/train/schedular.py +130 -0
- src/train/train.py +291 -0
- src/train/train_clm_sft.py +347 -0
- src/train/train_sdxl_img2img_llm.py +428 -0
- utils.py +83 -0
.project-root
ADDED
File without changes
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title: SEED Story
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.39.0
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
+
title: SEED Story George
|
3 |
+
emoji: 🌍
|
4 |
colorFrom: blue
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.39.0
|
8 |
app_file: app.py
|
app.py
ADDED
@@ -0,0 +1,750 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
+
from typing import Optional
|
6 |
+
import transformers
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
import io
|
9 |
+
import spaces
|
10 |
+
import base64
|
11 |
+
from PIL import Image
|
12 |
+
import gradio as gr
|
13 |
+
import time
|
14 |
+
import hashlib
|
15 |
+
|
16 |
+
from utils import build_logger
|
17 |
+
from conversation import conv_seed_llama2
|
18 |
+
|
19 |
+
import hydra
|
20 |
+
import pyrootutils
|
21 |
+
import torch
|
22 |
+
import re
|
23 |
+
import time
|
24 |
+
from omegaconf import OmegaConf
|
25 |
+
from flask import Flask
|
26 |
+
import json
|
27 |
+
from typing import Optional
|
28 |
+
import cv2
|
29 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler, StableDiffusionImg2ImgPipeline
|
30 |
+
|
31 |
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
32 |
+
|
33 |
+
from src.data.any_res import process_anyres_image
|
34 |
+
|
35 |
+
BOI_TOKEN = '<img>'
|
36 |
+
BOP_TOKEN = '<patch>'
|
37 |
+
EOI_TOKEN = '</img>'
|
38 |
+
EOP_TOKEN = '</patch>'
|
39 |
+
IMG_TOKEN = '<img_{:05d}>'
|
40 |
+
|
41 |
+
IMG_FLAG = '<image>'
|
42 |
+
num_img_in_tokens = 64
|
43 |
+
num_img_out_tokens = 64
|
44 |
+
|
45 |
+
resolution_grids = ['1x1', '1x2', '1x3', '1x4', '1x5', '1x6', '1x10', '2x1', '3x1', '4x1', '5x1', '6x1', '10x1', '2x2',
|
46 |
+
'2x3', '3x2', '2x4', '4x2']
|
47 |
+
base_resolution = 448
|
48 |
+
|
49 |
+
app = Flask(__name__)
|
50 |
+
|
51 |
+
|
52 |
+
def decode_image(encoded_image: str) -> Image:
|
53 |
+
decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
|
54 |
+
buffer = io.BytesIO(decoded_bytes)
|
55 |
+
image = Image.open(buffer)
|
56 |
+
return image
|
57 |
+
|
58 |
+
|
59 |
+
def encode_image(image: Image.Image, format: str = 'PNG') -> str:
|
60 |
+
with io.BytesIO() as buffer:
|
61 |
+
image.save(buffer, format=format)
|
62 |
+
encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
63 |
+
return encoded_image
|
64 |
+
|
65 |
+
|
66 |
+
@dataclass
|
67 |
+
class Arguments:
|
68 |
+
image_transform: Optional[str] = field(default='configs/processer/qwen_448_transform.yaml',
|
69 |
+
metadata={"help": "config path of image transform"})
|
70 |
+
tokenizer: Optional[str] = field(default='configs/tokenizer/clm_llama_tokenizer.yaml',
|
71 |
+
metadata={"help": "config path of tokenizer used to initialize tokenizer"})
|
72 |
+
llm: Optional[str] = field(default='configs/clm_models/llama2chat7b_lora.yaml', metadata={"help": "config path of llm"})
|
73 |
+
visual_encoder: Optional[str] = field(default='configs/visual_tokenzier/qwen_vitg_448.yaml',
|
74 |
+
metadata={"help": "config path of visual encoder"})
|
75 |
+
sd_adapter: Optional[str] = field(
|
76 |
+
default='configs/detokenizer/detokenizer_sdxl_qwen_vit_adapted.yaml',
|
77 |
+
metadata={"help": "config path of sd adapter"})
|
78 |
+
agent: Optional[str] = field(default='configs/clm_models/agent_7b_sft.yaml',
|
79 |
+
metadata={"help": "config path of agent model"})
|
80 |
+
diffusion_path: Optional[str] = field(default='stabilityai/stable-diffusion-xl-base-1.0',
|
81 |
+
metadata={"help": "diffusion model path"})
|
82 |
+
port: Optional[str] = field(default=80, metadata={"help": "network port"})
|
83 |
+
llm_device: Optional[str] = field(default='cuda:0', metadata={"help": "llm device"})
|
84 |
+
vit_sd_device: Optional[str] = field(default='cuda:0', metadata={"help": "sd and vit device"})
|
85 |
+
dtype: Optional[str] = field(default='fp16', metadata={"help": "mix percision"})
|
86 |
+
|
87 |
+
|
88 |
+
parser = transformers.HfArgumentParser(Arguments)
|
89 |
+
args, = parser.parse_args_into_dataclasses()
|
90 |
+
|
91 |
+
|
92 |
+
class LLMService:
|
93 |
+
|
94 |
+
def __init__(self, args) -> None:
|
95 |
+
|
96 |
+
self.llm_device = args.llm_device
|
97 |
+
self.vit_sd_device = args.vit_sd_device
|
98 |
+
|
99 |
+
dtype = args.dtype
|
100 |
+
if dtype == 'fp16':
|
101 |
+
self.dtype = torch.float16
|
102 |
+
elif dtype == 'bf16':
|
103 |
+
self.dtype = torch.bfloat16
|
104 |
+
else:
|
105 |
+
raise ValueError
|
106 |
+
|
107 |
+
image_transform_cfg = OmegaConf.load(args.image_transform)
|
108 |
+
self.image_transform = hydra.utils.instantiate(image_transform_cfg)
|
109 |
+
|
110 |
+
tokenizer_cfg = OmegaConf.load(args.tokenizer)
|
111 |
+
self.tokenizer = hydra.utils.instantiate(tokenizer_cfg)
|
112 |
+
|
113 |
+
visual_encoder_cfg = OmegaConf.load(args.visual_encoder)
|
114 |
+
self.visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
|
115 |
+
self.visual_encoder.eval().to(self.vit_sd_device, dtype=self.dtype)
|
116 |
+
print('Init visual encoder done')
|
117 |
+
|
118 |
+
llm_cfg = OmegaConf.load(args.llm)
|
119 |
+
llm = hydra.utils.instantiate(llm_cfg, torch_dtype=self.dtype)
|
120 |
+
print('Init llm done.')
|
121 |
+
|
122 |
+
agent_cfg = OmegaConf.load(args.agent)
|
123 |
+
self.agent = hydra.utils.instantiate(agent_cfg, llm=llm)
|
124 |
+
|
125 |
+
self.agent.eval().to(self.llm_device, dtype=self.dtype)
|
126 |
+
print('Init agent mdoel Done')
|
127 |
+
|
128 |
+
noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.diffusion_path, subfolder="scheduler")
|
129 |
+
|
130 |
+
vae = AutoencoderKL.from_pretrained(args.diffusion_path, subfolder="vae").to(self.vit_sd_device,
|
131 |
+
dtype=self.dtype)
|
132 |
+
|
133 |
+
unet = UNet2DConditionModel.from_pretrained(args.diffusion_path, subfolder="unet").to(self.vit_sd_device,
|
134 |
+
dtype=self.dtype)
|
135 |
+
|
136 |
+
sd_adapter_cfg = OmegaConf.load(args.sd_adapter)
|
137 |
+
|
138 |
+
self.sd_adapter = hydra.utils.instantiate(sd_adapter_cfg, unet=unet).eval().to(self.vit_sd_device,
|
139 |
+
dtype=self.dtype)
|
140 |
+
|
141 |
+
# self.sd_adapter.init_pipe(vae=vae,
|
142 |
+
# scheduler=noise_scheduler,
|
143 |
+
# visual_encoder=self.visual_encoder.cpu(),
|
144 |
+
# image_transform=self.image_transform,
|
145 |
+
# discrete_model=None,
|
146 |
+
# dtype=self.dtype,
|
147 |
+
# device="cpu")
|
148 |
+
|
149 |
+
self.sd_adapter.init_pipe(vae=vae,
|
150 |
+
scheduler=noise_scheduler,
|
151 |
+
visual_encoder=self.visual_encoder,
|
152 |
+
image_transform=self.image_transform,
|
153 |
+
discrete_model=None,
|
154 |
+
dtype=self.dtype,
|
155 |
+
device=self.vit_sd_device)
|
156 |
+
|
157 |
+
print('Init sd adapter pipe done.')
|
158 |
+
|
159 |
+
self.visual_encoder.to(self.vit_sd_device, dtype=self.dtype)
|
160 |
+
|
161 |
+
model_id_or_path = "stablediffusionapi/realistic-vision-v51"
|
162 |
+
self.vae_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, safety_checker=None,
|
163 |
+
torch_dtype=torch.float16)
|
164 |
+
# self.vae_pipe = self.vae_pipe.to(self.vit_sd_device)
|
165 |
+
|
166 |
+
self.boi_token_id = self.tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
|
167 |
+
self.eoi_token_id = self.tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
|
168 |
+
|
169 |
+
|
170 |
+
service = LLMService(args)
|
171 |
+
|
172 |
+
|
173 |
+
@spaces.GPU
|
174 |
+
def generate(text_list, image_list, max_new_tokens, force_boi, force_bbox, force_polish):
|
175 |
+
with torch.no_grad():
|
176 |
+
text_list = text_list.split(IMG_FLAG)
|
177 |
+
top_p = 0.5
|
178 |
+
assert len(text_list) == len(image_list) + 1
|
179 |
+
|
180 |
+
image_tokens = BOI_TOKEN + ''.join(
|
181 |
+
[IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN
|
182 |
+
|
183 |
+
input_images = []
|
184 |
+
if len(image_list) > 0:
|
185 |
+
image_tensor_list = []
|
186 |
+
embeds_cmp_mask = []
|
187 |
+
embeds_gen_mask = []
|
188 |
+
|
189 |
+
if service.multi_resolution:
|
190 |
+
patch_pos = []
|
191 |
+
image_patch_length = []
|
192 |
+
image_size_list = []
|
193 |
+
|
194 |
+
for idx, image_item in enumerate(image_list):
|
195 |
+
if isinstance(image_item, str):
|
196 |
+
image = decode_image(image_item)
|
197 |
+
print('after decode image size:', image.size)
|
198 |
+
input_images.append(image)
|
199 |
+
|
200 |
+
# if service.multi_resolution:
|
201 |
+
# image_size_list.append(image.size)
|
202 |
+
# print('image size:', image.size)
|
203 |
+
# image_tensor, patch_pos_tensor = process_anyres_image(image, service.image_transform,
|
204 |
+
# service.grid_pinpoints,
|
205 |
+
# service.base_resolution)
|
206 |
+
# image_tensor_list.append(image_tensor)
|
207 |
+
# patch_pos.append(patch_pos_tensor)
|
208 |
+
# image_patch_length.append(image_tensor.shape[0])
|
209 |
+
# print('image_patch_length', image_patch_length)
|
210 |
+
# embeds_cmp_mask.extend([True] * image_tensor.shape[0])
|
211 |
+
# embeds_gen_mask.extend([False] * image_tensor.shape[0])
|
212 |
+
#
|
213 |
+
# else:
|
214 |
+
image_tensor = service.image_transform(image)
|
215 |
+
image_tensor_list.append(image_tensor)
|
216 |
+
embeds_cmp_mask.append(True)
|
217 |
+
embeds_gen_mask.append(False)
|
218 |
+
else:
|
219 |
+
raise ValueError
|
220 |
+
|
221 |
+
if service.multi_resolution:
|
222 |
+
pixel_values = torch.cat(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype)
|
223 |
+
patch_position = torch.cat(patch_pos, dim=0)
|
224 |
+
|
225 |
+
image_tokens_list = []
|
226 |
+
for patch_length in image_patch_length:
|
227 |
+
image_tokens = ''
|
228 |
+
for _ in range(patch_length - 1):
|
229 |
+
image_tokens += BOP_TOKEN + ''.join(
|
230 |
+
IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN
|
231 |
+
image_tokens += BOI_TOKEN + ''.join(
|
232 |
+
IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN
|
233 |
+
image_tokens_list.append(image_tokens)
|
234 |
+
else:
|
235 |
+
pixel_values = torch.stack(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype)
|
236 |
+
|
237 |
+
image_embeds = service.visual_encoder(pixel_values)
|
238 |
+
image_embeds = image_embeds.to(service.llm_device)
|
239 |
+
|
240 |
+
embeds_cmp_mask = torch.tensor(embeds_cmp_mask, dtype=torch.bool).to(service.llm_device)
|
241 |
+
embeds_gen_mask = torch.tensor(embeds_gen_mask, dtype=torch.bool).to(service.llm_device)
|
242 |
+
|
243 |
+
else:
|
244 |
+
image_embeds = None
|
245 |
+
patch_position = 0
|
246 |
+
embeds_cmp_mask = None
|
247 |
+
embeds_gen_mask = None
|
248 |
+
|
249 |
+
input_text = image_tokens.join(text_list)
|
250 |
+
|
251 |
+
print('input_text:', input_text)
|
252 |
+
input_ids = service.tokenizer.encode(input_text, add_special_tokens=False)
|
253 |
+
input_ids = [service.tokenizer.bos_token_id] + input_ids
|
254 |
+
|
255 |
+
input_ids = torch.tensor(input_ids).to(service.llm_device, dtype=torch.long)
|
256 |
+
ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device)
|
257 |
+
ids_gen_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device)
|
258 |
+
|
259 |
+
boi_indices = torch.where(input_ids == service.boi_token_id)[0].tolist()
|
260 |
+
eoi_indices = torch.where(input_ids == service.eoi_token_id)[0].tolist()
|
261 |
+
|
262 |
+
for boi_idx, eoi_idx in zip(boi_indices, eoi_indices):
|
263 |
+
ids_cmp_mask[boi_idx + 1:eoi_idx] = True
|
264 |
+
|
265 |
+
input_ids = input_ids.unsqueeze(0)
|
266 |
+
ids_cmp_mask = ids_cmp_mask.unsqueeze(0)
|
267 |
+
ids_gen_mask = ids_gen_mask.unsqueeze(0)
|
268 |
+
|
269 |
+
error_msg = []
|
270 |
+
|
271 |
+
output = service.agent.generate(
|
272 |
+
tokenizer=service.tokenizer,
|
273 |
+
input_ids=input_ids,
|
274 |
+
image_embeds=image_embeds,
|
275 |
+
embeds_cmp_mask=embeds_cmp_mask,
|
276 |
+
ids_cmp_mask=ids_cmp_mask,
|
277 |
+
num_img_gen_tokens=num_img_out_tokens,
|
278 |
+
max_new_tokens=max_new_tokens,
|
279 |
+
dtype=service.dtype,
|
280 |
+
device=service.llm_device,
|
281 |
+
top_p=top_p,
|
282 |
+
)
|
283 |
+
|
284 |
+
gen_imgs_base64_list = []
|
285 |
+
generated_text = output['text']
|
286 |
+
generated_text = generated_text.replace(EOI_TOKEN, IMG_FLAG).replace(service.tokenizer.eos_token, '')
|
287 |
+
|
288 |
+
torch.cuda.empty_cache()
|
289 |
+
|
290 |
+
if output['has_img_output']:
|
291 |
+
# print('loading visual encoder and llm to CPU, and sd to GPU')
|
292 |
+
# a = time.time()
|
293 |
+
# service.agent = service.agent.cpu()
|
294 |
+
# service.sd_adapter = service.sd_adapter.to(service.vit_sd_device, dtype=service.dtype)
|
295 |
+
# print("Loading finished: ", time.time() - a)
|
296 |
+
|
297 |
+
img_gen_feat = output['img_gen_feat'].to(service.vit_sd_device, dtype=service.dtype)
|
298 |
+
|
299 |
+
for img_idx in range(output['num_gen_imgs']):
|
300 |
+
img_feat = img_gen_feat[img_idx:img_idx + 1]
|
301 |
+
generated_image = service.sd_adapter.generate(image_embeds=img_feat, num_inference_steps=50)[0]
|
302 |
+
|
303 |
+
if force_polish:
|
304 |
+
# service.sd_adapter = service.sd_adapter.cpu()
|
305 |
+
# service.vae_pipe = service.vae_pipe.to(service.vit_sd_device, dtype=service.dtype)
|
306 |
+
|
307 |
+
torch.cuda.empty_cache()
|
308 |
+
|
309 |
+
service.vae_pipe = service.vae_pipe.to(service.vit_sd_device)
|
310 |
+
|
311 |
+
init_image = generated_image.resize((1024, 1024))
|
312 |
+
prompt = ""
|
313 |
+
images = service.vae_pipe(prompt=prompt, image=init_image,
|
314 |
+
num_inference_steps=50, guidance_scale=8.0, strength=0.38).images
|
315 |
+
generated_image = images[0]
|
316 |
+
|
317 |
+
image_base64 = encode_image(generated_image)
|
318 |
+
gen_imgs_base64_list.append(image_base64)
|
319 |
+
|
320 |
+
# service.vae_pipe = service.vae_pipe.to("cpu")
|
321 |
+
# service.sd_adapter = service.sd_adapter.to(service.vit_sd_device, dtype=service.dtype)
|
322 |
+
|
323 |
+
torch.cuda.empty_cache()
|
324 |
+
|
325 |
+
# print('loading visual encoder and llm to GPU, and sd to CPU')
|
326 |
+
# a = time.time()
|
327 |
+
# service.sd_adapter = service.sd_adapter.cpu()
|
328 |
+
# service.visual_encoder = service.visual_encoder.to(service.vit_sd_device, dtype=service.dtype)
|
329 |
+
# service.agent = service.agent.to(service.vit_sd_device, dtype=service.dtype)
|
330 |
+
# print("Loading finished: ", time.time() - a)
|
331 |
+
|
332 |
+
if args.has_bbox:
|
333 |
+
bboxes = extract_box(generated_text)
|
334 |
+
if bboxes is not None and len(input_images) > 0:
|
335 |
+
image_viz = visualize_bbox(input_images[-1], bboxes)
|
336 |
+
image_base64 = encode_image(image_viz)
|
337 |
+
gen_imgs_base64_list.append(image_base64)
|
338 |
+
if '<box_start>' in generated_text:
|
339 |
+
generated_text = re.sub(r'\[\[ <box_start>.*?<box_end>.*?\]\]', 'the green bounding box',
|
340 |
+
generated_text)
|
341 |
+
else:
|
342 |
+
generated_text = re.sub(r'<loc-\d+> <loc-\d+> <loc-\d+> <loc-\d+> <box_end> \]\]',
|
343 |
+
'the green bounding box', generated_text)
|
344 |
+
generated_text += IMG_FLAG
|
345 |
+
print(input_text + generated_text)
|
346 |
+
return {'text': generated_text, 'images': gen_imgs_base64_list, 'error_msg': error_msg}
|
347 |
+
|
348 |
+
|
349 |
+
def http_bot(dialog_state, input_state, max_new_tokens, max_turns, force_image_gen, force_bbox, force_polish,
|
350 |
+
request: gr.Request):
|
351 |
+
print('input_state:', input_state)
|
352 |
+
|
353 |
+
if len(dialog_state.messages) == 0 or dialog_state.messages[-1]['role'] != dialog_state.roles[0] or len(
|
354 |
+
dialog_state.messages[-1]['message']['text'].strip(' ?.;!/')) == 0:
|
355 |
+
return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4
|
356 |
+
|
357 |
+
if len(dialog_state.messages) > max_turns * 2:
|
358 |
+
output_state = init_input_state()
|
359 |
+
output_state['text'] = 'Error: History exceeds maximum rounds, please clear history and restart.'
|
360 |
+
dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state})
|
361 |
+
input_state = init_input_state()
|
362 |
+
return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 3 + (enable_btn,)
|
363 |
+
|
364 |
+
prompt = dialog_state.get_prompt()
|
365 |
+
text = prompt['text']
|
366 |
+
max_new_tokens = int(max_new_tokens)
|
367 |
+
images = prompt['images']
|
368 |
+
force_boi = force_image_gen
|
369 |
+
force_bbox = force_bbox
|
370 |
+
|
371 |
+
results = generate(text, images, max_new_tokens, force_boi, force_bbox, force_polish)
|
372 |
+
print('response: ', {'text': results['text'], 'error_msg': results['error_msg']})
|
373 |
+
|
374 |
+
output_state = init_input_state()
|
375 |
+
image_dir = get_conv_image_dir()
|
376 |
+
output_state['text'] = results['text']
|
377 |
+
|
378 |
+
for image_base64 in results['images']:
|
379 |
+
if image_base64 == '':
|
380 |
+
image_path = ''
|
381 |
+
else:
|
382 |
+
image = decode_image(image_base64)
|
383 |
+
image = image.convert('RGB')
|
384 |
+
image_path = get_image_name(image=image, image_dir=image_dir)
|
385 |
+
if not os.path.exists(image_path):
|
386 |
+
image.save(image_path)
|
387 |
+
output_state['images'].append(image_path)
|
388 |
+
|
389 |
+
dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state})
|
390 |
+
|
391 |
+
vote_last_response(dialog_state, 'common', request)
|
392 |
+
input_state = init_input_state()
|
393 |
+
chatbot = update_error_msg(dialog_state.to_gradio_chatbot(), results['error_msg'])
|
394 |
+
return (dialog_state, input_state, chatbot) + (enable_btn,) * 4
|
395 |
+
|
396 |
+
|
397 |
+
IMG_FLAG = '<image>'
|
398 |
+
LOGDIR = 'log'
|
399 |
+
|
400 |
+
logger = build_logger("gradio_seed_x", LOGDIR)
|
401 |
+
headers = {"User-Agent": "SEED-X Client"}
|
402 |
+
|
403 |
+
no_change_btn = gr.Button()
|
404 |
+
enable_btn = gr.Button(interactive=True)
|
405 |
+
disable_btn = gr.Button(interactive=False)
|
406 |
+
|
407 |
+
conv_seed_llama = conv_seed_llama2
|
408 |
+
|
409 |
+
|
410 |
+
def get_conv_log_filename():
|
411 |
+
t = datetime.datetime.now()
|
412 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
413 |
+
return name
|
414 |
+
|
415 |
+
|
416 |
+
def get_conv_image_dir():
|
417 |
+
name = os.path.join(LOGDIR, 'images')
|
418 |
+
os.makedirs(name, exist_ok=True)
|
419 |
+
return name
|
420 |
+
|
421 |
+
|
422 |
+
def get_image_name(image, image_dir=None):
|
423 |
+
buffer = io.BytesIO()
|
424 |
+
image.save(buffer, format='PNG')
|
425 |
+
image_bytes = buffer.getvalue()
|
426 |
+
md5 = hashlib.md5(image_bytes).hexdigest()
|
427 |
+
|
428 |
+
if image_dir is not None:
|
429 |
+
image_name = os.path.join(image_dir, md5 + '.png')
|
430 |
+
else:
|
431 |
+
image_name = md5 + '.png'
|
432 |
+
|
433 |
+
return image_name
|
434 |
+
|
435 |
+
|
436 |
+
def resize_image_square(image, target_size=448):
|
437 |
+
resized_image = image.resize((target_size, target_size))
|
438 |
+
return resized_image
|
439 |
+
|
440 |
+
|
441 |
+
def resize_image(image, max_size=512):
|
442 |
+
width, height = image.size
|
443 |
+
aspect_ratio = float(width) / float(height)
|
444 |
+
|
445 |
+
if width > height:
|
446 |
+
new_width = max_size
|
447 |
+
new_height = int(new_width / aspect_ratio)
|
448 |
+
else:
|
449 |
+
new_height = max_size
|
450 |
+
new_width = int(new_height * aspect_ratio)
|
451 |
+
|
452 |
+
resized_image = image.resize((new_width, new_height))
|
453 |
+
return resized_image
|
454 |
+
|
455 |
+
|
456 |
+
def center_crop_image(image, max_aspect_ratio=1.5):
|
457 |
+
width, height = image.size
|
458 |
+
aspect_ratio = max(width, height) / min(width, height)
|
459 |
+
|
460 |
+
if aspect_ratio >= max_aspect_ratio:
|
461 |
+
if width > height:
|
462 |
+
new_width = int(height * max_aspect_ratio)
|
463 |
+
left = (width - new_width) // 2
|
464 |
+
right = (width + new_width) // 2
|
465 |
+
top = 0
|
466 |
+
bottom = height
|
467 |
+
else:
|
468 |
+
new_height = int(width * max_aspect_ratio)
|
469 |
+
left = 0
|
470 |
+
right = width
|
471 |
+
top = (height - new_height) // 2
|
472 |
+
bottom = (height + new_height) // 2
|
473 |
+
|
474 |
+
cropped_image = image.crop((left, top, right, bottom))
|
475 |
+
return cropped_image
|
476 |
+
else:
|
477 |
+
return image
|
478 |
+
|
479 |
+
|
480 |
+
def vote_last_response(state, vote_type, request: gr.Request):
|
481 |
+
with open(get_conv_log_filename(), "a") as fout:
|
482 |
+
data = {
|
483 |
+
"tstamp": round(time.time(), 4),
|
484 |
+
"type": vote_type,
|
485 |
+
"state": state.dict(),
|
486 |
+
"ip": request.client.host,
|
487 |
+
}
|
488 |
+
fout.write(json.dumps(data) + "\n")
|
489 |
+
|
490 |
+
|
491 |
+
def upvote_last_response(state, request: gr.Request):
|
492 |
+
logger.info(f"upvote. ip: {request.client.host}")
|
493 |
+
vote_last_response(state, "upvote", request)
|
494 |
+
return (disable_btn,) * 2
|
495 |
+
|
496 |
+
|
497 |
+
def downvote_last_response(state, request: gr.Request):
|
498 |
+
logger.info(f"downvote. ip: {request.client.host}")
|
499 |
+
vote_last_response(state, "downvote", request)
|
500 |
+
return (disable_btn,) * 2
|
501 |
+
|
502 |
+
|
503 |
+
def regenerate(dialog_state, request: gr.Request):
|
504 |
+
logger.info(f"regenerate. ip: {request.client.host}")
|
505 |
+
if dialog_state.messages[-1]['role'] == dialog_state.roles[1]:
|
506 |
+
dialog_state.messages.pop()
|
507 |
+
return (
|
508 |
+
dialog_state,
|
509 |
+
dialog_state.to_gradio_chatbot(),
|
510 |
+
) + (disable_btn,) * 4
|
511 |
+
|
512 |
+
|
513 |
+
def clear_history(request: gr.Request):
|
514 |
+
logger.info(f"clear_history. ip: {request.client.host}")
|
515 |
+
dialog_state = conv_seed_llama.copy()
|
516 |
+
input_state = init_input_state()
|
517 |
+
return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4
|
518 |
+
|
519 |
+
|
520 |
+
def init_input_state():
|
521 |
+
return {'images': [], 'text': ''}
|
522 |
+
|
523 |
+
|
524 |
+
def add_text(dialog_state, input_state, text, request: gr.Request):
|
525 |
+
logger.info(f"add_text. ip: {request.client.host}.")
|
526 |
+
if text is None or len(text) == 0:
|
527 |
+
return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4
|
528 |
+
input_state['text'] += text
|
529 |
+
|
530 |
+
if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]:
|
531 |
+
dialog_state.messages[-1]['message'] = input_state
|
532 |
+
else:
|
533 |
+
dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state})
|
534 |
+
print('add_text: ', dialog_state.to_gradio_chatbot())
|
535 |
+
|
536 |
+
return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4
|
537 |
+
|
538 |
+
|
539 |
+
def is_blank(image):
|
540 |
+
image_array = np.array(image)
|
541 |
+
unique_colors = np.unique(image_array)
|
542 |
+
print('unique_colors', len(unique_colors))
|
543 |
+
return len(unique_colors) == 1
|
544 |
+
|
545 |
+
|
546 |
+
def add_image(dialog_state, input_state, image, request: gr.Request):
|
547 |
+
logger.info(f"add_image. ip: {request.client.host}.")
|
548 |
+
if image is None:
|
549 |
+
return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4
|
550 |
+
|
551 |
+
image = image.convert('RGB')
|
552 |
+
|
553 |
+
print('image size:', image.size)
|
554 |
+
|
555 |
+
image = center_crop_image(image, max_aspect_ratio=10)
|
556 |
+
|
557 |
+
image_dir = get_conv_image_dir()
|
558 |
+
image_path = get_image_name(image=image, image_dir=image_dir)
|
559 |
+
if not os.path.exists(image_path):
|
560 |
+
image.save(image_path)
|
561 |
+
input_state['images'].append(image_path)
|
562 |
+
input_state['text'] += IMG_FLAG
|
563 |
+
|
564 |
+
if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]:
|
565 |
+
dialog_state.messages[-1]['message'] = input_state
|
566 |
+
else:
|
567 |
+
dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state})
|
568 |
+
|
569 |
+
print('add_image:', dialog_state)
|
570 |
+
|
571 |
+
return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4
|
572 |
+
|
573 |
+
|
574 |
+
def update_error_msg(chatbot, error_msg):
|
575 |
+
if len(error_msg) > 0:
|
576 |
+
info = '\n-------------\nSome errors occurred during response, please clear history and restart.\n' + '\n'.join(
|
577 |
+
error_msg)
|
578 |
+
chatbot[-1][-1] = chatbot[-1][-1] + info
|
579 |
+
|
580 |
+
return chatbot
|
581 |
+
|
582 |
+
|
583 |
+
def load_demo(request: gr.Request):
|
584 |
+
logger.info(f"load_demo. ip: {request.client.host}")
|
585 |
+
dialog_state = conv_seed_llama.copy()
|
586 |
+
input_state = init_input_state()
|
587 |
+
return dialog_state, input_state
|
588 |
+
|
589 |
+
|
590 |
+
title = ("""
|
591 |
+
# SEED-X-I
|
592 |
+
[[Paper]](https://arxiv.org/abs/2404.14396) [[Code]](https://github.com/AILab-CVC/SEED-X) [[Faster Demo]](https://arc.tencent.com/en/ai-demos/multimodal)
|
593 |
+
|
594 |
+
Demo of a general instruction-tuned model SEED-X-I (17B) from the foundation model SEED-X.
|
595 |
+
SEED-X-I can follow multimodal instruction (including images with **dynamic resolutions**) and make responses with **images, texts and bounding boxes** in multi-turn conversation.
|
596 |
+
|
597 |
+
SEED-X-I **does not support image manipulation**. If you want to experience **SEED-X-Edit** for high-precision image editing, please refer to [[Inference Code]](https://github.com/AILab-CVC/SEED-X).
|
598 |
+
|
599 |
+
If you want to experience the normal model inference speed, you can use [[Faster Demo]](https://arc.tencent.com/en/ai-demos/multimodal) or run [[Inference Code]](https://github.com/AILab-CVC/SEED-X) locally.
|
600 |
+
|
601 |
+
## Tips:
|
602 |
+
* Check out the conversation examples (at the bottom) for inspiration.
|
603 |
+
* You can adjust "Max History Rounds" to try a conversation with up to **three rounds due to insufficient GPU memory**. For more turns, you can download our checkpoints from GitHub and deploy them locally for inference.
|
604 |
+
* Our demo supports a mix of images and texts as input. You can freely upload an image or enter text, and then click on "Add Image/Text". You can repeat the former step multiple times, and click on "Submit" for model inference at last.
|
605 |
+
* You can click "Force Image Generation" to compel the model to produce images when necessary. For example, our model might struggle to generate images when there is an excessive amount of text-only context.
|
606 |
+
* You can click "Force Bounding Box" to compel the model to produce bounding box for object detection.
|
607 |
+
* You can click "Force Polishing Generated Image" to compel the model to polish the generated image with image post-processing.
|
608 |
+
|
609 |
+
* SEED-X was trained with English-only data. It may process with other languages due to the inherent capabilities from LLaMA, but might not stable.
|
610 |
+
""")
|
611 |
+
|
612 |
+
css = """
|
613 |
+
img {
|
614 |
+
font-family: 'Helvetica';
|
615 |
+
font-weight: 300;
|
616 |
+
line-height: 2;
|
617 |
+
text-align: center;
|
618 |
+
|
619 |
+
width: auto;
|
620 |
+
height: auto;
|
621 |
+
display: block;
|
622 |
+
position: relative;
|
623 |
+
}
|
624 |
+
img:before {
|
625 |
+
content: " ";
|
626 |
+
display: block;
|
627 |
+
position: absolute;
|
628 |
+
top: -10px;
|
629 |
+
left: 0;
|
630 |
+
height: calc(100% + 10px);
|
631 |
+
width: 100%;
|
632 |
+
background-color: rgb(230, 230, 230);
|
633 |
+
border: 2px dotted rgb(200, 200, 200);
|
634 |
+
border-radius: 5px;
|
635 |
+
}
|
636 |
+
img:after {
|
637 |
+
content: " ";
|
638 |
+
display: block;
|
639 |
+
font-size: 16px;
|
640 |
+
font-style: normal;
|
641 |
+
font-family: FontAwesome;
|
642 |
+
color: rgb(100, 100, 100);
|
643 |
+
|
644 |
+
position: absolute;
|
645 |
+
top: 5px;
|
646 |
+
left: 0;
|
647 |
+
width: 100%;
|
648 |
+
text-align: center;
|
649 |
+
}
|
650 |
+
"""
|
651 |
+
|
652 |
+
if __name__ == '__main__':
|
653 |
+
examples_mix = [
|
654 |
+
['https://github.com/AILab-CVC/SEED-X/blob/main/demos/bank.png?raw=true',
|
655 |
+
'Can I conntect with an advisor on Sunday?'],
|
656 |
+
['https://github.com/AILab-CVC/SEED-X/blob/main/demos/ground.png?raw=true',
|
657 |
+
'Is there anything in the image that can protect me from catching the flu virus when I go out? Show me the location.'],
|
658 |
+
['https://github.com/AILab-CVC/SEED-X/blob/main/demos/arrow.jpg?raw=true',
|
659 |
+
'What is the object pointed by the red arrow?'],
|
660 |
+
['https://github.com/AILab-CVC/SEED-X/blob/main/demos/shanghai.png?raw=true',
|
661 |
+
'Where was this image taken? Explain your answer.'],
|
662 |
+
['https://github.com/AILab-CVC/SEED-X/blob/main/demos/GPT4.png?raw=true',
|
663 |
+
'How long does it take to make GPT-4 safer?'],
|
664 |
+
['https://github.com/AILab-CVC/SEED-X/blob/main/demos/twitter.png?raw=true',
|
665 |
+
'Please provide a comprehensive description of this image.'],
|
666 |
+
]
|
667 |
+
examples_text = [
|
668 |
+
['I want to build a two story cabin in the woods, with many commanding windows. Can you show me a picture?'],
|
669 |
+
['Use your imagination to design a concept image for Artificial General Intelligence (AGI). Show me an image.'],
|
670 |
+
[
|
671 |
+
'Can you design an illustration for “The Three-Body Problem” to depict a scene from the novel? Show me a picture.'],
|
672 |
+
[
|
673 |
+
'My four year old son loves toy trains. Can you design a fancy birthday cake for him? Please generate a picture.'],
|
674 |
+
[
|
675 |
+
'Generate an image of a portrait of young nordic girl, age 25, freckled skin, neck tatoo, blue eyes 35mm lens, photography, ultra details.'],
|
676 |
+
['Generate an impressionist painting of an astronaut in a jungle.']
|
677 |
+
]
|
678 |
+
with gr.Blocks(css=css) as demo:
|
679 |
+
gr.Markdown(title)
|
680 |
+
dialog_state = gr.State()
|
681 |
+
input_state = gr.State()
|
682 |
+
with gr.Row():
|
683 |
+
with gr.Column(scale=3):
|
684 |
+
with gr.Row():
|
685 |
+
image = gr.Image(type='pil', label='input_image')
|
686 |
+
with gr.Row():
|
687 |
+
text = gr.Textbox(lines=5,
|
688 |
+
show_label=False,
|
689 |
+
label='input_text',
|
690 |
+
elem_id='textbox',
|
691 |
+
placeholder="Enter text and image, and press submit,", container=False)
|
692 |
+
with gr.Row():
|
693 |
+
add_image_btn = gr.Button("Add Image")
|
694 |
+
add_text_btn = gr.Button("Add Text")
|
695 |
+
|
696 |
+
submit_btn = gr.Button("Submit")
|
697 |
+
|
698 |
+
with gr.Row():
|
699 |
+
max_new_tokens = gr.Slider(minimum=64,
|
700 |
+
maximum=1024,
|
701 |
+
value=768,
|
702 |
+
step=64,
|
703 |
+
interactive=True,
|
704 |
+
label="Max Output Tokens")
|
705 |
+
max_turns = gr.Slider(minimum=1, maximum=3, value=3, step=1, interactive=True,
|
706 |
+
label="Max History Rounds")
|
707 |
+
force_img_gen = gr.Radio(choices=[True, False], value=False, label='Force Image Generation')
|
708 |
+
force_bbox = gr.Radio(choices=[True, False], value=False, label='Force Bounding Box')
|
709 |
+
force_polish = gr.Radio(choices=[True, False], value=True, label='Force Polishing Generated Image')
|
710 |
+
|
711 |
+
with gr.Column(scale=7):
|
712 |
+
chatbot = gr.Chatbot(elem_id='chatbot', label="SEED-X-I", height=700)
|
713 |
+
with gr.Row():
|
714 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
715 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
716 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
717 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
718 |
+
|
719 |
+
with gr.Row():
|
720 |
+
with gr.Column(scale=0.7):
|
721 |
+
gr.Examples(examples=examples_mix, label='Input examples', inputs=[image, text], cache_examples=False)
|
722 |
+
with gr.Column(scale=0.3):
|
723 |
+
gr.Examples(examples=examples_text, label='Input examples', inputs=[text], cache_examples=False)
|
724 |
+
|
725 |
+
# Register listeners
|
726 |
+
btn_list = [upvote_btn, downvote_btn, regenerate_btn, clear_btn]
|
727 |
+
upvote_btn.click(upvote_last_response, [dialog_state], [upvote_btn, downvote_btn])
|
728 |
+
downvote_btn.click(downvote_last_response, [dialog_state], [upvote_btn, downvote_btn])
|
729 |
+
|
730 |
+
regenerate_btn.click(regenerate, [dialog_state], [dialog_state, chatbot] + btn_list).then(
|
731 |
+
http_bot, [dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox, force_polish],
|
732 |
+
[dialog_state, input_state, chatbot] + btn_list)
|
733 |
+
add_image_btn.click(add_image, [dialog_state, input_state, image],
|
734 |
+
[dialog_state, input_state, image, chatbot] + btn_list)
|
735 |
+
|
736 |
+
add_text_btn.click(add_text, [dialog_state, input_state, text],
|
737 |
+
[dialog_state, input_state, text, chatbot] + btn_list)
|
738 |
+
|
739 |
+
submit_btn.click(
|
740 |
+
add_image, [dialog_state, input_state, image], [dialog_state, input_state, image, chatbot] + btn_list).then(
|
741 |
+
add_text, [dialog_state, input_state, text],
|
742 |
+
[dialog_state, input_state, text, chatbot, upvote_btn, downvote_btn, regenerate_btn, clear_btn]).then(
|
743 |
+
http_bot,
|
744 |
+
[dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox, force_polish],
|
745 |
+
[dialog_state, input_state, chatbot] + btn_list)
|
746 |
+
clear_btn.click(clear_history, None, [dialog_state, input_state, chatbot] + btn_list)
|
747 |
+
|
748 |
+
demo.load(load_demo, None, [dialog_state, input_state])
|
749 |
+
|
750 |
+
demo.launch(debug=True)
|
conversation.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
import io
|
6 |
+
import base64
|
7 |
+
import os
|
8 |
+
from PIL import Image
|
9 |
+
import copy
|
10 |
+
|
11 |
+
IMG_FLAG = '<image>'
|
12 |
+
|
13 |
+
|
14 |
+
class SeparatorStyle(Enum):
|
15 |
+
"""Different separator style."""
|
16 |
+
SINGLE = auto()
|
17 |
+
TWO = auto()
|
18 |
+
MPT = auto()
|
19 |
+
PLAIN = auto()
|
20 |
+
LLAMA_2 = auto()
|
21 |
+
|
22 |
+
|
23 |
+
def decode_image(encoded_image: str) -> Image:
|
24 |
+
decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
|
25 |
+
buffer = io.BytesIO(decoded_bytes)
|
26 |
+
image = Image.open(buffer)
|
27 |
+
return image
|
28 |
+
|
29 |
+
|
30 |
+
def encode_image(image: Image.Image, format: str = 'PNG') -> str:
|
31 |
+
with io.BytesIO() as buffer:
|
32 |
+
image.save(buffer, format=format)
|
33 |
+
encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
34 |
+
return encoded_image
|
35 |
+
|
36 |
+
|
37 |
+
@dataclasses.dataclass
|
38 |
+
class Conversation:
|
39 |
+
"""A class that keeps all conversation history."""
|
40 |
+
system: str
|
41 |
+
roles: List[str]
|
42 |
+
messages: List[dict] # multi-turn -> user & assistant -> {'images': [PIL.Image,], 'text': str}
|
43 |
+
offset: int
|
44 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
45 |
+
sep: str = "###"
|
46 |
+
sep2: str = None
|
47 |
+
version: str = "Unknown"
|
48 |
+
|
49 |
+
skip_next: bool = False
|
50 |
+
|
51 |
+
def get_prompt(self):
|
52 |
+
messages = copy.deepcopy(self.messages)
|
53 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
54 |
+
if self.system is None or self.system == '':
|
55 |
+
text = ''
|
56 |
+
else:
|
57 |
+
text = self.system + self.sep
|
58 |
+
images = []
|
59 |
+
for message in messages:
|
60 |
+
text += message['role'] + ": " + message['message']['text'] + self.sep
|
61 |
+
for image_path in message['message']['images']:
|
62 |
+
image = Image.open(image_path).resize((256, 256))
|
63 |
+
image_base64 = encode_image(image)
|
64 |
+
images.append(image_base64)
|
65 |
+
|
66 |
+
text += self.roles[1] + ":"
|
67 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
68 |
+
b_token = "[INST] "
|
69 |
+
e_token = " [/INST]"
|
70 |
+
if self.system is None or self.system == '':
|
71 |
+
text = ''
|
72 |
+
else:
|
73 |
+
text = f"<<SYS>>\n{self.system}\n<</SYS>>\n\n"
|
74 |
+
images = []
|
75 |
+
for idx, message in enumerate(messages):
|
76 |
+
# text += message['role'] + ": " + message['message']['text'] + self.sep
|
77 |
+
if idx % 2 == 0:
|
78 |
+
text += b_token + message['message']['text'] + e_token + self.sep
|
79 |
+
else:
|
80 |
+
text += message['message']['text'] + self.sep
|
81 |
+
|
82 |
+
for image_path in message['message']['images']:
|
83 |
+
image = Image.open(image_path)
|
84 |
+
image_base64 = encode_image(image)
|
85 |
+
images.append(image_base64)
|
86 |
+
else:
|
87 |
+
raise NotImplementedError
|
88 |
+
|
89 |
+
return {'text': text, 'images': images}
|
90 |
+
|
91 |
+
# def update_image_ids(self, images_ids):
|
92 |
+
# image_count = 0
|
93 |
+
# for message in self.messages:
|
94 |
+
# for idx in range(len(message['message']['images_ids'])):
|
95 |
+
# if message['message']["images_ids"][idx] is None:
|
96 |
+
# message['message']["images_ids"][idx] = images_ids[image_count]
|
97 |
+
# image_count += 1
|
98 |
+
|
99 |
+
# assert len(images_ids) == image_count, print(len(images_ids), image_count)
|
100 |
+
|
101 |
+
def append_message(self, role, message):
|
102 |
+
self.messages.append([role, message])
|
103 |
+
|
104 |
+
def to_gradio_chatbot(self):
|
105 |
+
dialog = []
|
106 |
+
for i, single_turn in enumerate(self.messages[self.offset:]):
|
107 |
+
single_turn = single_turn['message']
|
108 |
+
text_list = single_turn['text'].split(IMG_FLAG)
|
109 |
+
assert len(text_list) == len(single_turn['images']) + 1, print(text_list, len(single_turn['images']))
|
110 |
+
message = ''
|
111 |
+
for image_idx in range(len(single_turn['images'])):
|
112 |
+
image_path = single_turn['images'][image_idx]
|
113 |
+
image = Image.open(image_path)
|
114 |
+
image_base64 = encode_image(image)
|
115 |
+
image_str = f'<img src="data:image/png;base64,{image_base64}" alt="user upload image" />'
|
116 |
+
message += text_list[image_idx] + image_str
|
117 |
+
|
118 |
+
# image_path = single_turn['images'][image_idx]
|
119 |
+
# if image_path == '':
|
120 |
+
# message += text_list[image_idx] + '<corrupt_image>'
|
121 |
+
# else:
|
122 |
+
# message += text_list[image_idx] + f'![](file={image_path})'
|
123 |
+
message += text_list[-1]
|
124 |
+
|
125 |
+
if i % 2 == 0:
|
126 |
+
dialog.append([message, None])
|
127 |
+
else:
|
128 |
+
dialog[-1][-1] = message
|
129 |
+
|
130 |
+
return dialog
|
131 |
+
|
132 |
+
def copy(self):
|
133 |
+
return Conversation(system=self.system,
|
134 |
+
roles=self.roles,
|
135 |
+
messages=copy.deepcopy(self.messages),
|
136 |
+
offset=self.offset,
|
137 |
+
sep_style=self.sep_style,
|
138 |
+
sep=self.sep,
|
139 |
+
sep2=self.sep2,
|
140 |
+
version=self.version)
|
141 |
+
|
142 |
+
def dict(self):
|
143 |
+
messages = copy.deepcopy(self.messages)
|
144 |
+
for message in messages:
|
145 |
+
for i in range(len(message['message']['images'])):
|
146 |
+
message['message']['images'][i] = os.path.basename(message['message']['images'][i])
|
147 |
+
return {
|
148 |
+
"system": self.system,
|
149 |
+
"roles": self.roles,
|
150 |
+
"messages": messages,
|
151 |
+
"offset": self.offset,
|
152 |
+
"sep": self.sep,
|
153 |
+
"sep2": self.sep2,
|
154 |
+
}
|
155 |
+
|
156 |
+
|
157 |
+
conv_seed_vicuna = Conversation(
|
158 |
+
system="",
|
159 |
+
roles=("USER", "ASSISTANT"),
|
160 |
+
version="v2",
|
161 |
+
messages=[],
|
162 |
+
offset=0,
|
163 |
+
sep_style=SeparatorStyle.SINGLE,
|
164 |
+
sep='\n',
|
165 |
+
)
|
166 |
+
|
167 |
+
conv_seed_vicuna_system = Conversation(
|
168 |
+
system="A chat between a curious user and an artificial intelligence assistant. ",
|
169 |
+
roles=("USER", "ASSISTANT"),
|
170 |
+
version="v2",
|
171 |
+
messages=[],
|
172 |
+
offset=0,
|
173 |
+
sep_style=SeparatorStyle.SINGLE,
|
174 |
+
sep='\n',
|
175 |
+
)
|
176 |
+
|
177 |
+
conv_seed_llama2 = Conversation(
|
178 |
+
system="",
|
179 |
+
roles=("[INST]", "[/INST]"),
|
180 |
+
version="v2",
|
181 |
+
messages=[],
|
182 |
+
offset=0,
|
183 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
184 |
+
sep='\n',
|
185 |
+
)
|
src/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .datapipes import TarArchiveLoader, JsonlParserIterDataPipe
|
src/data/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (226 Bytes). View file
|
|
src/data/__pycache__/datapipes.cpython-38.pyc
ADDED
Binary file (2.86 kB). View file
|
|
src/data/dataloader_utils.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import time
|
9 |
+
import random
|
10 |
+
import torch
|
11 |
+
# from lavis.datasets.data_utils import move_to_cuda
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
|
15 |
+
class MultiIterLoader:
|
16 |
+
"""
|
17 |
+
A simple wrapper for iterating over multiple iterators.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
loaders (List[Loader]): List of Iterator loaders.
|
21 |
+
ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, loaders, ratios=None):
|
25 |
+
# assert all loaders has __next__ method
|
26 |
+
for loader in loaders:
|
27 |
+
assert hasattr(loader, "__next__"), "Loader {} has no __next__ method.".format(loader)
|
28 |
+
|
29 |
+
if ratios is None:
|
30 |
+
ratios = [1.0] * len(loaders)
|
31 |
+
else:
|
32 |
+
assert len(ratios) == len(loaders)
|
33 |
+
ratios = [float(ratio) / sum(ratios) for ratio in ratios]
|
34 |
+
|
35 |
+
self.loaders = loaders
|
36 |
+
self.ratios = ratios
|
37 |
+
|
38 |
+
def __next__(self):
|
39 |
+
# random sample from each loader by ratio
|
40 |
+
loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
|
41 |
+
return next(self.loaders[loader_idx])
|
42 |
+
|
43 |
+
def __iter__(self):
|
44 |
+
return self
|
45 |
+
|
46 |
+
|
47 |
+
class PrefetchLoader(object):
|
48 |
+
"""
|
49 |
+
Modified from https://github.com/ChenRocks/UNITER.
|
50 |
+
|
51 |
+
overlap compute and cuda data transfer
|
52 |
+
(copied and then modified from nvidia apex)
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self, loader):
|
56 |
+
self.loader = loader
|
57 |
+
self.stream = torch.cuda.Stream()
|
58 |
+
|
59 |
+
def __iter__(self):
|
60 |
+
loader_it = iter(self.loader)
|
61 |
+
self.preload(loader_it)
|
62 |
+
batch = self.next(loader_it)
|
63 |
+
while batch is not None:
|
64 |
+
is_tuple = isinstance(batch, tuple)
|
65 |
+
if is_tuple:
|
66 |
+
task, batch = batch
|
67 |
+
|
68 |
+
if is_tuple:
|
69 |
+
yield task, batch
|
70 |
+
else:
|
71 |
+
yield batch
|
72 |
+
batch = self.next(loader_it)
|
73 |
+
|
74 |
+
def __len__(self):
|
75 |
+
return len(self.loader)
|
76 |
+
|
77 |
+
def preload(self, it):
|
78 |
+
try:
|
79 |
+
self.batch = next(it)
|
80 |
+
except StopIteration:
|
81 |
+
self.batch = None
|
82 |
+
return
|
83 |
+
# if record_stream() doesn't work, another option is to make sure
|
84 |
+
# device inputs are created on the main stream.
|
85 |
+
# self.next_input_gpu = torch.empty_like(self.next_input,
|
86 |
+
# device='cuda')
|
87 |
+
# self.next_target_gpu = torch.empty_like(self.next_target,
|
88 |
+
# device='cuda')
|
89 |
+
# Need to make sure the memory allocated for next_* is not still in use
|
90 |
+
# by the main stream at the time we start copying to next_*:
|
91 |
+
# self.stream.wait_stream(torch.cuda.current_stream())
|
92 |
+
# with torch.cuda.stream(self.stream):
|
93 |
+
# self.batch = move_to_cuda(self.batch)
|
94 |
+
# more code for the alternative if record_stream() doesn't work:
|
95 |
+
# copy_ will record the use of the pinned source tensor in this
|
96 |
+
# side stream.
|
97 |
+
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
|
98 |
+
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
|
99 |
+
# self.next_input = self.next_input_gpu
|
100 |
+
# self.next_target = self.next_target_gpu
|
101 |
+
|
102 |
+
def next(self, it):
|
103 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
104 |
+
batch = self.batch
|
105 |
+
if batch is not None:
|
106 |
+
record_cuda_stream(batch)
|
107 |
+
self.preload(it)
|
108 |
+
return batch
|
109 |
+
|
110 |
+
def __getattr__(self, name):
|
111 |
+
method = self.loader.__getattribute__(name)
|
112 |
+
return method
|
113 |
+
|
114 |
+
|
115 |
+
def record_cuda_stream(batch):
|
116 |
+
if isinstance(batch, torch.Tensor):
|
117 |
+
batch.record_stream(torch.cuda.current_stream())
|
118 |
+
elif isinstance(batch, list) or isinstance(batch, tuple):
|
119 |
+
for t in batch:
|
120 |
+
record_cuda_stream(t)
|
121 |
+
elif isinstance(batch, dict):
|
122 |
+
for t in batch.values():
|
123 |
+
record_cuda_stream(t)
|
124 |
+
else:
|
125 |
+
pass
|
126 |
+
|
127 |
+
|
128 |
+
class IterLoader:
|
129 |
+
"""
|
130 |
+
A wrapper to convert DataLoader as an infinite iterator.
|
131 |
+
|
132 |
+
Modified from:
|
133 |
+
https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
|
134 |
+
"""
|
135 |
+
|
136 |
+
def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
|
137 |
+
self._dataloader = dataloader
|
138 |
+
self.iter_loader = iter(self._dataloader)
|
139 |
+
self._use_distributed = use_distributed
|
140 |
+
self._epoch = 0
|
141 |
+
|
142 |
+
@property
|
143 |
+
def epoch(self) -> int:
|
144 |
+
return self._epoch
|
145 |
+
|
146 |
+
def __next__(self):
|
147 |
+
try:
|
148 |
+
data = next(self.iter_loader)
|
149 |
+
except StopIteration:
|
150 |
+
self._epoch += 1
|
151 |
+
if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
|
152 |
+
self._dataloader.sampler.set_epoch(self._epoch)
|
153 |
+
time.sleep(2) # Prevent possible deadlock during epoch transition
|
154 |
+
self.iter_loader = iter(self._dataloader)
|
155 |
+
data = next(self.iter_loader)
|
156 |
+
|
157 |
+
return data
|
158 |
+
|
159 |
+
def __iter__(self):
|
160 |
+
return self
|
161 |
+
|
162 |
+
def __len__(self):
|
163 |
+
return len(self._dataloader)
|
src/data/datapipes.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchdata.datapipes as dp
|
2 |
+
import os
|
3 |
+
import tarfile
|
4 |
+
from torchdata.datapipes.iter import TarArchiveLoader
|
5 |
+
from typing import cast, IO, Iterable, Iterator, Optional, Tuple, Dict
|
6 |
+
from torchdata.datapipes import functional_datapipe
|
7 |
+
from io import BufferedIOBase
|
8 |
+
from torchdata.datapipes.utils import StreamWrapper
|
9 |
+
from torchdata.datapipes.utils.common import validate_pathname_binary_tuple
|
10 |
+
import warnings
|
11 |
+
from torchdata.datapipes.iter import IterDataPipe
|
12 |
+
import json
|
13 |
+
|
14 |
+
|
15 |
+
@functional_datapipe("load_from_tar_wo_exception")
|
16 |
+
class TarArchiveLoaderWoException(TarArchiveLoader):
|
17 |
+
|
18 |
+
def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
|
19 |
+
for data in self.datapipe:
|
20 |
+
validate_pathname_binary_tuple(data)
|
21 |
+
pathname, data_stream = data
|
22 |
+
try:
|
23 |
+
if isinstance(data_stream, StreamWrapper) and isinstance(data_stream.file_obj, tarfile.TarFile):
|
24 |
+
tar = data_stream.file_obj
|
25 |
+
else:
|
26 |
+
reading_mode = (self.mode if hasattr(data_stream, "seekable") and data_stream.seekable() else
|
27 |
+
self.mode.replace(":", "|"))
|
28 |
+
# typing.cast is used here to silence mypy's type checker
|
29 |
+
tar = tarfile.open(fileobj=cast(Optional[IO[bytes]], data_stream), mode=reading_mode)
|
30 |
+
for tarinfo in tar:
|
31 |
+
if not tarinfo.isfile():
|
32 |
+
continue
|
33 |
+
extracted_fobj = tar.extractfile(tarinfo)
|
34 |
+
if extracted_fobj is None:
|
35 |
+
warnings.warn(f"failed to extract file {tarinfo.name} from source tarfile {pathname}")
|
36 |
+
raise tarfile.ExtractError
|
37 |
+
inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name))
|
38 |
+
yield inner_pathname, StreamWrapper(extracted_fobj, data_stream,
|
39 |
+
name=inner_pathname) # type: ignore[misc]
|
40 |
+
except Exception as e:
|
41 |
+
warnings.warn(f"Unable to extract files from corrupted tarfile stream {pathname} due to: {e}, abort!")
|
42 |
+
# raise e
|
43 |
+
finally:
|
44 |
+
if isinstance(data_stream, StreamWrapper):
|
45 |
+
data_stream.autoclose()
|
46 |
+
|
47 |
+
|
48 |
+
@functional_datapipe("parse_jsonl_files")
|
49 |
+
class JsonlParserIterDataPipe(IterDataPipe[Tuple[str, Dict]]):
|
50 |
+
|
51 |
+
def __init__(self, source_datapipe: IterDataPipe[Tuple[str, IO]], **kwargs) -> None:
|
52 |
+
self.source_datapipe: IterDataPipe[Tuple[str, IO]] = source_datapipe
|
53 |
+
self.kwargs = kwargs
|
54 |
+
|
55 |
+
def __iter__(self) -> Iterator[Tuple[str, Dict]]:
|
56 |
+
for file_name, stream in self.source_datapipe:
|
57 |
+
for idx, line in enumerate(stream):
|
58 |
+
if line.strip() != '':
|
59 |
+
try:
|
60 |
+
yield f'{file_name}_line{idx}', json.loads(line)
|
61 |
+
except Exception as e:
|
62 |
+
warnings.warn(f"Error occured when parsing string to json due to: {e} abort!")
|
src/data/story_telling.py
ADDED
@@ -0,0 +1,634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchdata.datapipes as dp
|
2 |
+
import json
|
3 |
+
from PIL import Image
|
4 |
+
import functools
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import pickle
|
8 |
+
import os
|
9 |
+
import cv2
|
10 |
+
import random
|
11 |
+
from torchvision import transforms
|
12 |
+
from braceexpand import braceexpand
|
13 |
+
import hydra
|
14 |
+
from random import choice
|
15 |
+
import tarfile
|
16 |
+
from torchdata.datapipes.iter import TarArchiveLoader
|
17 |
+
from typing import cast, IO, Iterable, Iterator, Optional, Tuple, Dict
|
18 |
+
from torchdata.datapipes import functional_datapipe
|
19 |
+
from io import BufferedIOBase
|
20 |
+
from torchdata.datapipes.utils import StreamWrapper
|
21 |
+
from torchdata.datapipes.utils.common import validate_pathname_binary_tuple
|
22 |
+
import warnings
|
23 |
+
from torchdata.datapipes.iter import IterDataPipe
|
24 |
+
|
25 |
+
import pyrootutils
|
26 |
+
|
27 |
+
pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
|
28 |
+
|
29 |
+
BOI_TOKEN = '<img>'
|
30 |
+
EOI_TOKEN = '</img>'
|
31 |
+
IMG_TOKEN = '<img_{:05d}>'
|
32 |
+
|
33 |
+
gen_prompt = [
|
34 |
+
"Please show me a picture of ",
|
35 |
+
"Please design an image of ",
|
36 |
+
"Please produce a photo of ",
|
37 |
+
"Please generate an image of ",
|
38 |
+
"Please draw a painting of ",
|
39 |
+
"I'd like to see a drawing of ",
|
40 |
+
"I'd love to see an illustration of ",
|
41 |
+
"I'd like to view an image of ",
|
42 |
+
"I want to see a picture of ",
|
43 |
+
"I would like to see a photo of ",
|
44 |
+
"Show me a photo of ",
|
45 |
+
"Generate a picture of ",
|
46 |
+
"Show me a photograph of ",
|
47 |
+
"Generate an image of ",
|
48 |
+
"Generate an image: ",
|
49 |
+
"Generate a picture: ",
|
50 |
+
"Generate a painting: ",
|
51 |
+
"Generate a photograph: ",
|
52 |
+
"Show me a photograph: ",
|
53 |
+
"Draw a picture: ",
|
54 |
+
"Draw a painting: ",
|
55 |
+
"Draw an image: ",
|
56 |
+
"Can you make an image of ",
|
57 |
+
"Can you draw a painting of ",
|
58 |
+
"Can you produce a picture of ",
|
59 |
+
"Can you generate a photo of ",
|
60 |
+
"Can you depict a picture of ",
|
61 |
+
"Can you show me an illustration of ",
|
62 |
+
]
|
63 |
+
|
64 |
+
gen_prompt_response = [
|
65 |
+
"Here is a picture.",
|
66 |
+
"I have designed an image.",
|
67 |
+
"Here is a photo.",
|
68 |
+
"I have generated an image.",
|
69 |
+
"Here's a painting.",
|
70 |
+
"Here's a drawing.",
|
71 |
+
"Enjoy this illustration.",
|
72 |
+
"Take a look at this image.",
|
73 |
+
"Here is a picture.",
|
74 |
+
"I have created a photo.",
|
75 |
+
"Enjoy this photo.",
|
76 |
+
"I have generated a picture.",
|
77 |
+
"Here is a photograph.",
|
78 |
+
"Here's an image.",
|
79 |
+
"Certainly, here's an image.",
|
80 |
+
"Absolutely, here is a painting.",
|
81 |
+
"Sure, here is a picture.",
|
82 |
+
"Of course, here is a photo.",
|
83 |
+
"Certainly, please enjoy this picture.",
|
84 |
+
"Sure, please enjoy this illustration.",
|
85 |
+
"",
|
86 |
+
]
|
87 |
+
|
88 |
+
jdb_filter_vocab = ['watermark', 'watermark,', 'chaos 100', 'chaos 100,']
|
89 |
+
|
90 |
+
|
91 |
+
def filter_data_with_image_ids(item):
|
92 |
+
if ('images' not in item):
|
93 |
+
# print(item['__key__'])
|
94 |
+
# print('filtered because no images')
|
95 |
+
return False
|
96 |
+
elif 'input_ids' not in item:
|
97 |
+
return False
|
98 |
+
else:
|
99 |
+
return True
|
100 |
+
|
101 |
+
|
102 |
+
def calculate_new_dimensions(height, width, target_size):
|
103 |
+
if height < width:
|
104 |
+
new_height = target_size
|
105 |
+
new_width = int(width * (target_size / height))
|
106 |
+
else:
|
107 |
+
new_width = target_size
|
108 |
+
new_height = int(height * (target_size / width))
|
109 |
+
return new_height, new_width
|
110 |
+
|
111 |
+
|
112 |
+
def unwarp_data(item):
|
113 |
+
unwarpped = {}
|
114 |
+
for key, value in item.items():
|
115 |
+
if isinstance(value, dict):
|
116 |
+
unwarpped.update(value)
|
117 |
+
elif value is not None:
|
118 |
+
unwarpped[key] = value
|
119 |
+
if 'metadata' not in unwarpped:
|
120 |
+
unwarpped['metadata'] = '{}'
|
121 |
+
# if '__key__' in unwarpped:
|
122 |
+
# unwarpped['__key__'] = unwarpped['__key__'].split('/')[-1]
|
123 |
+
return unwarpped
|
124 |
+
|
125 |
+
|
126 |
+
# def filter_data_with_similarity(item, similarity_thr=0.2, min_resolution=180, min_aspect_ratio=0.666):
|
127 |
+
def filter_data_with_similarity(item, similarity_thr=0.2, assure_text=True):
|
128 |
+
if ('images' not in item):
|
129 |
+
# print(item['__key__'])
|
130 |
+
# print('filtered because no images')
|
131 |
+
return False
|
132 |
+
elif (not item.get('filter_flag', True)):
|
133 |
+
# print(item['__key__'])
|
134 |
+
# print('filtered because filter flag.')
|
135 |
+
return False
|
136 |
+
elif assure_text and ('text' not in item):
|
137 |
+
# print(item['__key__'])
|
138 |
+
# print('filtered because assure_text')
|
139 |
+
return False
|
140 |
+
else:
|
141 |
+
metadata = json.loads(item['metadata'])
|
142 |
+
|
143 |
+
if 'all_similarities' in metadata:
|
144 |
+
similarity = max(metadata['all_similarities'])
|
145 |
+
elif 'similarity' in metadata:
|
146 |
+
similarity = metadata['similarity']
|
147 |
+
elif 'score' in metadata:
|
148 |
+
similarity = metadata['score']
|
149 |
+
elif 'SCORE' in metadata:
|
150 |
+
similarity = metadata['SCORE']
|
151 |
+
else:
|
152 |
+
similarity = None
|
153 |
+
|
154 |
+
if similarity is not None:
|
155 |
+
if similarity < similarity_thr:
|
156 |
+
# print(item['__key__'])
|
157 |
+
# print('filtered because similarity')
|
158 |
+
return False
|
159 |
+
|
160 |
+
return True
|
161 |
+
|
162 |
+
|
163 |
+
def single_turn_edit_collate(batch):
|
164 |
+
results = {}
|
165 |
+
keys = batch[0].keys()
|
166 |
+
|
167 |
+
for key in keys:
|
168 |
+
cur = [batch[i][key] for i in range(len(batch)) if batch[i][key] is not None]
|
169 |
+
if len(cur) == 0:
|
170 |
+
results[key] = None
|
171 |
+
elif isinstance(cur[0], torch.Tensor):
|
172 |
+
if key in ['embeds_gen_mask', 'embeds_cmp_mask', 'images']:
|
173 |
+
results[key] = torch.cat(cur, dim=0)
|
174 |
+
else:
|
175 |
+
results[key] = torch.stack(cur, dim=0)
|
176 |
+
else:
|
177 |
+
results[key] = cur
|
178 |
+
|
179 |
+
return results
|
180 |
+
|
181 |
+
|
182 |
+
def decode_t2i_data(item,
|
183 |
+
image_dir,
|
184 |
+
tokenizer,
|
185 |
+
image_transform=None,
|
186 |
+
sd_image_transform=None,
|
187 |
+
max_length=128,
|
188 |
+
min_resolution=400,
|
189 |
+
instruction_prompt='[INST] {instruction} [/INST]\n',
|
190 |
+
turn_sep='\n',
|
191 |
+
system_message='',
|
192 |
+
min_aspect_ratio=0.666,
|
193 |
+
num_img_in_tokens=64,
|
194 |
+
num_img_out_tokens=64):
|
195 |
+
key, value = item
|
196 |
+
|
197 |
+
if 'image' not in value or 'caption' not in value:
|
198 |
+
return {}
|
199 |
+
|
200 |
+
image_path = os.path.join(image_dir, value["image"])
|
201 |
+
|
202 |
+
try:
|
203 |
+
image = Image.open(image_path).convert('RGB')
|
204 |
+
|
205 |
+
width, height = image.size
|
206 |
+
|
207 |
+
aspect_ratio = height / width
|
208 |
+
if height < min_resolution or width < min_resolution:
|
209 |
+
print(f'filtered because resolution: ({width},{height})')
|
210 |
+
return {}
|
211 |
+
if aspect_ratio < min_aspect_ratio or aspect_ratio > 1 / min_aspect_ratio:
|
212 |
+
print(f'filtered because aspect ratio: ({width},{height})')
|
213 |
+
return {}
|
214 |
+
### SD related
|
215 |
+
|
216 |
+
image_data = {}
|
217 |
+
|
218 |
+
if sd_image_transform is not None:
|
219 |
+
# image_data['original_sizes'] = torch.tensor([height, width])
|
220 |
+
sd_image_tensor = sd_image_transform(image)
|
221 |
+
target_size = sd_image_tensor.shape[-2]
|
222 |
+
target_width, target_height = calculate_new_dimensions(height=height, width=width, target_size=target_size)
|
223 |
+
y1 = max(0, int(round((target_height - target_size) / 2.0)))
|
224 |
+
x1 = max(0, int(round((target_width - target_size) / 2.0)))
|
225 |
+
# image_data['crop_top_lefts'] = torch.tensor([y1, x1])
|
226 |
+
image_data['time_ids'] = torch.tensor([height, width, y1, x1, target_size, target_size])
|
227 |
+
|
228 |
+
image_data['sd_images'] = sd_image_tensor
|
229 |
+
|
230 |
+
if image_transform is not None:
|
231 |
+
image = image_transform(image)
|
232 |
+
|
233 |
+
except Exception as e:
|
234 |
+
print('Error while decode image: ', e)
|
235 |
+
return {}
|
236 |
+
|
237 |
+
input_ids = []
|
238 |
+
labels = []
|
239 |
+
input_text = ''
|
240 |
+
|
241 |
+
if system_message != '':
|
242 |
+
if not system_message.endswith('\n'):
|
243 |
+
system_message += '\n'
|
244 |
+
input_text += system_message
|
245 |
+
item_ids = tokenizer.encode(system_message, add_special_tokens=False)
|
246 |
+
item_labels = [-100] * len(item_ids)
|
247 |
+
input_ids.extend(item_ids)
|
248 |
+
labels.extend(item_labels)
|
249 |
+
|
250 |
+
caption = value["caption"]
|
251 |
+
|
252 |
+
image_cmp_tokens = BOI_TOKEN + ''.join(
|
253 |
+
[IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN
|
254 |
+
|
255 |
+
image_gen_tokens = BOI_TOKEN + ''.join(
|
256 |
+
[IMG_TOKEN.format(int(item)) for item in range(num_img_out_tokens)]) + EOI_TOKEN
|
257 |
+
|
258 |
+
instruction = instruction_prompt.format_map({'instruction': caption})
|
259 |
+
|
260 |
+
response = image_gen_tokens
|
261 |
+
images = torch.stack([image], dim=0)
|
262 |
+
# print(instruction)
|
263 |
+
|
264 |
+
item_ids = tokenizer.encode(instruction, add_special_tokens=False)
|
265 |
+
item_labels = [-100] * len(item_ids)
|
266 |
+
input_text += instruction
|
267 |
+
input_ids.extend(item_ids)
|
268 |
+
labels.extend(item_labels)
|
269 |
+
|
270 |
+
item_ids = tokenizer.encode(response, add_special_tokens=False)
|
271 |
+
item_labels = item_ids
|
272 |
+
input_text += response
|
273 |
+
input_ids.extend(item_ids)
|
274 |
+
labels.extend(item_labels)
|
275 |
+
|
276 |
+
input_ids = [tokenizer.bos_token_id] + input_ids + [tokenizer.eos_token_id]
|
277 |
+
attention_mask = [1] * len(input_ids)
|
278 |
+
labels = [-100] + labels + [tokenizer.eos_token_id]
|
279 |
+
|
280 |
+
boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
|
281 |
+
eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
|
282 |
+
ids_cmp_mask = [False] * len(input_ids)
|
283 |
+
ids_gen_mask = [False] * len(input_ids)
|
284 |
+
|
285 |
+
embeds_cmp_mask = [False]
|
286 |
+
embeds_gen_mask = [True]
|
287 |
+
|
288 |
+
# print(len(input_ids))
|
289 |
+
if len(input_ids) >= max_length:
|
290 |
+
# input_ids = input_ids[:max_length]
|
291 |
+
# attention_mask = attention_mask[:max_length]
|
292 |
+
# labels = labels[:max_length]
|
293 |
+
# ids_cmp_mask = ids_cmp_mask[:max_length]
|
294 |
+
# ids_gen_mask = ids_gen_mask[:max_length]
|
295 |
+
# print('An edit sample has been removed because of max length. input_text: ', input_text)
|
296 |
+
return {}
|
297 |
+
else:
|
298 |
+
padding_length = max_length - len(input_ids)
|
299 |
+
input_ids = input_ids + [tokenizer.pad_token_id] * padding_length
|
300 |
+
attention_mask = attention_mask + [0] * padding_length
|
301 |
+
labels = labels + [-100] * padding_length
|
302 |
+
ids_cmp_mask = ids_cmp_mask + [False] * padding_length
|
303 |
+
ids_gen_mask = ids_gen_mask + [False] * padding_length
|
304 |
+
|
305 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
306 |
+
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
|
307 |
+
labels = torch.tensor(labels, dtype=torch.long)
|
308 |
+
ids_cmp_mask = torch.tensor(ids_cmp_mask, dtype=torch.bool)
|
309 |
+
ids_gen_mask = torch.tensor(ids_gen_mask, dtype=torch.bool)
|
310 |
+
embeds_cmp_mask = torch.tensor(embeds_cmp_mask) if embeds_cmp_mask is not None else None
|
311 |
+
embeds_gen_mask = torch.tensor(embeds_gen_mask) if embeds_gen_mask is not None else None
|
312 |
+
|
313 |
+
boi_idx = torch.where(input_ids == boi_token_id)[0].tolist()
|
314 |
+
eoi_idx = torch.where(input_ids == eoi_token_id)[0].tolist()
|
315 |
+
|
316 |
+
ids_gen_mask[boi_idx[0] + 1:eoi_idx[0]] = True
|
317 |
+
labels[boi_idx[0] + 1:eoi_idx[0] + 1] = -100
|
318 |
+
|
319 |
+
ret = {
|
320 |
+
'input_ids': input_ids,
|
321 |
+
'attention_mask': attention_mask,
|
322 |
+
'labels': labels,
|
323 |
+
'ids_gen_mask': ids_gen_mask,
|
324 |
+
'ids_cmp_mask': ids_cmp_mask,
|
325 |
+
'embeds_gen_mask': embeds_gen_mask,
|
326 |
+
'embeds_cmp_mask': embeds_cmp_mask,
|
327 |
+
'images': images,
|
328 |
+
'text': input_text,
|
329 |
+
}
|
330 |
+
|
331 |
+
ret.update(image_data)
|
332 |
+
|
333 |
+
return ret
|
334 |
+
|
335 |
+
|
336 |
+
def build_t2i_datapipe(data_dir,
|
337 |
+
image_dir,
|
338 |
+
tokenizer=None,
|
339 |
+
max_length=77,
|
340 |
+
batch_size=None,
|
341 |
+
min_resolution=180,
|
342 |
+
image_transform=None,
|
343 |
+
sd_image_transform=None,
|
344 |
+
instruction_prompt='[INST] {instruction} [INST]\n',
|
345 |
+
turn_sep='\n',
|
346 |
+
system_message='',
|
347 |
+
min_aspect_ratio=0.666,
|
348 |
+
num_img_in_tokens=64,
|
349 |
+
num_img_out_tokens=64,
|
350 |
+
cycle_count=None):
|
351 |
+
decode_partial = functools.partial(decode_t2i_data,
|
352 |
+
image_dir=image_dir,
|
353 |
+
tokenizer=tokenizer,
|
354 |
+
image_transform=image_transform,
|
355 |
+
sd_image_transform=sd_image_transform,
|
356 |
+
max_length=max_length,
|
357 |
+
instruction_prompt=instruction_prompt,
|
358 |
+
turn_sep=turn_sep,
|
359 |
+
system_message=system_message,
|
360 |
+
min_resolution=min_resolution,
|
361 |
+
min_aspect_ratio=min_aspect_ratio,
|
362 |
+
num_img_in_tokens=num_img_in_tokens,
|
363 |
+
num_img_out_tokens=num_img_out_tokens)
|
364 |
+
|
365 |
+
filter_partial = functools.partial(filter_data_with_image_ids)
|
366 |
+
|
367 |
+
if isinstance(data_dir, str):
|
368 |
+
data_dir = list(braceexpand(data_dir))
|
369 |
+
|
370 |
+
datapipe = dp.iter.FileLister(root=data_dir, masks='*.jsonl', recursive=True)
|
371 |
+
datapipe = datapipe.shuffle()
|
372 |
+
datapipe = datapipe.cycle(count=cycle_count)
|
373 |
+
datapipe = datapipe.shuffle()
|
374 |
+
# datapipe = dp.iter.FileLister(root=data_dir, masks='0000000.tar', recursive=True)
|
375 |
+
datapipe = datapipe.sharding_filter()
|
376 |
+
# datapipe = datapipe.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
|
377 |
+
|
378 |
+
datapipe = datapipe.open_files(mode='r')
|
379 |
+
datapipe = datapipe.parse_jsonl_files()
|
380 |
+
datapipe = datapipe.map(decode_partial)
|
381 |
+
datapipe = datapipe.filter(filter_partial)
|
382 |
+
|
383 |
+
# datapipe = datapipe.shuffle(buffer_size=1024)
|
384 |
+
if batch_size is not None:
|
385 |
+
datapipe = datapipe.batch(batch_size)
|
386 |
+
datapipe = datapipe.collate(single_turn_edit_collate)
|
387 |
+
return datapipe
|
388 |
+
|
389 |
+
|
390 |
+
def decode_long_story_data(item,
|
391 |
+
image_dir,
|
392 |
+
tokenizer,
|
393 |
+
story_len,
|
394 |
+
image_transform=None,
|
395 |
+
sd_image_transform=None,
|
396 |
+
max_length=128,
|
397 |
+
min_resolution=400,
|
398 |
+
instruction_prompt='{instruction}',
|
399 |
+
turn_sep='\n',
|
400 |
+
system_message='',
|
401 |
+
min_aspect_ratio=0.666,
|
402 |
+
num_img_in_tokens=64,
|
403 |
+
num_img_out_tokens=64, ):
|
404 |
+
key, value = item
|
405 |
+
if 'images' not in value or 'captions' not in value:
|
406 |
+
return {}
|
407 |
+
|
408 |
+
image_paths = [os.path.join(image_dir, image_path) for image_path in value["images"]]
|
409 |
+
# assert len(image_paths) == story_len
|
410 |
+
story_len = len(image_paths)
|
411 |
+
num_image_given = random.randint(0, story_len - 2)
|
412 |
+
|
413 |
+
try:
|
414 |
+
images = []
|
415 |
+
for image_path in image_paths:
|
416 |
+
image = Image.open(image_path).convert('RGB')
|
417 |
+
images.append(image)
|
418 |
+
width, height = image.size
|
419 |
+
|
420 |
+
aspect_ratio = height / width
|
421 |
+
if height < min_resolution or width < min_resolution:
|
422 |
+
print(f'filtered because resolution: ({width},{height})')
|
423 |
+
return {}
|
424 |
+
if aspect_ratio < min_aspect_ratio or aspect_ratio > 1 / min_aspect_ratio:
|
425 |
+
print(f'filtered because aspect ratio: ({width},{height})')
|
426 |
+
return {}
|
427 |
+
|
428 |
+
image_data = {}
|
429 |
+
sd_image = images[num_image_given + 1]
|
430 |
+
if sd_image_transform is not None:
|
431 |
+
# image_data['original_sizes'] = torch.tensor([height, width])
|
432 |
+
sd_image_tensor = sd_image_transform(sd_image)
|
433 |
+
target_size = sd_image_tensor.shape[-2]
|
434 |
+
target_width, target_height = calculate_new_dimensions(height=height, width=width, target_size=target_size)
|
435 |
+
y1 = max(0, int(round((target_height - target_size) / 2.0)))
|
436 |
+
x1 = max(0, int(round((target_width - target_size) / 2.0)))
|
437 |
+
# image_data['crop_top_lefts'] = torch.tensor([y1, x1])
|
438 |
+
image_data['time_ids'] = torch.tensor([height, width, y1, x1, target_size, target_size])
|
439 |
+
|
440 |
+
image_data['sd_images'] = sd_image_tensor
|
441 |
+
|
442 |
+
if image_transform is not None:
|
443 |
+
for i in range(len(images)):
|
444 |
+
images[i] = image_transform(images[i])
|
445 |
+
images = torch.stack(images, dim=0)
|
446 |
+
|
447 |
+
except Exception as e:
|
448 |
+
print('Error while decode image: ', e)
|
449 |
+
return {}
|
450 |
+
|
451 |
+
input_ids = []
|
452 |
+
labels = []
|
453 |
+
input_text = ''
|
454 |
+
|
455 |
+
if system_message != '':
|
456 |
+
if not system_message.endswith('\n'):
|
457 |
+
system_message += '\n'
|
458 |
+
input_text += system_message
|
459 |
+
item_ids = tokenizer.encode(system_message, add_special_tokens=False)
|
460 |
+
item_labels = [-100] * len(item_ids)
|
461 |
+
input_ids.extend(item_ids)
|
462 |
+
labels.extend(item_labels)
|
463 |
+
|
464 |
+
captions_all = []
|
465 |
+
for i in range(story_len):
|
466 |
+
caption = value["captions"][i]
|
467 |
+
captions_all.append(caption)
|
468 |
+
|
469 |
+
image_cmp_tokens = BOI_TOKEN + ''.join(
|
470 |
+
[IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN
|
471 |
+
|
472 |
+
image_gen_tokens = BOI_TOKEN + ''.join(
|
473 |
+
[IMG_TOKEN.format(int(item)) for item in range(num_img_out_tokens)]) + EOI_TOKEN
|
474 |
+
|
475 |
+
instruction = instruction_prompt.format_map({'instruction': captions_all[0] + image_cmp_tokens})
|
476 |
+
for i in range(num_image_given):
|
477 |
+
instruction = instruction + "[INST]" + captions_all[i + 1] + image_cmp_tokens
|
478 |
+
|
479 |
+
response = "[INST]" + captions_all[num_image_given + 1] + image_gen_tokens
|
480 |
+
|
481 |
+
images = images[:num_image_given + 2]
|
482 |
+
# print(instruction)
|
483 |
+
|
484 |
+
item_ids = tokenizer.encode(instruction, add_special_tokens=False)
|
485 |
+
item_labels = [-100] * len(item_ids)
|
486 |
+
input_text += instruction
|
487 |
+
input_ids.extend(item_ids)
|
488 |
+
labels.extend(item_labels)
|
489 |
+
|
490 |
+
item_ids = tokenizer.encode(response, add_special_tokens=False)
|
491 |
+
item_labels = item_ids
|
492 |
+
input_text += response
|
493 |
+
input_ids.extend(item_ids)
|
494 |
+
labels.extend(item_labels)
|
495 |
+
|
496 |
+
input_ids = [tokenizer.bos_token_id] + input_ids + [tokenizer.eos_token_id]
|
497 |
+
attention_mask = [1] * len(input_ids)
|
498 |
+
labels = [-100] + labels + [tokenizer.eos_token_id]
|
499 |
+
|
500 |
+
boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
|
501 |
+
eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
|
502 |
+
ids_cmp_mask = [False] * len(input_ids)
|
503 |
+
ids_gen_mask = [False] * len(input_ids)
|
504 |
+
|
505 |
+
embeds_cmp_mask = [True] + [True] * num_image_given + [False]
|
506 |
+
embeds_gen_mask = [False] + [False] * num_image_given + [True]
|
507 |
+
|
508 |
+
# print(len(input_ids))
|
509 |
+
if len(input_ids) >= max_length:
|
510 |
+
# input_ids = input_ids[:max_length]
|
511 |
+
# attention_mask = attention_mask[:max_length]
|
512 |
+
# labels = labels[:max_length]
|
513 |
+
# ids_cmp_mask = ids_cmp_mask[:max_length]
|
514 |
+
# ids_gen_mask = ids_gen_mask[:max_length]
|
515 |
+
# print('An edit sample has been removed because of max length. input_text: ', input_text)
|
516 |
+
return {}
|
517 |
+
else:
|
518 |
+
padding_length = max_length - len(input_ids)
|
519 |
+
input_ids = input_ids + [tokenizer.pad_token_id] * padding_length
|
520 |
+
attention_mask = attention_mask + [0] * padding_length
|
521 |
+
labels = labels + [-100] * padding_length
|
522 |
+
ids_cmp_mask = ids_cmp_mask + [False] * padding_length
|
523 |
+
ids_gen_mask = ids_gen_mask + [False] * padding_length
|
524 |
+
|
525 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
526 |
+
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
|
527 |
+
labels = torch.tensor(labels, dtype=torch.long)
|
528 |
+
ids_cmp_mask = torch.tensor(ids_cmp_mask, dtype=torch.bool)
|
529 |
+
ids_gen_mask = torch.tensor(ids_gen_mask, dtype=torch.bool)
|
530 |
+
embeds_cmp_mask = torch.tensor(embeds_cmp_mask) if embeds_cmp_mask is not None else None
|
531 |
+
embeds_gen_mask = torch.tensor(embeds_gen_mask) if embeds_gen_mask is not None else None
|
532 |
+
|
533 |
+
boi_idx = torch.where(input_ids == boi_token_id)[0].tolist()
|
534 |
+
eoi_idx = torch.where(input_ids == eoi_token_id)[0].tolist()
|
535 |
+
|
536 |
+
ids_cmp_mask[boi_idx[0] + 1:eoi_idx[0]] = True
|
537 |
+
for i in range(num_image_given):
|
538 |
+
ids_cmp_mask[boi_idx[i + 1] + 1:eoi_idx[i + 1]] = True
|
539 |
+
|
540 |
+
ids_gen_mask[boi_idx[-1] + 1:eoi_idx[-1]] = True
|
541 |
+
labels[boi_idx[-1] + 1:eoi_idx[-1] + 1] = -100
|
542 |
+
|
543 |
+
ret = {
|
544 |
+
'input_ids': input_ids,
|
545 |
+
'attention_mask': attention_mask,
|
546 |
+
'labels': labels,
|
547 |
+
'ids_gen_mask': ids_gen_mask,
|
548 |
+
'ids_cmp_mask': ids_cmp_mask,
|
549 |
+
'embeds_gen_mask': embeds_gen_mask,
|
550 |
+
'embeds_cmp_mask': embeds_cmp_mask,
|
551 |
+
'images': images,
|
552 |
+
'text': input_text,
|
553 |
+
}
|
554 |
+
|
555 |
+
ret.update(image_data)
|
556 |
+
|
557 |
+
return ret
|
558 |
+
|
559 |
+
|
560 |
+
def build_long_story_datapipe(data_dir,
|
561 |
+
image_dir,
|
562 |
+
tokenizer=None,
|
563 |
+
story_len=30,
|
564 |
+
max_length=77,
|
565 |
+
batch_size=None,
|
566 |
+
min_resolution=180,
|
567 |
+
image_transform=None,
|
568 |
+
sd_image_transform=None,
|
569 |
+
instruction_prompt='{instruction}',
|
570 |
+
turn_sep='\n',
|
571 |
+
system_message='',
|
572 |
+
min_aspect_ratio=0.666,
|
573 |
+
num_img_in_tokens=64,
|
574 |
+
num_img_out_tokens=64,
|
575 |
+
cycle_count=None):
|
576 |
+
decode_partial = functools.partial(decode_long_story_data,
|
577 |
+
image_dir=image_dir,
|
578 |
+
tokenizer=tokenizer,
|
579 |
+
story_len=story_len,
|
580 |
+
image_transform=image_transform,
|
581 |
+
sd_image_transform=sd_image_transform,
|
582 |
+
max_length=max_length,
|
583 |
+
instruction_prompt=instruction_prompt,
|
584 |
+
turn_sep=turn_sep,
|
585 |
+
system_message=system_message,
|
586 |
+
min_resolution=min_resolution,
|
587 |
+
min_aspect_ratio=min_aspect_ratio,
|
588 |
+
num_img_in_tokens=num_img_in_tokens,
|
589 |
+
num_img_out_tokens=num_img_out_tokens)
|
590 |
+
|
591 |
+
filter_partial = functools.partial(filter_data_with_image_ids)
|
592 |
+
|
593 |
+
if isinstance(data_dir, str):
|
594 |
+
data_dir = list(braceexpand(data_dir))
|
595 |
+
|
596 |
+
datapipe = dp.iter.FileLister(root=data_dir, masks='*.jsonl', recursive=True)
|
597 |
+
datapipe = datapipe.shuffle()
|
598 |
+
datapipe = datapipe.cycle(count=cycle_count)
|
599 |
+
datapipe = datapipe.shuffle()
|
600 |
+
# datapipe = dp.iter.FileLister(root=data_dir, masks='0000000.tar', recursive=True)
|
601 |
+
datapipe = datapipe.sharding_filter()
|
602 |
+
# datapipe = datapipe.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
|
603 |
+
|
604 |
+
datapipe = datapipe.open_files(mode='r')
|
605 |
+
datapipe = datapipe.parse_jsonl_files()
|
606 |
+
datapipe = datapipe.map(decode_partial)
|
607 |
+
datapipe = datapipe.filter(filter_partial)
|
608 |
+
|
609 |
+
# datapipe = datapipe.shuffle(buffer_size=1024)
|
610 |
+
if batch_size is not None:
|
611 |
+
datapipe = datapipe.batch(batch_size)
|
612 |
+
datapipe = datapipe.collate(single_turn_edit_collate)
|
613 |
+
return datapipe
|
614 |
+
|
615 |
+
|
616 |
+
def build_multi_datapipes(datapipes, tokenizer=None, image_transform=None, sd_image_transform=None,
|
617 |
+
sample_weights=None):
|
618 |
+
# assert concat_type in ['concat', 'mux_longest', 'sample']
|
619 |
+
if sample_weights is None:
|
620 |
+
sample_weights = [1] * len(datapipes)
|
621 |
+
else:
|
622 |
+
assert len(sample_weights) == len(datapipes)
|
623 |
+
|
624 |
+
datapipes = [
|
625 |
+
hydra.utils.instantiate(datapipe, tokenizer=tokenizer, image_transform=image_transform,
|
626 |
+
sd_image_transform=sd_image_transform) for datapipe in datapipes
|
627 |
+
]
|
628 |
+
|
629 |
+
datasets_to_weights_dict = {}
|
630 |
+
for dataset, sample_weight in zip(datapipes, sample_weights):
|
631 |
+
datasets_to_weights_dict[dataset] = sample_weight
|
632 |
+
datapipe = dp.iter.SampleMultiplexer(datasets_to_weights_dict)
|
633 |
+
|
634 |
+
return datapipe
|
src/eval/gpt_comparative_eval.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from openai import OpenAI
|
3 |
+
import ast
|
4 |
+
import time
|
5 |
+
import os
|
6 |
+
import base64
|
7 |
+
# from PIL import Image
|
8 |
+
import io
|
9 |
+
|
10 |
+
client = OpenAI(
|
11 |
+
base_url="YOUR_URL",
|
12 |
+
api_key="YOUR_KEY",
|
13 |
+
)
|
14 |
+
|
15 |
+
instruction = "Please act as an impartial judge and evaluate the quality of the generation story contents provided by two AI assistants. Your job is to evaluate which assistant's generation is better. Your evaluation should consider the coherence of the generated story images and text. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie."
|
16 |
+
|
17 |
+
# style
|
18 |
+
# instruction = "Please act as an impartial judge and evaluate the quality of the generation story contents provided by two AI assistants. Your job is to evaluate which assistant's generation is better. Your evaluation should consider the style consistency of the story images. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie."
|
19 |
+
|
20 |
+
# text engaging level
|
21 |
+
# instruction = "Please act as an impartial judge and evaluate the quality of the generation story contents provided by two AI assistants. Your job is to evaluate which assistant's generation is better. Your evaluation should consider the engaging level of the story. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie."
|
22 |
+
|
23 |
+
def api_call(messages):
|
24 |
+
try_times = 0
|
25 |
+
while try_times < 3:
|
26 |
+
try:
|
27 |
+
chat_completion = client.chat.completions.create(
|
28 |
+
messages=messages,
|
29 |
+
model="gpt-4-turbo-2024-04-09", #"gpt-4-0125-preview", #"claude-3-opus-20240229", #"gpt-4-1106-preview",
|
30 |
+
max_tokens=4096,
|
31 |
+
temperature=0.3,
|
32 |
+
# stop=['<wait to execute>']
|
33 |
+
)
|
34 |
+
success = True
|
35 |
+
break
|
36 |
+
except Exception as e:
|
37 |
+
print(f"Error during API call: {e}")
|
38 |
+
time.sleep(15)
|
39 |
+
try_times += 1
|
40 |
+
success = False
|
41 |
+
if success:
|
42 |
+
cleaned_string = chat_completion.choices[0].message.content.strip()
|
43 |
+
return cleaned_string
|
44 |
+
else:
|
45 |
+
return None
|
46 |
+
|
47 |
+
|
48 |
+
def encode_image(image_path):
|
49 |
+
with open(image_path, "rb") as image_file:
|
50 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
51 |
+
|
52 |
+
|
53 |
+
def read_json_and_extract_content(filepath):
|
54 |
+
"""
|
55 |
+
Reads a JSON file and extracts sentences and images.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
filepath (str): The path to the JSON file.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
dict: A dictionary with two keys 'sentences' and 'images', containing the respective content.
|
62 |
+
"""
|
63 |
+
with open(filepath, 'r') as file:
|
64 |
+
data = json.load(file)
|
65 |
+
|
66 |
+
all_content = []
|
67 |
+
for line in data:
|
68 |
+
extracted_content = {
|
69 |
+
"sentences": [],
|
70 |
+
"images": []
|
71 |
+
}
|
72 |
+
# Matching sentences to their corresponding images using their indices
|
73 |
+
for ix in line['sentence_ixs']:
|
74 |
+
if ix == 0:
|
75 |
+
continue
|
76 |
+
extracted_content['sentences'].append(line['sentences'][ix].replace('<|beginofimage|>', ''))
|
77 |
+
extracted_content['images'].append(line['images'][ix])
|
78 |
+
all_content.append(extracted_content)
|
79 |
+
|
80 |
+
return all_content
|
81 |
+
|
82 |
+
|
83 |
+
def read_seed_content_from_folders(base_path):
|
84 |
+
"""
|
85 |
+
Reads sentences from text.txt and image paths from subfolders named val_x.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
base_path (str): Path to the main folder containing subfolders val_0 to val_179.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
list of dict: Each dictionary contains 'sentences' and 'images' from each subfolder.
|
92 |
+
"""
|
93 |
+
contents = []
|
94 |
+
|
95 |
+
# Iterate over each possible subfolder val_0 to val_179
|
96 |
+
for i in range(180): # 0 to 179 inclusive
|
97 |
+
folder_name = f"val_{i}"
|
98 |
+
folder_path = os.path.join(base_path, folder_name)
|
99 |
+
|
100 |
+
if os.path.exists(folder_path):
|
101 |
+
content_dict = {
|
102 |
+
"sentences": [],
|
103 |
+
"images": []
|
104 |
+
}
|
105 |
+
|
106 |
+
# Read sentences from text.txt
|
107 |
+
text_file_path = os.path.join(folder_path, 'text.txt')
|
108 |
+
if os.path.isfile(text_file_path):
|
109 |
+
with open(text_file_path, 'r') as file:
|
110 |
+
content_dict['sentences'] = file.read().splitlines()[:6]
|
111 |
+
content_dict['sentences'] = [s.replace('[INST]', '') for s in content_dict['sentences'] ]
|
112 |
+
|
113 |
+
# Collect paths for the images ori_01 to ori_06
|
114 |
+
for j in range(1, 7): # 1 to 6 inclusive
|
115 |
+
image_name = f"ori_0{j}.jpg" # Assuming the images are in .jpg format
|
116 |
+
image_path = os.path.join(folder_path, image_name)
|
117 |
+
if os.path.isfile(image_path):
|
118 |
+
content_dict['images'].append(image_path)
|
119 |
+
|
120 |
+
# Add the content dictionary to the list if it contains any images or sentences
|
121 |
+
if content_dict['sentences'] or content_dict['images']:
|
122 |
+
contents.append(content_dict)
|
123 |
+
|
124 |
+
return contents
|
125 |
+
|
126 |
+
|
127 |
+
def evaluate_models(assistant_a, assistant_b, instruction):
|
128 |
+
# Encode all images to base64
|
129 |
+
images_a_base64 = [encode_image(img_path) for img_path in assistant_a['images'][:5]]
|
130 |
+
images_b_base64 = [encode_image(img_path) for img_path in assistant_b['images'][:5]]
|
131 |
+
|
132 |
+
# Extract the stories from both assistants
|
133 |
+
story_a = assistant_a['sentences']
|
134 |
+
story_b = assistant_b['sentences']
|
135 |
+
|
136 |
+
messages = []
|
137 |
+
# A
|
138 |
+
messages.append(
|
139 |
+
{
|
140 |
+
"role": "user",
|
141 |
+
"content": [
|
142 |
+
{
|
143 |
+
"type": "text",
|
144 |
+
"text": "Story text from Assistant A: {}\n".format(story_a[:5])
|
145 |
+
}
|
146 |
+
]
|
147 |
+
}
|
148 |
+
)
|
149 |
+
messages.append(
|
150 |
+
{
|
151 |
+
"role": "user",
|
152 |
+
"content": [
|
153 |
+
{
|
154 |
+
"type": "text",
|
155 |
+
"text": "Images from Assistant A are encoded in base64.\n"
|
156 |
+
}
|
157 |
+
]
|
158 |
+
}
|
159 |
+
)
|
160 |
+
for img_a in images_a_base64:
|
161 |
+
messages.append({
|
162 |
+
"role": "user",
|
163 |
+
"content": [
|
164 |
+
{
|
165 |
+
"type": "image_url",
|
166 |
+
"image_url": {"url": f"data:image/jpeg;base64,{img_a}"}
|
167 |
+
}
|
168 |
+
]
|
169 |
+
})
|
170 |
+
|
171 |
+
# B
|
172 |
+
messages.append(
|
173 |
+
{
|
174 |
+
"role": "user",
|
175 |
+
"content": [
|
176 |
+
{
|
177 |
+
"type": "text",
|
178 |
+
"text": "Story text from Assistant B: {}\n".format(story_b[:5])
|
179 |
+
}
|
180 |
+
]
|
181 |
+
}
|
182 |
+
)
|
183 |
+
messages.append(
|
184 |
+
{
|
185 |
+
"role": "user",
|
186 |
+
"content": [
|
187 |
+
{
|
188 |
+
"type": "text",
|
189 |
+
"text": "Images from Assistant B are encoded in base64.\n"
|
190 |
+
}
|
191 |
+
]
|
192 |
+
}
|
193 |
+
)
|
194 |
+
for img_b in images_b_base64:
|
195 |
+
messages.append({
|
196 |
+
"role": "user",
|
197 |
+
"content": [
|
198 |
+
{
|
199 |
+
"type": "image_url",
|
200 |
+
"image_url": {"url": f"data:image/jpeg;base64,{img_b}"}
|
201 |
+
}
|
202 |
+
]
|
203 |
+
})
|
204 |
+
|
205 |
+
# INST
|
206 |
+
messages.append(
|
207 |
+
{
|
208 |
+
"role": "user",
|
209 |
+
"content": [
|
210 |
+
{
|
211 |
+
"type": "text",
|
212 |
+
"text": instruction
|
213 |
+
}
|
214 |
+
]
|
215 |
+
}
|
216 |
+
)
|
217 |
+
# Combine stories and encoded images into the evaluation instruction
|
218 |
+
result = api_call(messages)
|
219 |
+
print(result)
|
220 |
+
return result
|
221 |
+
|
222 |
+
def main():
|
223 |
+
# read mm json
|
224 |
+
mm_contents = read_json_and_extract_content('/group/40034/shuaisyang/seed_project/StorySalon/llm_eval/mm_eval.json')
|
225 |
+
seed_contents = read_seed_content_from_folders('/group/40034/shuaisyang/seed_project/StorySalon/llm_eval/gen_george_len7')
|
226 |
+
assert len(mm_contents) == len(seed_contents)
|
227 |
+
mm_win = 0
|
228 |
+
seed_win = 0
|
229 |
+
tie = 0
|
230 |
+
error = []
|
231 |
+
for i in range(len(mm_contents)):
|
232 |
+
# for i in range(2):
|
233 |
+
mm = mm_contents[i]
|
234 |
+
seed = seed_contents[i]
|
235 |
+
judgment = evaluate_models(mm, seed, instruction)
|
236 |
+
|
237 |
+
if "[[A]]" in judgment:
|
238 |
+
mm_win += 1
|
239 |
+
elif "[[B]]" in judgment:
|
240 |
+
seed_win += 1
|
241 |
+
elif "[[C]]" in judgment:
|
242 |
+
tie += 1
|
243 |
+
else:
|
244 |
+
error.append([i, judgment])
|
245 |
+
|
246 |
+
with open('coherence.txt', 'w') as f:
|
247 |
+
f.write("mm:{}\nseed:{}\ntie:{}\nerror:{}".format(mm_win, seed_win, tie, error))
|
248 |
+
|
249 |
+
main()
|
src/eval/gpt_score_eval.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from openai import OpenAI
|
3 |
+
import ast
|
4 |
+
import time
|
5 |
+
import os
|
6 |
+
import base64
|
7 |
+
# from PIL import Image
|
8 |
+
import io
|
9 |
+
import re
|
10 |
+
|
11 |
+
client = OpenAI(
|
12 |
+
base_url="YOUR_URL",
|
13 |
+
api_key="YOUR_KEY",
|
14 |
+
)
|
15 |
+
|
16 |
+
style_instruction = "Please act as an impartial judge and evaluate the quality of the generation story contents provided by an AI assistant. Your job is to give a score out of 10. Your evaluation should consider the style consistency of the story images. Do not allow the length of the responses to influence your evaluation. Be as objective as possible. After providing your explanation, output your final score by strictly following this format: \"[[score]]\", such as \"[[7]]\"."
|
17 |
+
|
18 |
+
engage_instruction = "Please act as an impartial judge and evaluate the quality of the generation story contents provided by an AI assistant. Your job is to give a score out of 10. Your evaluation should consider the engaging level of the story. Do not allow the length of the responses to influence your evaluation. Be as objective as possible. After providing your explanation, output your final score by strictly following this format: \"[[score]]\", such as \"[[7]]\"."
|
19 |
+
|
20 |
+
coherence_instruction = "Please act as an impartial judge and evaluate the quality of the generation story contents provided by an AI assistant. Your job is to give a score out of 10. Your evaluation should consider the coherence of the generated story images and text. Do not allow the length of the responses to influence your evaluation. Be as objective as possible. After providing your explanation, output your final score by strictly following this format: \"[[score]]\", such as \"[[7]]\"."
|
21 |
+
|
22 |
+
def api_call(messages):
|
23 |
+
try_times = 0
|
24 |
+
while try_times < 3:
|
25 |
+
try:
|
26 |
+
chat_completion = client.chat.completions.create(
|
27 |
+
messages=messages,
|
28 |
+
model="gpt-4-turbo-2024-04-09", #"gpt-4-0125-preview", #"claude-3-opus-20240229", #"gpt-4-1106-preview",
|
29 |
+
max_tokens=4096,
|
30 |
+
temperature=0.3,
|
31 |
+
# stop=['<wait to execute>']
|
32 |
+
)
|
33 |
+
success = True
|
34 |
+
break
|
35 |
+
except Exception as e:
|
36 |
+
print(f"Error during API call: {e}")
|
37 |
+
time.sleep(15)
|
38 |
+
try_times += 1
|
39 |
+
success = False
|
40 |
+
if success:
|
41 |
+
cleaned_string = chat_completion.choices[0].message.content.strip()
|
42 |
+
return cleaned_string
|
43 |
+
else:
|
44 |
+
return None
|
45 |
+
|
46 |
+
|
47 |
+
def encode_image(image_path):
|
48 |
+
with open(image_path, "rb") as image_file:
|
49 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
50 |
+
|
51 |
+
|
52 |
+
def read_json_and_extract_content(filepath):
|
53 |
+
"""
|
54 |
+
Reads a JSON file and extracts sentences and images.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
filepath (str): The path to the JSON file.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
dict: A dictionary with two keys 'sentences' and 'images', containing the respective content.
|
61 |
+
"""
|
62 |
+
with open(filepath, 'r') as file:
|
63 |
+
data = json.load(file)
|
64 |
+
|
65 |
+
all_content = []
|
66 |
+
for line in data:
|
67 |
+
extracted_content = {
|
68 |
+
"sentences": [],
|
69 |
+
"images": []
|
70 |
+
}
|
71 |
+
# Matching sentences to their corresponding images using their indices
|
72 |
+
for ix in line['sentence_ixs']:
|
73 |
+
if ix == 0:
|
74 |
+
continue
|
75 |
+
extracted_content['sentences'].append(line['sentences'][ix].replace('<|beginofimage|>', ''))
|
76 |
+
extracted_content['images'].append(line['images'][ix])
|
77 |
+
all_content.append(extracted_content)
|
78 |
+
|
79 |
+
return all_content
|
80 |
+
|
81 |
+
|
82 |
+
def read_seed_content_from_folders(base_path):
|
83 |
+
"""
|
84 |
+
Reads sentences from text.txt and image paths from subfolders named val_x.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
base_path (str): Path to the main folder containing subfolders val_0 to val_179.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
list of dict: Each dictionary contains 'sentences' and 'images' from each subfolder.
|
91 |
+
"""
|
92 |
+
contents = []
|
93 |
+
|
94 |
+
# Iterate over each possible subfolder val_0 to val_179
|
95 |
+
for i in range(180): # 0 to 179 inclusive
|
96 |
+
folder_name = f"val_{i}"
|
97 |
+
folder_path = os.path.join(base_path, folder_name)
|
98 |
+
|
99 |
+
if os.path.exists(folder_path):
|
100 |
+
content_dict = {
|
101 |
+
"sentences": [],
|
102 |
+
"images": []
|
103 |
+
}
|
104 |
+
|
105 |
+
# Read sentences from text.txt
|
106 |
+
text_file_path = os.path.join(folder_path, 'text.txt')
|
107 |
+
if os.path.isfile(text_file_path):
|
108 |
+
with open(text_file_path, 'r') as file:
|
109 |
+
content_dict['sentences'] = file.read().splitlines()[:6]
|
110 |
+
content_dict['sentences'] = [s.replace('[INST]', '') for s in content_dict['sentences'] ]
|
111 |
+
|
112 |
+
# Collect paths for the images ori_01 to ori_06
|
113 |
+
for j in range(1, 7): # 1 to 6 inclusive
|
114 |
+
image_name = f"ori_0{j}.jpg" # Assuming the images are in .jpg format
|
115 |
+
image_path = os.path.join(folder_path, image_name)
|
116 |
+
if os.path.isfile(image_path):
|
117 |
+
content_dict['images'].append(image_path)
|
118 |
+
|
119 |
+
# Add the content dictionary to the list if it contains any images or sentences
|
120 |
+
if content_dict['sentences'] or content_dict['images']:
|
121 |
+
contents.append(content_dict)
|
122 |
+
|
123 |
+
return contents
|
124 |
+
|
125 |
+
|
126 |
+
def evaluate_models(assistant_a, instruction):
|
127 |
+
print(assistant_a, instruction)
|
128 |
+
# Encode all images to base64
|
129 |
+
images_a_base64 = [encode_image(img_path) for img_path in assistant_a['images'][:5]]
|
130 |
+
|
131 |
+
# Extract the stories from both assistants
|
132 |
+
story_a = assistant_a['sentences']
|
133 |
+
|
134 |
+
messages = []
|
135 |
+
# A
|
136 |
+
messages.append(
|
137 |
+
{
|
138 |
+
"role": "user",
|
139 |
+
"content": [
|
140 |
+
{
|
141 |
+
"type": "text",
|
142 |
+
"text": "Story text from Assistant A: {}\n".format(story_a[:5])
|
143 |
+
}
|
144 |
+
]
|
145 |
+
}
|
146 |
+
)
|
147 |
+
messages.append(
|
148 |
+
{
|
149 |
+
"role": "user",
|
150 |
+
"content": [
|
151 |
+
{
|
152 |
+
"type": "text",
|
153 |
+
"text": "Images are encoded in base64.\n"
|
154 |
+
}
|
155 |
+
]
|
156 |
+
}
|
157 |
+
)
|
158 |
+
for img_a in images_a_base64:
|
159 |
+
messages.append({
|
160 |
+
"role": "user",
|
161 |
+
"content": [
|
162 |
+
{
|
163 |
+
"type": "image_url",
|
164 |
+
"image_url": {"url": f"data:image/jpeg;base64,{img_a}"}
|
165 |
+
}
|
166 |
+
]
|
167 |
+
})
|
168 |
+
|
169 |
+
# INST
|
170 |
+
messages.append(
|
171 |
+
{
|
172 |
+
"role": "user",
|
173 |
+
"content": [
|
174 |
+
{
|
175 |
+
"type": "text",
|
176 |
+
"text": instruction
|
177 |
+
}
|
178 |
+
]
|
179 |
+
}
|
180 |
+
)
|
181 |
+
# Combine stories and encoded images into the evaluation instruction
|
182 |
+
result = api_call(messages)
|
183 |
+
print(result)
|
184 |
+
return result
|
185 |
+
|
186 |
+
def find_number_in_string(input_string):
|
187 |
+
# Regular expression to find [[number]]
|
188 |
+
pattern = r'\[\[(\d+)\]\]'
|
189 |
+
match = re.search(pattern, input_string)
|
190 |
+
|
191 |
+
if match:
|
192 |
+
return int(match.group(1)) # Return the number as an integer
|
193 |
+
else:
|
194 |
+
return None # No match found
|
195 |
+
|
196 |
+
|
197 |
+
def main():
|
198 |
+
# read mm json
|
199 |
+
# mm_contents = read_json_and_extract_content('/group/40034/shuaisyang/seed_project/StorySalon/llm_eval/mm_eval.json')
|
200 |
+
seed_contents = read_seed_content_from_folders('/group/40034/shuaisyang/seed_project/StorySalon/llm_eval/gen_george')
|
201 |
+
# assert len(mm_contents) == len(seed_contents)
|
202 |
+
# mm_win = 0
|
203 |
+
seed_win = 0
|
204 |
+
# tie = 0
|
205 |
+
|
206 |
+
error = []
|
207 |
+
metrics = ['style', 'engaging', 'coherence']
|
208 |
+
for idx, ins in enumerate((style_instruction, engage_instruction, coherence_instruction)):
|
209 |
+
total_score = 0
|
210 |
+
scores = ''
|
211 |
+
for i in range(len(seed_contents)):
|
212 |
+
seed = seed_contents[i]
|
213 |
+
judgment = evaluate_models(seed, ins)
|
214 |
+
number_found = find_number_in_string(judgment)
|
215 |
+
scores += str(number_found) + '\n'
|
216 |
+
total_score += number_found
|
217 |
+
|
218 |
+
with open('result_{}.txt'.format(metrics[idx]), 'w') as f:
|
219 |
+
f.write("total:{}\navg:{}\nscores:{}".format(total_score, total_score/len(seed_contents), scores))
|
220 |
+
|
221 |
+
|
222 |
+
main()
|
src/inference/gen_george.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
import hydra
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
import pyrootutils
|
8 |
+
from PIL import Image, ImageDraw, ImageFont
|
9 |
+
import json
|
10 |
+
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, EulerDiscreteScheduler
|
11 |
+
|
12 |
+
pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
|
13 |
+
|
14 |
+
BOI_TOKEN = '<img>'
|
15 |
+
EOI_TOKEN = '</img>'
|
16 |
+
IMG_TOKEN = '<img_{:05d}>'
|
17 |
+
|
18 |
+
device = 'cuda:0'
|
19 |
+
dtype = torch.float16
|
20 |
+
dtype_str = 'fp16'
|
21 |
+
num_img_in_tokens = 64
|
22 |
+
num_img_out_tokens = 64
|
23 |
+
instruction_prompt = '{instruction}'
|
24 |
+
|
25 |
+
tokenizer_cfg_path = 'configs/tokenizer/clm_llama_tokenizer.yaml'
|
26 |
+
image_transform_cfg_path = 'configs/processer/qwen_448_transform.yaml'
|
27 |
+
visual_encoder_cfg_path = 'configs/visual_tokenizer/qwen_vitg_448.yaml'
|
28 |
+
|
29 |
+
llm_cfg_path = 'configs/clm_models/llama2chat7b_lora.yaml'
|
30 |
+
agent_cfg_path = 'configs/clm_models/agent_7b_sft.yaml'
|
31 |
+
|
32 |
+
adapter_cfg_path = 'configs/detokenizer/detokenizer_sdxl_qwen_vit_adapted.yaml'
|
33 |
+
discrete_model_cfg_path = 'configs/discrete_model/discrete_identity.yaml'
|
34 |
+
|
35 |
+
diffusion_model_path = 'pretrained/stable-diffusion-xl-base-1.0'
|
36 |
+
|
37 |
+
save_dir = "output"
|
38 |
+
|
39 |
+
tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path)
|
40 |
+
tokenizer = hydra.utils.instantiate(tokenizer_cfg)
|
41 |
+
|
42 |
+
image_transform_cfg = OmegaConf.load(image_transform_cfg_path)
|
43 |
+
image_transform = hydra.utils.instantiate(image_transform_cfg)
|
44 |
+
|
45 |
+
visual_encoder_cfg = OmegaConf.load(visual_encoder_cfg_path)
|
46 |
+
visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
|
47 |
+
visual_encoder.eval().to(device, dtype=dtype)
|
48 |
+
print('Init visual encoder done')
|
49 |
+
|
50 |
+
llm_cfg = OmegaConf.load(llm_cfg_path)
|
51 |
+
llm = hydra.utils.instantiate(llm_cfg, torch_dtype=dtype_str)
|
52 |
+
print('Init llm done.')
|
53 |
+
|
54 |
+
agent_model_cfg = OmegaConf.load(agent_cfg_path)
|
55 |
+
agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm)
|
56 |
+
|
57 |
+
agent_model.eval().to(device, dtype=dtype)
|
58 |
+
print('Init agent model Done')
|
59 |
+
|
60 |
+
noise_scheduler = EulerDiscreteScheduler.from_pretrained(diffusion_model_path, subfolder="scheduler")
|
61 |
+
print('init vae')
|
62 |
+
vae = AutoencoderKL.from_pretrained(diffusion_model_path, subfolder="vae").to(device, dtype=dtype)
|
63 |
+
print('init unet')
|
64 |
+
unet = UNet2DConditionModel.from_pretrained(diffusion_model_path, subfolder="unet").to(device, dtype=dtype)
|
65 |
+
|
66 |
+
adapter_cfg = OmegaConf.load(adapter_cfg_path)
|
67 |
+
adapter = hydra.utils.instantiate(adapter_cfg, unet=unet).to(device, dtype=dtype).eval()
|
68 |
+
print('Init adapter done')
|
69 |
+
|
70 |
+
discrete_model_cfg = OmegaConf.load(discrete_model_cfg_path)
|
71 |
+
discrete_model = hydra.utils.instantiate(discrete_model_cfg).to(device).eval()
|
72 |
+
print('Init discrete model done')
|
73 |
+
|
74 |
+
adapter.init_pipe(vae=vae,
|
75 |
+
scheduler=noise_scheduler,
|
76 |
+
visual_encoder=visual_encoder,
|
77 |
+
image_transform=image_transform,
|
78 |
+
discrete_model=discrete_model,
|
79 |
+
dtype=dtype,
|
80 |
+
device=device)
|
81 |
+
|
82 |
+
print('Init adapter pipe done')
|
83 |
+
boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
|
84 |
+
eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
|
85 |
+
|
86 |
+
|
87 |
+
def read_jsonl_to_dict(filename):
|
88 |
+
data = []
|
89 |
+
with open(filename, 'r') as file:
|
90 |
+
for line in file:
|
91 |
+
# Each line is a valid JSON object
|
92 |
+
json_object = json.loads(line)
|
93 |
+
data.append(json_object)
|
94 |
+
return data
|
95 |
+
|
96 |
+
|
97 |
+
# data
|
98 |
+
filename = 'data/json/val.jsonl'
|
99 |
+
image_root = 'data/image/george_full'
|
100 |
+
data = read_jsonl_to_dict(filename)
|
101 |
+
image_paths = [
|
102 |
+
os.path.join(image_root, d['images'][0]) for d in data
|
103 |
+
]
|
104 |
+
questions = [
|
105 |
+
d['captions'][0] for d in data
|
106 |
+
]
|
107 |
+
|
108 |
+
|
109 |
+
# texts = [
|
110 |
+
# d['captions'][1:] for d in data
|
111 |
+
# ]
|
112 |
+
|
113 |
+
|
114 |
+
def add_subtitle(original_image, text):
|
115 |
+
# Calculate the size of the new image
|
116 |
+
text_height = 80 # Height of the black bar for the text
|
117 |
+
new_image_size = (original_image.width, original_image.height + text_height)
|
118 |
+
|
119 |
+
# Create a new image with a black background
|
120 |
+
new_image = Image.new("RGB", new_image_size, "black")
|
121 |
+
# Paste the original image onto the new image
|
122 |
+
new_image.paste(original_image, (0, 0))
|
123 |
+
|
124 |
+
# Prepare the new image for drawing
|
125 |
+
draw = ImageDraw.Draw(new_image)
|
126 |
+
|
127 |
+
# Specify the font size and font path
|
128 |
+
font_size = 14 # Adjust font size as needed
|
129 |
+
# font = ImageFont.truetype(font_path, font_size)
|
130 |
+
|
131 |
+
# Manually split the text into two lines
|
132 |
+
line1, line2 = text[:len(text) // 2], text[len(text) // 2:]
|
133 |
+
|
134 |
+
# Update the position for the first line of text to ensure both lines are centered vertically
|
135 |
+
text_position_line1 = (10, original_image.height + (text_height - font_size) // 2)
|
136 |
+
|
137 |
+
# Define the text color
|
138 |
+
text_color = "white"
|
139 |
+
|
140 |
+
# Add the first line of text to the new image
|
141 |
+
draw.text(text_position_line1, line1, fill=text_color)
|
142 |
+
|
143 |
+
# Adjust the position for the second line of text, based on the height of the first line
|
144 |
+
text_position_line2 = (10, text_position_line1[1] + font_size)
|
145 |
+
|
146 |
+
# Add the second line of text to the new image
|
147 |
+
draw.text(text_position_line2, line2, fill=text_color)
|
148 |
+
|
149 |
+
return new_image
|
150 |
+
|
151 |
+
|
152 |
+
for j in range(len(image_paths)):
|
153 |
+
image_path = image_paths[j]
|
154 |
+
question = questions[j]
|
155 |
+
image = Image.open(image_path).convert('RGB')
|
156 |
+
|
157 |
+
save_folder = '{}/val_{}'.format(save_dir, j)
|
158 |
+
|
159 |
+
os.makedirs(save_folder, exist_ok=True)
|
160 |
+
|
161 |
+
init_image = add_subtitle(image, question)
|
162 |
+
save_path = os.path.join(save_folder, '000start_image.jpg')
|
163 |
+
init_image.save(save_path)
|
164 |
+
|
165 |
+
agent_model.llm.base_model.model.use_kv_cache_head = False
|
166 |
+
image_tensor = image_transform(image).unsqueeze(0).to(device, dtype=dtype)
|
167 |
+
|
168 |
+
image_tokens = BOI_TOKEN + ''.join([IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN
|
169 |
+
|
170 |
+
prompt = instruction_prompt.format_map({'instruction': question + image_tokens})
|
171 |
+
print(prompt)
|
172 |
+
print('*' * 20)
|
173 |
+
|
174 |
+
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
175 |
+
input_ids = [tokenizer.bos_token_id] + input_ids
|
176 |
+
|
177 |
+
boi_idx = input_ids.index(boi_token_id)
|
178 |
+
eoi_idx = input_ids.index(eoi_token_id)
|
179 |
+
|
180 |
+
input_ids = torch.tensor(input_ids).to(device, dtype=torch.long).unsqueeze(0)
|
181 |
+
|
182 |
+
ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
183 |
+
|
184 |
+
ids_cmp_mask[0, boi_idx + 1:eoi_idx] = True
|
185 |
+
embeds_cmp_mask = torch.tensor([True]).to(device, dtype=torch.bool)
|
186 |
+
|
187 |
+
with torch.no_grad():
|
188 |
+
image_embeds = visual_encoder(image_tensor)
|
189 |
+
output = agent_model.generate(tokenizer=tokenizer,
|
190 |
+
input_ids=input_ids,
|
191 |
+
image_embeds=image_embeds,
|
192 |
+
embeds_cmp_mask=embeds_cmp_mask,
|
193 |
+
ids_cmp_mask=ids_cmp_mask,
|
194 |
+
max_new_tokens=500,
|
195 |
+
num_img_gen_tokens=num_img_out_tokens)
|
196 |
+
text = re.sub(r'\s*<[^>]*>\s*', ' ', output['text']).strip()
|
197 |
+
|
198 |
+
with open("{}/text.txt".format(save_folder), 'a+') as text_file:
|
199 |
+
text_file.write(text + '\n')
|
200 |
+
with open("{}/token.txt".format(save_folder), 'a+') as token_file:
|
201 |
+
token_file.write("context token: {}\n".format(input_ids.shape))
|
202 |
+
print(output['text'])
|
203 |
+
print('*' * 20)
|
204 |
+
|
205 |
+
story_len = 25
|
206 |
+
window_size = 8
|
207 |
+
text_id = 1
|
208 |
+
while output['has_img_output'] and image_embeds.shape[0] < story_len:
|
209 |
+
image_embeds_gen = output['img_gen_feat']
|
210 |
+
images_gen = adapter.generate(image_embeds=output['img_gen_feat'], num_inference_steps=50)
|
211 |
+
|
212 |
+
name = '{:02d}.jpg'.format(text_id)
|
213 |
+
save_path = os.path.join(save_folder, name)
|
214 |
+
|
215 |
+
# Open the generated image
|
216 |
+
original_image = images_gen[0]
|
217 |
+
ori_path = os.path.join(save_folder, 'ori_{:02d}.jpg'.format(text_id))
|
218 |
+
original_image.save(ori_path)
|
219 |
+
|
220 |
+
new_image = add_subtitle(original_image, text)
|
221 |
+
# Save the modified image
|
222 |
+
new_image.save(save_path)
|
223 |
+
|
224 |
+
image_embeds = torch.cat((image_embeds, image_embeds_gen), dim=0)
|
225 |
+
|
226 |
+
# image_embeds = torch.cat((image_embeds, image_embeds_gen), dim=0)
|
227 |
+
|
228 |
+
if text_id >= story_len - 1:
|
229 |
+
break
|
230 |
+
|
231 |
+
prompt = prompt + text + image_tokens
|
232 |
+
text_id += 1
|
233 |
+
|
234 |
+
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
235 |
+
while image_embeds.shape[0] > window_size:
|
236 |
+
eoi_prompt_idx = prompt.index(EOI_TOKEN)
|
237 |
+
prompt = prompt[eoi_prompt_idx + len(EOI_TOKEN) + len('[INST]'):]
|
238 |
+
image_embeds = image_embeds[1:]
|
239 |
+
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
240 |
+
|
241 |
+
print(prompt)
|
242 |
+
print('*' * 20)
|
243 |
+
|
244 |
+
input_ids = [tokenizer.bos_token_id] + input_ids
|
245 |
+
|
246 |
+
boi_idx = torch.where(torch.tensor(input_ids) == boi_token_id)[0].tolist()
|
247 |
+
eoi_idx = torch.where(torch.tensor(input_ids) == eoi_token_id)[0].tolist()
|
248 |
+
|
249 |
+
input_ids = torch.tensor(input_ids).to(device, dtype=torch.long).unsqueeze(0)
|
250 |
+
|
251 |
+
ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
252 |
+
|
253 |
+
for i in range(image_embeds.shape[0]):
|
254 |
+
ids_cmp_mask[0, boi_idx[i] + 1:eoi_idx[i]] = True
|
255 |
+
embeds_cmp_mask = torch.tensor([True] * image_embeds.shape[0]).to(device, dtype=torch.bool)
|
256 |
+
|
257 |
+
output = agent_model.generate(tokenizer=tokenizer,
|
258 |
+
input_ids=input_ids,
|
259 |
+
image_embeds=image_embeds,
|
260 |
+
embeds_cmp_mask=embeds_cmp_mask,
|
261 |
+
ids_cmp_mask=ids_cmp_mask,
|
262 |
+
max_new_tokens=500,
|
263 |
+
num_img_gen_tokens=num_img_out_tokens)
|
264 |
+
text = re.sub(r'\s*<[^>]*>\s*', ' ', output['text']).strip()
|
265 |
+
print(output['text'])
|
266 |
+
print('*' * 20)
|
267 |
+
with open("{}/text.txt".format(save_folder), 'a+') as text_file:
|
268 |
+
text_file.write(text + '\n')
|
269 |
+
with open("{}/token.txt".format(save_folder), 'a+') as token_file:
|
270 |
+
token_file.write("context token: {}\n".format(input_ids.shape))
|
src/inference/vis_george_sink.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hydra
|
2 |
+
from omegaconf import OmegaConf
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import pyrootutils
|
7 |
+
from PIL import Image, ImageDraw, ImageFont
|
8 |
+
import json
|
9 |
+
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, EulerDiscreteScheduler
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import numpy as np
|
12 |
+
from collections import Counter
|
13 |
+
import time
|
14 |
+
|
15 |
+
pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
|
16 |
+
|
17 |
+
BOI_TOKEN = '<img>'
|
18 |
+
EOI_TOKEN = '</img>'
|
19 |
+
IMG_TOKEN = '<img_{:05d}>'
|
20 |
+
|
21 |
+
device = 'cuda:0'
|
22 |
+
dtype = torch.float16
|
23 |
+
dtype_str = 'fp16'
|
24 |
+
num_img_in_tokens = 64
|
25 |
+
num_img_out_tokens = 64
|
26 |
+
instruction_prompt = '{instruction}'
|
27 |
+
|
28 |
+
tokenizer_cfg_path = 'configs/tokenizer/clm_llama_tokenizer.yaml'
|
29 |
+
image_transform_cfg_path = 'configs/processer/qwen_448_transform.yaml'
|
30 |
+
visual_encoder_cfg_path = 'configs/visual_tokenizer/qwen_vitg_448.yaml'
|
31 |
+
|
32 |
+
llm_cfg_path = 'configs/clm_models/llama2chat7b_lora.yaml'
|
33 |
+
agent_cfg_path = 'configs/clm_models/agent_7b_sft.yaml'
|
34 |
+
|
35 |
+
adapter_cfg_path = 'configs/detokenizer/detokenizer_sdxl_qwen_vit_adapted.yaml'
|
36 |
+
discrete_model_cfg_path = 'configs/discrete_model/discrete_identity.yaml'
|
37 |
+
|
38 |
+
diffusion_model_path = 'pretrained/stable-diffusion-xl-base-1.0'
|
39 |
+
|
40 |
+
save_dir = "output"
|
41 |
+
|
42 |
+
cache_mode = 'img_head_tail'
|
43 |
+
# init
|
44 |
+
tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path)
|
45 |
+
tokenizer = hydra.utils.instantiate(tokenizer_cfg)
|
46 |
+
|
47 |
+
image_transform_cfg = OmegaConf.load(image_transform_cfg_path)
|
48 |
+
image_transform = hydra.utils.instantiate(image_transform_cfg)
|
49 |
+
|
50 |
+
visual_encoder_cfg = OmegaConf.load(visual_encoder_cfg_path)
|
51 |
+
visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
|
52 |
+
visual_encoder.eval().to(device, dtype=dtype)
|
53 |
+
print('Init visual encoder done')
|
54 |
+
|
55 |
+
llm_cfg = OmegaConf.load(llm_cfg_path)
|
56 |
+
llm = hydra.utils.instantiate(llm_cfg, torch_dtype=dtype_str)
|
57 |
+
print('Init llm done.')
|
58 |
+
|
59 |
+
agent_model_cfg = OmegaConf.load(agent_cfg_path)
|
60 |
+
agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm)
|
61 |
+
|
62 |
+
agent_model.eval().to(device, dtype=dtype)
|
63 |
+
print('Init agent model Done')
|
64 |
+
|
65 |
+
noise_scheduler = EulerDiscreteScheduler.from_pretrained(diffusion_model_path, subfolder="scheduler")
|
66 |
+
print('init vae')
|
67 |
+
vae = AutoencoderKL.from_pretrained(diffusion_model_path, subfolder="vae").to(device, dtype=dtype)
|
68 |
+
print('init unet')
|
69 |
+
unet = UNet2DConditionModel.from_pretrained(diffusion_model_path, subfolder="unet").to(device, dtype=dtype)
|
70 |
+
|
71 |
+
adapter_cfg = OmegaConf.load(adapter_cfg_path)
|
72 |
+
adapter = hydra.utils.instantiate(adapter_cfg, unet=unet).to(device, dtype=dtype).eval()
|
73 |
+
print('Init adapter done')
|
74 |
+
|
75 |
+
discrete_model_cfg = OmegaConf.load(discrete_model_cfg_path)
|
76 |
+
discrete_model = hydra.utils.instantiate(discrete_model_cfg).to(device).eval()
|
77 |
+
print('Init discrete model done')
|
78 |
+
|
79 |
+
adapter.init_pipe(vae=vae,
|
80 |
+
scheduler=noise_scheduler,
|
81 |
+
visual_encoder=visual_encoder,
|
82 |
+
image_transform=image_transform,
|
83 |
+
discrete_model=discrete_model,
|
84 |
+
dtype=dtype,
|
85 |
+
device=device)
|
86 |
+
|
87 |
+
print('Init adapter pipe done')
|
88 |
+
boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
|
89 |
+
eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
|
90 |
+
|
91 |
+
|
92 |
+
def read_jsonl_to_dict(filename):
|
93 |
+
data = []
|
94 |
+
with open(filename, 'r') as file:
|
95 |
+
for line in file:
|
96 |
+
# Each line is a valid JSON object
|
97 |
+
json_object = json.loads(line)
|
98 |
+
data.append(json_object)
|
99 |
+
return data
|
100 |
+
|
101 |
+
|
102 |
+
# data
|
103 |
+
filename = 'data/json/val.jsonl'
|
104 |
+
image_root = 'data/image/george_full'
|
105 |
+
data = read_jsonl_to_dict(filename)
|
106 |
+
image_paths = [
|
107 |
+
os.path.join(image_root, d['images'][0]) for d in data
|
108 |
+
]
|
109 |
+
starting_texts = [
|
110 |
+
d['captions'][0] for d in data
|
111 |
+
]
|
112 |
+
|
113 |
+
texts = [
|
114 |
+
d['captions'][1:] for d in data
|
115 |
+
]
|
116 |
+
|
117 |
+
def add_subtitle(original_image, text):
|
118 |
+
# Calculate the size of the new image
|
119 |
+
text_height = 80 # Height of the black bar for the text
|
120 |
+
new_image_size = (original_image.width, original_image.height + text_height)
|
121 |
+
|
122 |
+
# Create a new image with a black background
|
123 |
+
new_image = Image.new("RGB", new_image_size, "black")
|
124 |
+
# Paste the original image onto the new image
|
125 |
+
new_image.paste(original_image, (0, 0))
|
126 |
+
|
127 |
+
# Prepare the new image for drawing
|
128 |
+
draw = ImageDraw.Draw(new_image)
|
129 |
+
|
130 |
+
# Specify the font size and font path
|
131 |
+
font_size = 14 # Adjust font size as needed
|
132 |
+
# font = ImageFont.truetype(font_path, font_size)
|
133 |
+
|
134 |
+
# Manually split the text into two lines
|
135 |
+
line1, line2 = text[:len(text) // 2], text[len(text) // 2:]
|
136 |
+
|
137 |
+
# Update the position for the first line of text to ensure both lines are centered vertically
|
138 |
+
text_position_line1 = (10, original_image.height + (text_height - font_size) // 2)
|
139 |
+
|
140 |
+
# Define the text color
|
141 |
+
text_color = "white"
|
142 |
+
|
143 |
+
# Add the first line of text to the new image
|
144 |
+
draw.text(text_position_line1, line1, fill=text_color)
|
145 |
+
|
146 |
+
# Adjust the position for the second line of text, based on the height of the first line
|
147 |
+
text_position_line2 = (10, text_position_line1[1] + font_size)
|
148 |
+
|
149 |
+
# Add the second line of text to the new image
|
150 |
+
draw.text(text_position_line2, line2, fill=text_color)
|
151 |
+
|
152 |
+
return new_image
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
for j in range(len(image_paths)):
|
157 |
+
image_path = image_paths[j]
|
158 |
+
starting_text = starting_texts[j]
|
159 |
+
text_seq = texts[j]
|
160 |
+
image = Image.open(image_path).convert('RGB')
|
161 |
+
|
162 |
+
save_folder = '{}/val_{}'.format(save_dir, j)
|
163 |
+
|
164 |
+
os.makedirs(save_folder, exist_ok=True)
|
165 |
+
|
166 |
+
init_image = add_subtitle(image, starting_text)
|
167 |
+
save_path = os.path.join(save_folder, '000start_image.jpg')
|
168 |
+
init_image.save(save_path)
|
169 |
+
|
170 |
+
sink_kv_cache = []
|
171 |
+
agent_model.llm.base_model.model.kv_cache_head = None
|
172 |
+
agent_model.llm.base_model.model.past_key_values = None
|
173 |
+
agent_model.llm.base_model.model.use_kv_cache_head = False
|
174 |
+
|
175 |
+
image_tensor = image_transform(image).unsqueeze(0).to(device, dtype=dtype)
|
176 |
+
|
177 |
+
image_tokens = BOI_TOKEN + ''.join([IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN
|
178 |
+
|
179 |
+
text = text_seq[0]
|
180 |
+
prompt = instruction_prompt.format_map({'instruction': starting_text + image_tokens}) + text
|
181 |
+
print(prompt)
|
182 |
+
print('*' * 20)
|
183 |
+
|
184 |
+
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
185 |
+
input_ids = [tokenizer.bos_token_id] + input_ids
|
186 |
+
|
187 |
+
boi_idx = input_ids.index(boi_token_id)
|
188 |
+
eoi_idx = input_ids.index(eoi_token_id)
|
189 |
+
|
190 |
+
input_ids = torch.tensor(input_ids).to(device, dtype=torch.long).unsqueeze(0)
|
191 |
+
|
192 |
+
ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
193 |
+
|
194 |
+
ids_cmp_mask[0, boi_idx + 1:eoi_idx] = True
|
195 |
+
embeds_cmp_mask = torch.tensor([True]).to(device, dtype=torch.bool)
|
196 |
+
|
197 |
+
with torch.no_grad():
|
198 |
+
image_embeds = visual_encoder(image_tensor)
|
199 |
+
left = 0
|
200 |
+
right = input_ids.shape[1]
|
201 |
+
output = agent_model.generate(tokenizer=tokenizer,
|
202 |
+
input_ids=input_ids,
|
203 |
+
image_embeds=image_embeds,
|
204 |
+
embeds_cmp_mask=embeds_cmp_mask,
|
205 |
+
ids_cmp_mask=ids_cmp_mask,
|
206 |
+
max_new_tokens=500,
|
207 |
+
num_img_gen_tokens=num_img_out_tokens,
|
208 |
+
)
|
209 |
+
with open("{}/text.txt".format(save_folder), 'a+') as text_file:
|
210 |
+
text_file.write(text + '\n')
|
211 |
+
with open("{}/token.txt".format(save_folder), 'a+') as token_file:
|
212 |
+
token_file.write("context token: {} boi_idx: {}\n".format(input_ids.shape, boi_idx))
|
213 |
+
|
214 |
+
story_len = 25
|
215 |
+
window_size = 8
|
216 |
+
text_id = 1
|
217 |
+
while output['has_img_output'] and image_embeds.shape[0] < story_len:
|
218 |
+
image_embeds_gen = output['img_gen_feat']
|
219 |
+
images_gen = adapter.generate(image_embeds=output['img_gen_feat'], num_inference_steps=50)
|
220 |
+
|
221 |
+
name = '{:02d}.jpg'.format(text_id)
|
222 |
+
save_path = os.path.join(save_folder, name)
|
223 |
+
|
224 |
+
# Open the generated image
|
225 |
+
original_image = images_gen[0]
|
226 |
+
ori_path = os.path.join(save_folder, 'ori_{:02d}.jpg'.format(text_id))
|
227 |
+
original_image.save(ori_path)
|
228 |
+
|
229 |
+
new_image = add_subtitle(original_image, text)
|
230 |
+
# Save the modified image
|
231 |
+
new_image.save(save_path)
|
232 |
+
|
233 |
+
image_embeds = torch.cat((image_embeds, image_embeds_gen), dim=0)
|
234 |
+
|
235 |
+
# next gen
|
236 |
+
text = text_seq[text_id]
|
237 |
+
text_id += 1
|
238 |
+
|
239 |
+
# image_embeds = torch.cat((image_embeds, image_embeds_gen), dim=0)
|
240 |
+
if text_id >= story_len - 1:
|
241 |
+
break
|
242 |
+
|
243 |
+
past_key_values = [[kv[:, :, :input_ids.shape[1], :] for kv in l] for l in output['past_key_values']]
|
244 |
+
agent_model.llm.base_model.model.kv_cache_head = input_ids.shape[1]
|
245 |
+
|
246 |
+
prompt = prompt + image_tokens + text
|
247 |
+
next_input_ids = tokenizer.encode(image_tokens + text, add_special_tokens=False)
|
248 |
+
next_input_ids = torch.tensor(next_input_ids).to(device, dtype=torch.long).unsqueeze(0)
|
249 |
+
input_ids = torch.cat((input_ids, next_input_ids), dim=1)
|
250 |
+
left = right
|
251 |
+
right = input_ids.shape[1]
|
252 |
+
|
253 |
+
|
254 |
+
while image_embeds.shape[0] > window_size:
|
255 |
+
|
256 |
+
eoi_prompt_idx = prompt.index(EOI_TOKEN)
|
257 |
+
prompt = prompt[eoi_prompt_idx + len(EOI_TOKEN) :]
|
258 |
+
|
259 |
+
boi_idx = torch.where(input_ids == boi_token_id)[1].tolist()
|
260 |
+
eoi_idx = torch.where(input_ids == eoi_token_id)[1].tolist()
|
261 |
+
|
262 |
+
image_embeds = image_embeds[1:]
|
263 |
+
input_ids = input_ids[:, eoi_idx[0]+1:]
|
264 |
+
|
265 |
+
# slice kv cache
|
266 |
+
if cache_mode == 'img_head_tail':
|
267 |
+
if len(sink_kv_cache) == 0:
|
268 |
+
sink_kv_cache = [
|
269 |
+
[
|
270 |
+
kv[:, :, :4, :] for kv in l
|
271 |
+
] for l in past_key_values
|
272 |
+
]
|
273 |
+
sink_kv_cache = [
|
274 |
+
[
|
275 |
+
torch.cat(
|
276 |
+
(sink_kv_cache[l_idx][kv_idx],
|
277 |
+
kv[:, :, boi_idx[0] - 4:boi_idx[0] + 8, :],
|
278 |
+
kv[:, :, eoi_idx[0] - 8:eoi_idx[0] + 4, :]),
|
279 |
+
dim=2
|
280 |
+
) for kv_idx, kv in enumerate(l)
|
281 |
+
] for l_idx, l in enumerate(past_key_values)
|
282 |
+
]
|
283 |
+
past_key_values = [
|
284 |
+
[
|
285 |
+
torch.cat(
|
286 |
+
(sink_kv_cache[l_idx][kv_idx],
|
287 |
+
kv[:, :, eoi_idx[0] + sink_kv_cache[0][0].shape[2] + 1:, :]),
|
288 |
+
dim=2
|
289 |
+
) for kv_idx, kv in enumerate(l)
|
290 |
+
] for l_idx, l in enumerate(past_key_values)
|
291 |
+
]
|
292 |
+
# slice Left right
|
293 |
+
agent_model.llm.base_model.model.kv_cache_head -= eoi_idx[0] + 1
|
294 |
+
left -= eoi_idx[0] + 1
|
295 |
+
right -= eoi_idx[0] + 1
|
296 |
+
|
297 |
+
print("prompt: {}".format(prompt))
|
298 |
+
print('*' * 20)
|
299 |
+
|
300 |
+
boi_idx = torch.where(input_ids == boi_token_id)[1].tolist()
|
301 |
+
eoi_idx = torch.where(input_ids == eoi_token_id)[1].tolist()
|
302 |
+
|
303 |
+
ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
304 |
+
|
305 |
+
for i in range(image_embeds.shape[0]):
|
306 |
+
ids_cmp_mask[0, boi_idx[i] + 1:eoi_idx[i]] = True
|
307 |
+
embeds_cmp_mask = torch.tensor([True] * image_embeds.shape[0]).to(device, dtype=torch.bool)
|
308 |
+
|
309 |
+
output = agent_model.generate(tokenizer=tokenizer,
|
310 |
+
input_ids=input_ids,
|
311 |
+
image_embeds=image_embeds,
|
312 |
+
embeds_cmp_mask=embeds_cmp_mask,
|
313 |
+
ids_cmp_mask=ids_cmp_mask,
|
314 |
+
max_new_tokens=500,
|
315 |
+
num_img_gen_tokens=num_img_out_tokens,
|
316 |
+
past_key_values=None)
|
317 |
+
with open("{}/text.txt".format(save_folder), 'a+') as text_file:
|
318 |
+
text_file.write(text + '\n')
|
319 |
+
with open("{}/token.txt".format(save_folder), 'a+') as token_file:
|
320 |
+
token_file.write("context token: {} boi_idx: {}\n".format(input_ids.shape, boi_idx))
|
src/models/__init__.py
ADDED
File without changes
|
src/models/discrete_models.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import pyrootutils
|
4 |
+
import torch.distributed as dist
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
|
8 |
+
from src.train.dist_utils import concat_all_gather
|
9 |
+
|
10 |
+
|
11 |
+
def cosine_loss(rec, target):
|
12 |
+
target = target / target.norm(dim=-1, keepdim=True)
|
13 |
+
rec = rec / rec.norm(dim=-1, keepdim=True)
|
14 |
+
rec_loss = (1 - (target * rec).sum(-1)).mean()
|
15 |
+
return rec_loss
|
16 |
+
|
17 |
+
|
18 |
+
def contrastive_loss(image_feats, text_feats, logit_scale):
|
19 |
+
image_feats = image_feats.unsqueeze(1).contiguous()
|
20 |
+
image_feats_all = concat_all_gather(image_feats) # [batch_size*num_gpu, num_query_tokens, embed_dim]
|
21 |
+
text_feats_all = concat_all_gather(text_feats) # [batch_size*num_gpu, embed_dim]
|
22 |
+
|
23 |
+
sim_q2t = torch.matmul(image_feats.unsqueeze(1), text_feats_all.unsqueeze(-1)).squeeze()
|
24 |
+
# [batch_size, batch_size*num_gpu, num_query_tokens]
|
25 |
+
|
26 |
+
# image-text similarity: aggregate across all query tokens
|
27 |
+
# sim_i2t, _ = sim_q2t.max(-1)
|
28 |
+
# sim_i2t = sim_q2t.mean(-1)
|
29 |
+
sim_i2t = sim_q2t
|
30 |
+
sim_i2t = sim_i2t / logit_scale
|
31 |
+
|
32 |
+
# text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens]
|
33 |
+
sim_t2q = torch.matmul(text_feats.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)).squeeze()
|
34 |
+
|
35 |
+
# print(image_feats_all.shape, text_feat_all.shape, sim_q2t.shape, sim_t2q.shape)
|
36 |
+
# text-image similarity: aggregate across all query tokens
|
37 |
+
# sim_t2i, _ = sim_t2q.max(-1)
|
38 |
+
# sim_t2i = sim_t2q.mean(-1)
|
39 |
+
sim_t2i = sim_t2q
|
40 |
+
sim_t2i = sim_t2i / logit_scale # [batch_size, batch_size*num_gpu]
|
41 |
+
|
42 |
+
rank = dist.get_rank()
|
43 |
+
bs = image_feats.size(0)
|
44 |
+
targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(image_feats.device)
|
45 |
+
|
46 |
+
loss_itc = (F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) +
|
47 |
+
F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)) / 2
|
48 |
+
|
49 |
+
i2t_acc = (sim_i2t.argmax(-1) == targets).sum() / len(sim_i2t)
|
50 |
+
t2i_acc = (sim_t2i.argmax(-1) == targets).sum() / len(sim_t2i)
|
51 |
+
|
52 |
+
return loss_itc, i2t_acc, t2i_acc
|
53 |
+
|
54 |
+
|
55 |
+
class DiscreteModleOnlyDistill(nn.Module):
|
56 |
+
|
57 |
+
def __init__(self,
|
58 |
+
qformer,
|
59 |
+
quantizer,
|
60 |
+
distiller=None,
|
61 |
+
loss_type='cosine',
|
62 |
+
scale_commit_loss=1.0,
|
63 |
+
freeze_qformer=False) -> None:
|
64 |
+
super().__init__()
|
65 |
+
self.qformer = qformer
|
66 |
+
self.quantizer = quantizer
|
67 |
+
self.distiller = distiller
|
68 |
+
self.loss_type = loss_type
|
69 |
+
self.scale_commit_loss = scale_commit_loss
|
70 |
+
|
71 |
+
self.freeze_qformer = freeze_qformer
|
72 |
+
|
73 |
+
if freeze_qformer:
|
74 |
+
self.qformer.requires_grad_(False)
|
75 |
+
|
76 |
+
def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None):
|
77 |
+
if self.freeze_qformer:
|
78 |
+
with torch.no_grad():
|
79 |
+
qforemr_embeds = self.qformer(image_embeds=image_embeds)
|
80 |
+
else:
|
81 |
+
qforemr_embeds = self.qformer(image_embeds=image_embeds)
|
82 |
+
|
83 |
+
quantizer_output = self.quantizer(qforemr_embeds)
|
84 |
+
recon_embeds = self.distiller(quantizer_output['quant_embeds'])
|
85 |
+
|
86 |
+
if self.loss_type == 'cosine':
|
87 |
+
distill_loss = cosine_loss(recon_embeds, image_embeds)
|
88 |
+
else:
|
89 |
+
raise NotImplementedError
|
90 |
+
|
91 |
+
total_loss = distill_loss + self.scale_commit_loss * \
|
92 |
+
quantizer_output['commit_loss']
|
93 |
+
|
94 |
+
return {
|
95 |
+
'total_loss': total_loss,
|
96 |
+
'distill_loss': distill_loss,
|
97 |
+
'commit_loss': quantizer_output['commit_loss'],
|
98 |
+
'indices': quantizer_output['indices']
|
99 |
+
}
|
100 |
+
|
101 |
+
def encode_image_embeds(self, image_embeds):
|
102 |
+
qforemr_embeds = self.qformer(image_embeds=image_embeds)
|
103 |
+
quantizer_output = self.quantizer(qforemr_embeds)
|
104 |
+
|
105 |
+
output_embeds = quantizer_output['quant_embeds']
|
106 |
+
if self.distiller is not None:
|
107 |
+
output_embeds = self.distiller(output_embeds)
|
108 |
+
return output_embeds
|
109 |
+
|
110 |
+
@classmethod
|
111 |
+
def from_pretrained(cls, qformer, quantizer, distiller=None, pretrained_model_path=None, **kwargs):
|
112 |
+
model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, **kwargs)
|
113 |
+
if pretrained_model_path is not None:
|
114 |
+
ckpt = torch.load(pretrained_model_path, map_location='cpu')
|
115 |
+
missing, unexpected = model.load_state_dict(ckpt, strict=False)
|
116 |
+
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
|
117 |
+
return model
|
118 |
+
|
119 |
+
|
120 |
+
class DiscreteModleIdentity(nn.Module):
|
121 |
+
|
122 |
+
def __init__(self) -> None:
|
123 |
+
super().__init__()
|
124 |
+
self.model = nn.Identity()
|
125 |
+
|
126 |
+
def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None):
|
127 |
+
return
|
128 |
+
|
129 |
+
def encode_image_embeds(self, image_embeds):
|
130 |
+
return self.model(image_embeds)
|
131 |
+
|
132 |
+
|
133 |
+
class DiscreteModleStageOneContrastive(nn.Module):
|
134 |
+
|
135 |
+
def __init__(self, qformer, quantizer=None, distiller=None, projection_dim=1024,
|
136 |
+
image_cls_token_type='last') -> None:
|
137 |
+
super().__init__()
|
138 |
+
self.qformer = qformer
|
139 |
+
self.quantizer = quantizer
|
140 |
+
self.distiller = distiller
|
141 |
+
self.image_cls_token_type = image_cls_token_type
|
142 |
+
self.logit_scale = nn.Parameter(0.07 * torch.ones([]))
|
143 |
+
self.image_proj = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
|
144 |
+
self.text_proj = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
|
145 |
+
|
146 |
+
def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None):
|
147 |
+
image_embeds = self.qformer(image_embeds=image_embeds)
|
148 |
+
if self.image_cls_token_type == 'last':
|
149 |
+
image_embeds = image_embeds[:, -1, :]
|
150 |
+
else:
|
151 |
+
raise NotImplementedError
|
152 |
+
|
153 |
+
text_embeds = self.qformer(input_ids=input_ids, text_attention_mask=text_attention_mask)
|
154 |
+
text_embeds = text_embeds[:, 0, :]
|
155 |
+
|
156 |
+
image_embeds = F.normalize(self.image_proj(image_embeds), dim=-1)
|
157 |
+
text_embeds = F.normalize(self.text_proj(text_embeds), dim=-1)
|
158 |
+
|
159 |
+
contrast_loss, i2t_acc, t2i_acc = contrastive_loss(image_feats=image_embeds,
|
160 |
+
text_feats=text_embeds,
|
161 |
+
logit_scale=self.logit_scale)
|
162 |
+
|
163 |
+
return {
|
164 |
+
'total_loss': contrast_loss,
|
165 |
+
'i2t_acc': i2t_acc,
|
166 |
+
't2i_acc': t2i_acc,
|
167 |
+
}
|
168 |
+
|
169 |
+
def encode_image_embeds(self, image_embeds):
|
170 |
+
image_embeds = self.qformer(image_embeds=image_embeds)
|
171 |
+
|
172 |
+
return image_embeds
|
173 |
+
|
174 |
+
@classmethod
|
175 |
+
def from_pretrained(cls, qformer, quantizer, distiller=None, pretrained_model_path=None, **kwargs):
|
176 |
+
model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, **kwargs)
|
177 |
+
if pretrained_model_path is not None:
|
178 |
+
ckpt = torch.load(pretrained_model_path, map_location='cpu')
|
179 |
+
missing, unexpected = model.load_state_dict(ckpt, strict=False)
|
180 |
+
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
|
181 |
+
return model
|
182 |
+
|
183 |
+
|
184 |
+
class DiscreteModleStageTwoContrastiveDistill(nn.Module):
|
185 |
+
|
186 |
+
def __init__(self,
|
187 |
+
qformer,
|
188 |
+
quantizer=None,
|
189 |
+
distiller=None,
|
190 |
+
contrast_head=None,
|
191 |
+
projection_dim=1024,
|
192 |
+
distill_loss_type='cosine',
|
193 |
+
freeze_qformer=True,
|
194 |
+
image_cls_token_type='last',
|
195 |
+
scale_commit_loss=1.0,
|
196 |
+
scale_contrast_loss=1.0,
|
197 |
+
scale_distill_loss=1.0) -> None:
|
198 |
+
super().__init__()
|
199 |
+
self.qformer = qformer
|
200 |
+
self.quantizer = quantizer
|
201 |
+
self.distiller = distiller
|
202 |
+
self.contrast_head = contrast_head
|
203 |
+
self.distill_loss_type = distill_loss_type
|
204 |
+
self.image_cls_token_type = image_cls_token_type
|
205 |
+
if self.contrast_head is not None:
|
206 |
+
self.logit_scale = nn.Parameter(0.07 * torch.ones([]))
|
207 |
+
self.image_proj = nn.Linear(contrast_head.perceiver.config.projection_dim, projection_dim, bias=False)
|
208 |
+
self.text_proj = nn.Linear(contrast_head.perceiver.config.projection_dim, projection_dim, bias=False)
|
209 |
+
|
210 |
+
self.freeze_qformer = freeze_qformer
|
211 |
+
if freeze_qformer:
|
212 |
+
self.qformer.requires_grad_(False)
|
213 |
+
|
214 |
+
self.scale_commit_loss = scale_commit_loss
|
215 |
+
self.scale_contrast_loss = scale_contrast_loss
|
216 |
+
self.scale_distill_loss = scale_distill_loss
|
217 |
+
|
218 |
+
def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None):
|
219 |
+
if self.freeze_qformer:
|
220 |
+
with torch.no_grad():
|
221 |
+
qforemr_embeds = self.qformer(image_embeds=image_embeds)
|
222 |
+
else:
|
223 |
+
qforemr_embeds = self.qformer(image_embeds=image_embeds)
|
224 |
+
|
225 |
+
quantizer_output = self.quantizer(qforemr_embeds)
|
226 |
+
|
227 |
+
output_state = {}
|
228 |
+
output_state['indices'] = quantizer_output['indices']
|
229 |
+
output_state['commit_loss'] = quantizer_output['commit_loss']
|
230 |
+
output_state['total_loss'] = self.scale_commit_loss * quantizer_output['commit_loss']
|
231 |
+
if self.distiller is not None:
|
232 |
+
recon_embeds = self.distiller(quantizer_output['quant_embeds'])
|
233 |
+
|
234 |
+
if self.distill_loss_type == 'cosine':
|
235 |
+
distill_loss = cosine_loss(recon_embeds, image_embeds)
|
236 |
+
else:
|
237 |
+
raise NotImplementedError
|
238 |
+
|
239 |
+
output_state['distill_loss'] = distill_loss
|
240 |
+
output_state['total_loss'] += self.scale_distill_loss * distill_loss
|
241 |
+
|
242 |
+
if self.contrast_head is not None:
|
243 |
+
text_embeds = self.qformer(input_ids=input_ids, text_attention_mask=text_attention_mask)
|
244 |
+
text_embeds = text_embeds[:, 0, :]
|
245 |
+
|
246 |
+
image_embeds = self.contrast_head(quantizer_output['quant_embeds'])
|
247 |
+
if self.image_cls_token_type == 'last':
|
248 |
+
image_embeds = image_embeds[:, -1, :]
|
249 |
+
else:
|
250 |
+
raise NotImplementedError
|
251 |
+
|
252 |
+
image_embeds = F.normalize(self.image_proj(image_embeds), dim=-1)
|
253 |
+
text_embeds = F.normalize(self.text_proj(text_embeds), dim=-1)
|
254 |
+
|
255 |
+
contrast_loss, i2t_acc, t2i_acc = contrastive_loss(image_feats=image_embeds,
|
256 |
+
text_feats=text_embeds,
|
257 |
+
logit_scale=self.logit_scale)
|
258 |
+
output_state['contrast_loss'] = contrast_loss
|
259 |
+
output_state['total_loss'] += self.scale_contrast_loss * contrast_loss
|
260 |
+
output_state['i2t_acc'] = i2t_acc
|
261 |
+
output_state['t2i_acc'] = t2i_acc
|
262 |
+
|
263 |
+
return output_state
|
264 |
+
|
265 |
+
def encode_image_embeds(self, image_embeds):
|
266 |
+
pass
|
267 |
+
|
268 |
+
@classmethod
|
269 |
+
def from_pretrained(cls, qformer, quantizer, distiller=None, contrast_head=None, pretrained_model_path=None,
|
270 |
+
**kwargs):
|
271 |
+
model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, contrast_head=contrast_head, **kwargs)
|
272 |
+
if pretrained_model_path is not None:
|
273 |
+
ckpt = torch.load(pretrained_model_path, map_location='cpu')
|
274 |
+
missing, unexpected = model.load_state_dict(ckpt, strict=False)
|
275 |
+
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
|
276 |
+
return model
|
277 |
+
|
278 |
+
|
279 |
+
class DiscreteModleDistillWithDoubleContrastive(nn.Module):
|
280 |
+
|
281 |
+
def __init__(
|
282 |
+
self,
|
283 |
+
qformer,
|
284 |
+
quantizer=None,
|
285 |
+
distiller=None,
|
286 |
+
contrast_head=None,
|
287 |
+
projection_dim=1024,
|
288 |
+
distill_loss_type='cosine',
|
289 |
+
share_contrast_head=True, # share contrastive head with distiller
|
290 |
+
quantize_cls_token=False,
|
291 |
+
rec_qformer=False,
|
292 |
+
has_contrast=False,
|
293 |
+
freeze_qformer=False,
|
294 |
+
scale_commit_loss=1.0,
|
295 |
+
scale_contrast_loss=1.0,
|
296 |
+
scale_distill_loss=1.0) -> None:
|
297 |
+
super().__init__()
|
298 |
+
self.qformer = qformer
|
299 |
+
self.quantizer = quantizer
|
300 |
+
self.distiller = distiller
|
301 |
+
self.contrast_head = contrast_head
|
302 |
+
self.distill_loss_type = distill_loss_type
|
303 |
+
self.quantize_cls_token = quantize_cls_token
|
304 |
+
|
305 |
+
self.rec_qformer = rec_qformer
|
306 |
+
self.has_contrast = has_contrast
|
307 |
+
|
308 |
+
if freeze_qformer:
|
309 |
+
self.qformer.requires_grad_(False)
|
310 |
+
else:
|
311 |
+
self.logit_scale_qformer = nn.Parameter(0.07 * torch.ones([]))
|
312 |
+
self.image_proj_qformer = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
|
313 |
+
self.text_proj_qformer = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
|
314 |
+
self.cls_norm_qformer = nn.LayerNorm(qformer.perceiver.config.projection_dim)
|
315 |
+
|
316 |
+
if self.contrast_head is not None:
|
317 |
+
self.logit_scale_head = nn.Parameter(0.07 * torch.ones([]))
|
318 |
+
self.image_proj_head = nn.Linear(contrast_head.perceiver.config.projection_dim, projection_dim, bias=False)
|
319 |
+
self.text_proj_head = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
|
320 |
+
self.cls_norm_head = nn.LayerNorm(contrast_head.perceiver.config.projection_dim)
|
321 |
+
|
322 |
+
if share_contrast_head and distiller is not None:
|
323 |
+
self.logit_scale_head = nn.Parameter(0.07 * torch.ones([]))
|
324 |
+
self.image_proj_head = nn.Linear(distiller.perceiver.config.projection_dim, projection_dim, bias=False)
|
325 |
+
self.text_proj_head = nn.Linear(qformer.perceiver.config.projection_dim, projection_dim, bias=False)
|
326 |
+
self.cls_norm_head = nn.LayerNorm(distiller.perceiver.config.projection_dim)
|
327 |
+
|
328 |
+
self.scale_commit_loss = scale_commit_loss
|
329 |
+
self.scale_contrast_loss = scale_contrast_loss
|
330 |
+
self.scale_distill_loss = scale_distill_loss
|
331 |
+
self.share_contrast_head = share_contrast_head
|
332 |
+
self.freeze_qformer = freeze_qformer
|
333 |
+
assert int(self.share_contrast_head) + int(contrast_head is not None) <= 1
|
334 |
+
|
335 |
+
def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None):
|
336 |
+
|
337 |
+
if self.freeze_qformer:
|
338 |
+
with torch.no_grad():
|
339 |
+
qforemr_embeds = self.qformer(image_embeds=image_embeds)
|
340 |
+
else:
|
341 |
+
qforemr_embeds = self.qformer(image_embeds=image_embeds)
|
342 |
+
qforemr_cls_embeds = qforemr_embeds[:, -1, :]
|
343 |
+
|
344 |
+
if not self.quantize_cls_token:
|
345 |
+
qforemr_embeds = qforemr_embeds[:, :-1, :]
|
346 |
+
|
347 |
+
if self.has_contrast:
|
348 |
+
text_embeds = self.qformer(input_ids=input_ids, text_attention_mask=text_attention_mask)
|
349 |
+
text_cls_embeds = text_embeds[:, 0, :]
|
350 |
+
|
351 |
+
output_state = {}
|
352 |
+
output_state['total_loss'] = 0.0
|
353 |
+
|
354 |
+
if not self.freeze_qformer and self.has_contrast:
|
355 |
+
qforemr_cls_embeds = self.cls_norm_qformer(qforemr_cls_embeds)
|
356 |
+
qformer_image_embeds = F.normalize(self.image_proj_qformer(qforemr_cls_embeds), dim=-1)
|
357 |
+
qformer_text_embeds = F.normalize(self.text_proj_qformer(text_cls_embeds), dim=-1)
|
358 |
+
|
359 |
+
qformer_contrast_loss, \
|
360 |
+
qformer_i2t_acc, \
|
361 |
+
qformer_t2i_acc = contrastive_loss(image_feats=qformer_image_embeds,
|
362 |
+
text_feats=qformer_text_embeds,
|
363 |
+
logit_scale=self.logit_scale_qformer)
|
364 |
+
output_state['qformer_contrast_loss'] = qformer_contrast_loss
|
365 |
+
output_state['total_loss'] += self.scale_contrast_loss * qformer_contrast_loss
|
366 |
+
output_state['qformer_i2t_acc'] = qformer_i2t_acc
|
367 |
+
output_state['qformer_t2i_acc'] = qformer_t2i_acc
|
368 |
+
|
369 |
+
if self.quantizer is not None and self.distiller is not None:
|
370 |
+
quantizer_output = self.quantizer(qforemr_embeds)
|
371 |
+
|
372 |
+
recon_embeds = self.distiller(quantizer_output['quant_embeds'])
|
373 |
+
if self.share_contrast_head:
|
374 |
+
contrast_head_cls_embeds = recon_embeds[:, -1, :]
|
375 |
+
contrast_head_cls_embeds = self.cls_norm_head(contrast_head_cls_embeds)
|
376 |
+
recon_embeds = recon_embeds[:, :-1, :]
|
377 |
+
if self.contrast_head is not None:
|
378 |
+
contrast_head_embeds = self.contrast_head(quantizer_output['quant_embeds'])
|
379 |
+
contrast_head_cls_embeds = contrast_head_embeds[:, -1, :]
|
380 |
+
contrast_head_cls_embeds = self.cls_norm_head(contrast_head_cls_embeds)
|
381 |
+
|
382 |
+
output_state['indices'] = quantizer_output['indices']
|
383 |
+
output_state['commit_loss'] = quantizer_output['commit_loss']
|
384 |
+
output_state['total_loss'] += self.scale_commit_loss * quantizer_output['commit_loss']
|
385 |
+
|
386 |
+
if self.rec_qformer:
|
387 |
+
target_embeds = qforemr_embeds
|
388 |
+
else:
|
389 |
+
target_embeds = image_embeds
|
390 |
+
|
391 |
+
if self.distill_loss_type == 'cosine':
|
392 |
+
distill_loss = cosine_loss(recon_embeds, target_embeds)
|
393 |
+
else:
|
394 |
+
raise NotImplementedError
|
395 |
+
|
396 |
+
output_state['distill_loss'] = distill_loss
|
397 |
+
output_state['total_loss'] += self.scale_distill_loss * distill_loss
|
398 |
+
|
399 |
+
if self.contrast_head is not None or self.share_contrast_head:
|
400 |
+
head_image_embeds = F.normalize(self.image_proj_head(contrast_head_cls_embeds), dim=-1)
|
401 |
+
head_text_embeds = F.normalize(self.text_proj_head(text_cls_embeds), dim=-1)
|
402 |
+
|
403 |
+
head_contrast_loss, head_i2t_acc, head_t2i_acc = contrastive_loss(image_feats=head_image_embeds,
|
404 |
+
text_feats=head_text_embeds,
|
405 |
+
logit_scale=self.logit_scale_head)
|
406 |
+
output_state['head_contrast_loss'] = head_contrast_loss
|
407 |
+
output_state['total_loss'] += self.scale_contrast_loss * head_contrast_loss
|
408 |
+
output_state['head_i2t_acc'] = head_i2t_acc
|
409 |
+
output_state['head_t2i_acc'] = head_t2i_acc
|
410 |
+
|
411 |
+
return output_state
|
412 |
+
|
413 |
+
def encode_image_embeds(self, image_embeds):
|
414 |
+
qforemr_embeds = self.qformer(image_embeds=image_embeds)
|
415 |
+
return qforemr_embeds
|
416 |
+
|
417 |
+
@classmethod
|
418 |
+
def from_pretrained(cls, qformer, quantizer=None, distiller=None, contrast_head=None, pretrained_model_path=None,
|
419 |
+
**kwargs):
|
420 |
+
model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, contrast_head=contrast_head, **kwargs)
|
421 |
+
if pretrained_model_path is not None:
|
422 |
+
ckpt = torch.load(pretrained_model_path, map_location='cpu')
|
423 |
+
missing, unexpected = model.load_state_dict(ckpt, strict=False)
|
424 |
+
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
|
425 |
+
return model
|
426 |
+
|
427 |
+
@classmethod
|
428 |
+
def from_pretrained_stage1_yuying(cls,
|
429 |
+
qformer,
|
430 |
+
quantizer=None,
|
431 |
+
distiller=None,
|
432 |
+
contrast_head=None,
|
433 |
+
pretrained_model_path=None,
|
434 |
+
**kwargs):
|
435 |
+
model = cls(qformer=qformer, quantizer=quantizer, distiller=distiller, contrast_head=contrast_head, **kwargs)
|
436 |
+
if pretrained_model_path is not None:
|
437 |
+
ckpt = torch.load(pretrained_model_path, map_location='cpu')
|
438 |
+
ckpt = ckpt['model']
|
439 |
+
|
440 |
+
new_ckpt = {}
|
441 |
+
new_ckpt['qformer.embed_module.query'] = ckpt['query_tokens'].squeeze(0)
|
442 |
+
new_ckpt['qformer.norm.weight'] = ckpt['ln_vision.weight']
|
443 |
+
new_ckpt['qformer.norm.bias'] = ckpt['ln_vision.bias']
|
444 |
+
|
445 |
+
for key in ckpt.keys():
|
446 |
+
if key.startswith('Qformer'):
|
447 |
+
new_key = key.replace('Qformer', 'qformer.perceiver')
|
448 |
+
new_ckpt[new_key] = ckpt[key]
|
449 |
+
del ckpt
|
450 |
+
missing, unexpected = model.load_state_dict(new_ckpt, strict=False)
|
451 |
+
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
|
452 |
+
print(missing)
|
453 |
+
print(unexpected)
|
454 |
+
return model
|
src/models/qwen_visual.py
ADDED
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Alibaba Cloud.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from collections import OrderedDict
|
7 |
+
import math
|
8 |
+
import requests
|
9 |
+
from io import BytesIO
|
10 |
+
from functools import partial
|
11 |
+
from PIL import Image
|
12 |
+
from typing import Callable, Optional, Sequence, Tuple, List
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
from torch.nn import functional as F
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
from torchvision import transforms
|
20 |
+
from torchvision.transforms import InterpolationMode
|
21 |
+
|
22 |
+
|
23 |
+
def get_abs_pos(abs_pos, tgt_size):
|
24 |
+
# abs_pos: L, C
|
25 |
+
# tgt_size: M
|
26 |
+
# return: M, C
|
27 |
+
src_size = int(math.sqrt(abs_pos.size(0)))
|
28 |
+
tgt_size = int(math.sqrt(tgt_size))
|
29 |
+
dtype = abs_pos.dtype
|
30 |
+
|
31 |
+
if src_size != tgt_size:
|
32 |
+
return F.interpolate(
|
33 |
+
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
|
34 |
+
size=(tgt_size, tgt_size),
|
35 |
+
mode="bicubic",
|
36 |
+
align_corners=False,
|
37 |
+
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
|
38 |
+
else:
|
39 |
+
return abs_pos
|
40 |
+
|
41 |
+
|
42 |
+
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
43 |
+
|
44 |
+
|
45 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
46 |
+
"""
|
47 |
+
grid_size: int of the grid height and width
|
48 |
+
return:
|
49 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
50 |
+
"""
|
51 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
52 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
53 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
54 |
+
grid = np.stack(grid, axis=0)
|
55 |
+
|
56 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
57 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
58 |
+
if cls_token:
|
59 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
60 |
+
return pos_embed
|
61 |
+
|
62 |
+
|
63 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
64 |
+
assert embed_dim % 2 == 0
|
65 |
+
|
66 |
+
# use half of dimensions to encode grid_h
|
67 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
68 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
69 |
+
|
70 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
71 |
+
return emb
|
72 |
+
|
73 |
+
|
74 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
75 |
+
"""
|
76 |
+
embed_dim: output dimension for each position
|
77 |
+
pos: a list of positions to be encoded: size (M,)
|
78 |
+
out: (M, D)
|
79 |
+
"""
|
80 |
+
assert embed_dim % 2 == 0
|
81 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
82 |
+
omega /= embed_dim / 2.
|
83 |
+
omega = 1. / 10000 ** omega # (D/2,)
|
84 |
+
|
85 |
+
pos = pos.reshape(-1) # (M,)
|
86 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
87 |
+
|
88 |
+
emb_sin = np.sin(out) # (M, D/2)
|
89 |
+
emb_cos = np.cos(out) # (M, D/2)
|
90 |
+
|
91 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
92 |
+
return emb
|
93 |
+
|
94 |
+
|
95 |
+
class Resampler(nn.Module):
|
96 |
+
"""
|
97 |
+
A 2D perceiver-resampler network with one cross attention layers by
|
98 |
+
(grid_size**2) learnable queries and 2d sincos pos_emb
|
99 |
+
Outputs:
|
100 |
+
A tensor with the shape of (grid_size**2, embed_dim)
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(self, grid_size, embed_dim, num_heads, kv_dim=None, norm_layer=nn.LayerNorm):
|
104 |
+
super().__init__()
|
105 |
+
self.num_queries = grid_size ** 2
|
106 |
+
self.embed_dim = embed_dim
|
107 |
+
self.num_heads = num_heads
|
108 |
+
|
109 |
+
self.pos_embed = nn.Parameter(torch.from_numpy(get_2d_sincos_pos_embed(embed_dim,
|
110 |
+
grid_size)).float()).requires_grad_(
|
111 |
+
False)
|
112 |
+
|
113 |
+
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
114 |
+
trunc_normal_(self.query, std=.02)
|
115 |
+
|
116 |
+
if kv_dim is not None and kv_dim != embed_dim:
|
117 |
+
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
|
118 |
+
self.out_dim = kv_dim
|
119 |
+
else:
|
120 |
+
self.kv_proj = nn.Identity()
|
121 |
+
self.out_dim = embed_dim
|
122 |
+
|
123 |
+
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
124 |
+
self.ln_q = norm_layer(embed_dim)
|
125 |
+
self.ln_kv = norm_layer(embed_dim)
|
126 |
+
|
127 |
+
self.apply(self._init_weights)
|
128 |
+
|
129 |
+
def _init_weights(self, m):
|
130 |
+
if isinstance(m, nn.Linear):
|
131 |
+
trunc_normal_(m.weight, std=.02)
|
132 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
133 |
+
nn.init.constant_(m.bias, 0)
|
134 |
+
elif isinstance(m, nn.LayerNorm):
|
135 |
+
nn.init.constant_(m.bias, 0)
|
136 |
+
nn.init.constant_(m.weight, 1.0)
|
137 |
+
|
138 |
+
def forward(self, x, attn_mask=None):
|
139 |
+
|
140 |
+
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
|
141 |
+
|
142 |
+
x = self.kv_proj(x)
|
143 |
+
x = self.ln_kv(x).permute(1, 0, 2)
|
144 |
+
|
145 |
+
N = x.shape[1]
|
146 |
+
q = self.ln_q(self.query)
|
147 |
+
out = \
|
148 |
+
self.attn(self._repeat(q, N) + self.pos_embed.unsqueeze(1), x + pos_embed.unsqueeze(1), x, attn_mask=attn_mask)[
|
149 |
+
0]
|
150 |
+
return out.permute(1, 0, 2)
|
151 |
+
|
152 |
+
def _repeat(self, query, N: int):
|
153 |
+
return query.unsqueeze(1).repeat(1, N, 1)
|
154 |
+
|
155 |
+
|
156 |
+
class VisualAttention(nn.Module):
|
157 |
+
"""self-attention layer class.
|
158 |
+
|
159 |
+
Self-attention layer takes input with size [s, b, h]
|
160 |
+
and returns output of the same size.
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(self, embed_dim, num_heads, bias=True, kdim=None, vdim=None):
|
164 |
+
super(VisualAttention, self).__init__()
|
165 |
+
self.embed_dim = embed_dim
|
166 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
167 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
168 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
169 |
+
|
170 |
+
self.num_heads = num_heads
|
171 |
+
|
172 |
+
# Per attention head and per partition values.
|
173 |
+
assert embed_dim % num_heads == 0
|
174 |
+
self.hidden_size_per_attention_head = embed_dim // num_heads
|
175 |
+
self.num_attention_heads_per_partition = num_heads
|
176 |
+
self.hidden_size_per_partition = embed_dim
|
177 |
+
|
178 |
+
# Strided linear layer.
|
179 |
+
assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently'
|
180 |
+
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
|
181 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
182 |
+
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
183 |
+
|
184 |
+
def forward(self, query, key, value, attn_mask=None):
|
185 |
+
# query/key/value: [sq, b, h]
|
186 |
+
sq, b, _ = query.size()
|
187 |
+
|
188 |
+
assert query is key, 'Only Support Self-Attention Currently'
|
189 |
+
sk = sq
|
190 |
+
mixed_x_layer = self.in_proj(query)
|
191 |
+
|
192 |
+
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
|
193 |
+
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
194 |
+
(self.num_attention_heads_per_partition,
|
195 |
+
3 * self.hidden_size_per_attention_head)
|
196 |
+
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
197 |
+
|
198 |
+
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
199 |
+
query_layer, key_layer, value_layer = mixed_x_layer.split(self.hidden_size_per_attention_head, dim=-1)
|
200 |
+
|
201 |
+
# [sq, b, np, hn] -> [sq, b * np, hn]
|
202 |
+
query_layer = query_layer.view(sq, b * self.num_attention_heads_per_partition,
|
203 |
+
self.hidden_size_per_attention_head).transpose(0, 1)
|
204 |
+
# [sk, b, np, hn] -> [sk, b * np, hn]
|
205 |
+
key_layer = key_layer.view(sk, b * self.num_attention_heads_per_partition,
|
206 |
+
self.hidden_size_per_attention_head).transpose(0, 1)
|
207 |
+
|
208 |
+
q_scaled = query_layer / self.norm_factor
|
209 |
+
if attn_mask is not None:
|
210 |
+
attention_probs = torch.baddbmm(attn_mask, q_scaled, key_layer.transpose(-2, -1))
|
211 |
+
else:
|
212 |
+
attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
|
213 |
+
attention_probs = attention_probs.softmax(dim=-1)
|
214 |
+
|
215 |
+
value_layer = value_layer.view(sk, b * self.num_attention_heads_per_partition,
|
216 |
+
self.hidden_size_per_attention_head).transpose(0, 1)
|
217 |
+
|
218 |
+
# matmul: [b * np, sq, hn]
|
219 |
+
context_layer = torch.bmm(attention_probs, value_layer)
|
220 |
+
|
221 |
+
# change view [b, np, sq, hn]
|
222 |
+
context_layer = context_layer.view(b, self.num_attention_heads_per_partition, sq,
|
223 |
+
self.hidden_size_per_attention_head)
|
224 |
+
|
225 |
+
# [b, np, sq, hn] --> [sq, b, np, hn]
|
226 |
+
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
227 |
+
|
228 |
+
# [sq, b, np, hn] --> [sq, b, hp]
|
229 |
+
new_context_layer_shape = context_layer.size()[:-2] + \
|
230 |
+
(self.hidden_size_per_partition,)
|
231 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
232 |
+
|
233 |
+
output = self.out_proj(context_layer)
|
234 |
+
|
235 |
+
return output
|
236 |
+
|
237 |
+
|
238 |
+
class VisualAttentionBlock(nn.Module):
|
239 |
+
|
240 |
+
def __init__(
|
241 |
+
self,
|
242 |
+
d_model: int,
|
243 |
+
n_head: int,
|
244 |
+
mlp_ratio: float = 4.0,
|
245 |
+
act_layer: Callable = nn.GELU,
|
246 |
+
norm_layer: Callable = nn.LayerNorm,
|
247 |
+
is_cross_attention: bool = False,
|
248 |
+
):
|
249 |
+
super().__init__()
|
250 |
+
|
251 |
+
self.ln_1 = norm_layer(d_model)
|
252 |
+
if is_cross_attention:
|
253 |
+
self.ln_1_kv = norm_layer(d_model)
|
254 |
+
|
255 |
+
self.ln_2 = norm_layer(d_model)
|
256 |
+
mlp_width = int(d_model * mlp_ratio)
|
257 |
+
self.attn = VisualAttention(d_model, n_head)
|
258 |
+
self.mlp = nn.Sequential(
|
259 |
+
OrderedDict([("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()),
|
260 |
+
("c_proj", nn.Linear(mlp_width, d_model))]))
|
261 |
+
|
262 |
+
def attention(
|
263 |
+
self,
|
264 |
+
q_x: torch.Tensor,
|
265 |
+
k_x: Optional[torch.Tensor] = None,
|
266 |
+
v_x: Optional[torch.Tensor] = None,
|
267 |
+
attn_mask: Optional[torch.Tensor] = None,
|
268 |
+
):
|
269 |
+
k_x = k_x if k_x is not None else q_x
|
270 |
+
v_x = v_x if v_x is not None else q_x
|
271 |
+
|
272 |
+
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
|
273 |
+
return self.attn(q_x, k_x, v_x, attn_mask=attn_mask)
|
274 |
+
|
275 |
+
def forward(
|
276 |
+
self,
|
277 |
+
q_x: torch.Tensor,
|
278 |
+
k_x: Optional[torch.Tensor] = None,
|
279 |
+
v_x: Optional[torch.Tensor] = None,
|
280 |
+
attn_mask: Optional[torch.Tensor] = None,
|
281 |
+
):
|
282 |
+
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
|
283 |
+
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
|
284 |
+
|
285 |
+
x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
|
286 |
+
x = x + self.mlp(self.ln_2(x))
|
287 |
+
return x
|
288 |
+
|
289 |
+
|
290 |
+
class TransformerBlock(nn.Module):
|
291 |
+
|
292 |
+
def __init__(
|
293 |
+
self,
|
294 |
+
width: int,
|
295 |
+
layers: int,
|
296 |
+
heads: int,
|
297 |
+
mlp_ratio: float = 4.0,
|
298 |
+
act_layer: Callable = nn.GELU,
|
299 |
+
norm_layer: Callable = nn.LayerNorm,
|
300 |
+
):
|
301 |
+
super().__init__()
|
302 |
+
self.width = width
|
303 |
+
self.layers = layers
|
304 |
+
|
305 |
+
self.resblocks = nn.ModuleList(
|
306 |
+
[VisualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer) for _ in
|
307 |
+
range(layers)])
|
308 |
+
|
309 |
+
def get_cast_dtype(self) -> torch.dtype:
|
310 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
311 |
+
|
312 |
+
def get_cast_device(self) -> torch.device:
|
313 |
+
return self.resblocks[0].mlp.c_fc.weight.device
|
314 |
+
|
315 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
316 |
+
for r in self.resblocks:
|
317 |
+
x = r(x, attn_mask=attn_mask)
|
318 |
+
return x
|
319 |
+
|
320 |
+
|
321 |
+
class VisionTransformerWithAttnPool(nn.Module):
|
322 |
+
|
323 |
+
def __init__(self,
|
324 |
+
image_size: int,
|
325 |
+
patch_size: int,
|
326 |
+
width: int,
|
327 |
+
layers: int,
|
328 |
+
heads: int,
|
329 |
+
mlp_ratio: float,
|
330 |
+
n_queries: int = 256,
|
331 |
+
output_dim: int = 512,
|
332 |
+
**kwargs):
|
333 |
+
super().__init__()
|
334 |
+
image_height, image_width = self.image_size = (image_size, image_size)
|
335 |
+
patch_height, patch_width = self.patch_size = (patch_size, patch_size)
|
336 |
+
self.grid_size = (image_height // patch_height, image_width // patch_width)
|
337 |
+
self.output_dim = output_dim
|
338 |
+
|
339 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
340 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
341 |
+
self.image_transform = transforms.Compose([
|
342 |
+
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
343 |
+
transforms.ToTensor(),
|
344 |
+
transforms.Normalize(mean=mean, std=std),
|
345 |
+
])
|
346 |
+
|
347 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
348 |
+
|
349 |
+
# class embeddings and positional embeddings
|
350 |
+
scale = width ** -0.5
|
351 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn(256, width))
|
352 |
+
|
353 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
354 |
+
act_layer = nn.GELU
|
355 |
+
|
356 |
+
self.ln_pre = norm_layer(width)
|
357 |
+
self.transformer = TransformerBlock(
|
358 |
+
width,
|
359 |
+
layers,
|
360 |
+
heads,
|
361 |
+
mlp_ratio,
|
362 |
+
act_layer=act_layer,
|
363 |
+
norm_layer=norm_layer,
|
364 |
+
)
|
365 |
+
|
366 |
+
self.attn_pool = Resampler(
|
367 |
+
grid_size=int(math.sqrt(n_queries)),
|
368 |
+
embed_dim=output_dim,
|
369 |
+
num_heads=output_dim // 128,
|
370 |
+
kv_dim=width,
|
371 |
+
norm_layer=norm_layer,
|
372 |
+
)
|
373 |
+
self.ln_post = norm_layer(output_dim)
|
374 |
+
self.proj = nn.Parameter((output_dim ** -0.5) * torch.randn(output_dim, output_dim))
|
375 |
+
|
376 |
+
def forward(self, x: torch.Tensor):
|
377 |
+
x = x.to(
|
378 |
+
dtype=self.transformer.get_cast_dtype(),
|
379 |
+
device=self.transformer.get_cast_device(),
|
380 |
+
)
|
381 |
+
# to patches
|
382 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
383 |
+
# shape = [*, width, grid ** 2]
|
384 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
385 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
386 |
+
|
387 |
+
x = x + get_abs_pos(self.positional_embedding, x.size(1))
|
388 |
+
|
389 |
+
x = self.ln_pre(x)
|
390 |
+
|
391 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
392 |
+
x = self.transformer(x)
|
393 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
394 |
+
|
395 |
+
x = self.attn_pool(x)
|
396 |
+
x = self.ln_post(x)
|
397 |
+
x = x @ self.proj
|
398 |
+
|
399 |
+
return x
|
400 |
+
|
401 |
+
def encode(self, image_paths: List[str]):
|
402 |
+
images = []
|
403 |
+
for image_path in image_paths:
|
404 |
+
if image_path.startswith("http://") or image_path.startswith("https://"):
|
405 |
+
image = Image.open(requests.get(image_path, stream=True).raw)
|
406 |
+
else:
|
407 |
+
image = Image.open(image_path)
|
408 |
+
image = image.convert("RGB")
|
409 |
+
images.append(self.image_transform(image))
|
410 |
+
images = torch.stack(images, dim=0)
|
411 |
+
return self(images)
|
412 |
+
|
413 |
+
@classmethod
|
414 |
+
def from_pretrained(cls, pretrained_model_path=None, **kawrgs):
|
415 |
+
model = cls(**kawrgs)
|
416 |
+
if pretrained_model_path is not None:
|
417 |
+
ckpt = torch.load(pretrained_model_path, map_location='cpu')
|
418 |
+
missing, unexpected = model.load_state_dict(ckpt, strict=False)
|
419 |
+
print('Load ckpt of qwen visual encoder')
|
420 |
+
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
|
421 |
+
|
422 |
+
return model
|
423 |
+
|
424 |
+
|
425 |
+
class VisionTransformer(nn.Module):
|
426 |
+
|
427 |
+
def __init__(self,
|
428 |
+
image_size: int,
|
429 |
+
patch_size: int,
|
430 |
+
width: int,
|
431 |
+
layers: int,
|
432 |
+
heads: int,
|
433 |
+
mlp_ratio: float,
|
434 |
+
n_queries: int = 256,
|
435 |
+
output_dim: int = 512,
|
436 |
+
**kwargs):
|
437 |
+
super().__init__()
|
438 |
+
image_height, image_width = self.image_size = (image_size, image_size)
|
439 |
+
patch_height, patch_width = self.patch_size = (patch_size, patch_size)
|
440 |
+
self.grid_size = (image_height // patch_height, image_width // patch_width)
|
441 |
+
self.output_dim = output_dim
|
442 |
+
|
443 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
444 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
445 |
+
self.image_transform = transforms.Compose([
|
446 |
+
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
447 |
+
transforms.ToTensor(),
|
448 |
+
transforms.Normalize(mean=mean, std=std),
|
449 |
+
])
|
450 |
+
|
451 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
452 |
+
|
453 |
+
# class embeddings and positional embeddings
|
454 |
+
scale = width ** -0.5
|
455 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn(256, width))
|
456 |
+
|
457 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
458 |
+
act_layer = nn.GELU
|
459 |
+
|
460 |
+
self.ln_pre = norm_layer(width)
|
461 |
+
self.transformer = TransformerBlock(
|
462 |
+
width,
|
463 |
+
layers,
|
464 |
+
heads,
|
465 |
+
mlp_ratio,
|
466 |
+
act_layer=act_layer,
|
467 |
+
norm_layer=norm_layer,
|
468 |
+
)
|
469 |
+
|
470 |
+
def forward(self, x: torch.Tensor):
|
471 |
+
x = x.to(
|
472 |
+
dtype=self.transformer.get_cast_dtype(),
|
473 |
+
device=self.transformer.get_cast_device(),
|
474 |
+
)
|
475 |
+
# to patches
|
476 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
477 |
+
# shape = [*, width, grid ** 2]
|
478 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
479 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
480 |
+
|
481 |
+
x = x + get_abs_pos(self.positional_embedding, x.size(1))
|
482 |
+
|
483 |
+
x = self.ln_pre(x)
|
484 |
+
|
485 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
486 |
+
x = self.transformer(x)
|
487 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
488 |
+
|
489 |
+
return x
|
490 |
+
|
491 |
+
def encode(self, image_paths: List[str]):
|
492 |
+
images = []
|
493 |
+
for image_path in image_paths:
|
494 |
+
if image_path.startswith("http://") or image_path.startswith("https://"):
|
495 |
+
image = Image.open(requests.get(image_path, stream=True).raw)
|
496 |
+
else:
|
497 |
+
image = Image.open(image_path)
|
498 |
+
image = image.convert("RGB")
|
499 |
+
images.append(self.image_transform(image))
|
500 |
+
images = torch.stack(images, dim=0)
|
501 |
+
return self(images)
|
src/models_clm/__init__.py
ADDED
File without changes
|
src/models_clm/generation.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import LogitsProcessor
|
3 |
+
|
4 |
+
BOI_TOKEN = '<img>'
|
5 |
+
EOI_TOKEN = '</img>'
|
6 |
+
IMG_TOKEN = '<img_{:05d}>'
|
7 |
+
|
8 |
+
|
9 |
+
class AutoImageTokenGenerationProcessor(LogitsProcessor):
|
10 |
+
|
11 |
+
def __init__(self, tokenizer, num_img_gen_tokens=64) -> None:
|
12 |
+
super().__init__()
|
13 |
+
# self.boi_token_id = tokenizer.encode(BOI_TOKEN)[0]
|
14 |
+
# self.eoi_token_id = tokenizer.encode(EOI_TOKEN)[0]
|
15 |
+
img_all_token_str = ''.join([BOI_TOKEN] + [IMG_TOKEN.format(int(item))
|
16 |
+
for item in range(num_img_gen_tokens)] + [EOI_TOKEN])
|
17 |
+
self.img_ids_list = tokenizer.encode(img_all_token_str, add_special_tokens=False)
|
18 |
+
|
19 |
+
def __call__(self, input_ids, scores):
|
20 |
+
bz = input_ids.shape[0]
|
21 |
+
for i in range(bz):
|
22 |
+
cur_input_id = input_ids[i, -1].item()
|
23 |
+
if cur_input_id in self.img_ids_list[:-1]:
|
24 |
+
|
25 |
+
output_id = self.img_ids_list[self.img_ids_list.index(cur_input_id) + 1]
|
26 |
+
scores[i, ..., output_id] = scores[i, ...].max() + 10.
|
27 |
+
else:
|
28 |
+
|
29 |
+
scores[i, ..., torch.tensor(self.img_ids_list[1:]).to(dtype=torch.long)] = 0.0
|
30 |
+
|
31 |
+
return scores
|
src/models_clm/modeling_llama_4_35.py
ADDED
@@ -0,0 +1,1236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
6 |
+
# and OPT implementations in this library. It has been modified from its
|
7 |
+
# original forms to accommodate minor architectural differences compared
|
8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
9 |
+
#
|
10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
11 |
+
# you may not use this file except in compliance with the License.
|
12 |
+
# You may obtain a copy of the License at
|
13 |
+
#
|
14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
15 |
+
#
|
16 |
+
# Unless required by applicable law or agreed to in writing, software
|
17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
19 |
+
# See the License for the specific language governing permissions and
|
20 |
+
# limitations under the License.
|
21 |
+
""" PyTorch LLaMA model."""
|
22 |
+
import math
|
23 |
+
import warnings
|
24 |
+
from typing import List, Optional, Tuple, Union
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn.functional as F
|
28 |
+
import torch.utils.checkpoint
|
29 |
+
from torch import nn
|
30 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
31 |
+
|
32 |
+
from transformers.activations import ACT2FN
|
33 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
|
34 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \
|
35 |
+
SequenceClassifierOutputWithPast
|
36 |
+
from transformers.modeling_utils import PreTrainedModel
|
37 |
+
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
38 |
+
from transformers.utils import (
|
39 |
+
add_start_docstrings,
|
40 |
+
add_start_docstrings_to_model_forward,
|
41 |
+
is_flash_attn_2_available,
|
42 |
+
logging,
|
43 |
+
replace_return_docstrings,
|
44 |
+
)
|
45 |
+
from transformers.utils.import_utils import is_torch_fx_available
|
46 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
47 |
+
|
48 |
+
if is_flash_attn_2_available():
|
49 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
50 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
51 |
+
|
52 |
+
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
53 |
+
# It means that the function will not be traced through and simply appear as a node in the graph.
|
54 |
+
if is_torch_fx_available():
|
55 |
+
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
56 |
+
|
57 |
+
logger = logging.get_logger(__name__)
|
58 |
+
|
59 |
+
_CONFIG_FOR_DOC = "LlamaConfig"
|
60 |
+
|
61 |
+
|
62 |
+
def _get_unpad_data(attention_mask):
|
63 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
64 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
65 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
66 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
67 |
+
return (
|
68 |
+
indices,
|
69 |
+
cu_seqlens,
|
70 |
+
max_seqlen_in_batch,
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
75 |
+
warnings.warn(
|
76 |
+
"Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils.AttentionMaskConverter._prepare_4d_attention_mask" # yapf: disable # noqa
|
77 |
+
|
78 |
+
)
|
79 |
+
return AttentionMaskConverter._prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
80 |
+
|
81 |
+
|
82 |
+
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device,
|
83 |
+
past_key_values_length: int = 0):
|
84 |
+
warnings.warn(
|
85 |
+
"Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttentionMaskConverter._make_causal_mask" # yapf: disable # noqa
|
86 |
+
|
87 |
+
)
|
88 |
+
return AttentionMaskConverter._make_causal_mask(input_ids_shape=input_ids_shape,
|
89 |
+
dtype=dtype,
|
90 |
+
device=device,
|
91 |
+
past_key_values_length=past_key_values_length)
|
92 |
+
|
93 |
+
|
94 |
+
class LlamaRMSNorm(nn.Module):
|
95 |
+
|
96 |
+
def __init__(self, hidden_size, eps=1e-6):
|
97 |
+
"""
|
98 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
99 |
+
"""
|
100 |
+
super().__init__()
|
101 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
102 |
+
self.variance_epsilon = eps
|
103 |
+
|
104 |
+
def forward(self, hidden_states):
|
105 |
+
input_dtype = hidden_states.dtype
|
106 |
+
hidden_states = hidden_states.to(torch.float32)
|
107 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
108 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
109 |
+
return self.weight * hidden_states.to(input_dtype)
|
110 |
+
|
111 |
+
|
112 |
+
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
|
113 |
+
|
114 |
+
|
115 |
+
class LlamaRotaryEmbedding(nn.Module):
|
116 |
+
|
117 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
118 |
+
super().__init__()
|
119 |
+
|
120 |
+
self.dim = dim
|
121 |
+
self.max_position_embeddings = max_position_embeddings
|
122 |
+
self.base = base
|
123 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
124 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
125 |
+
|
126 |
+
# Build here to make `torch.jit.trace` work.
|
127 |
+
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device,
|
128 |
+
dtype=torch.get_default_dtype())
|
129 |
+
|
130 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
131 |
+
self.max_seq_len_cached = seq_len
|
132 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
133 |
+
|
134 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
135 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
136 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
137 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
138 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
139 |
+
|
140 |
+
def forward(self, x, seq_len=None):
|
141 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
142 |
+
if seq_len > self.max_seq_len_cached:
|
143 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
144 |
+
|
145 |
+
return (
|
146 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
147 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
152 |
+
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
153 |
+
|
154 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
155 |
+
self.scaling_factor = scaling_factor
|
156 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
157 |
+
|
158 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
159 |
+
self.max_seq_len_cached = seq_len
|
160 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
161 |
+
t = t / self.scaling_factor
|
162 |
+
|
163 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
164 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
165 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
166 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
167 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
168 |
+
|
169 |
+
|
170 |
+
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
171 |
+
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
172 |
+
|
173 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
174 |
+
self.scaling_factor = scaling_factor
|
175 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
176 |
+
|
177 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
178 |
+
self.max_seq_len_cached = seq_len
|
179 |
+
|
180 |
+
if seq_len > self.max_position_embeddings:
|
181 |
+
base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
|
182 |
+
(self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
|
183 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
184 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
185 |
+
|
186 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
187 |
+
|
188 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
189 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
190 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
191 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
192 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
193 |
+
|
194 |
+
|
195 |
+
def rotate_half(x):
|
196 |
+
"""Rotates half the hidden dims of the input."""
|
197 |
+
x1 = x[..., :x.shape[-1] // 2]
|
198 |
+
x2 = x[..., x.shape[-1] // 2:]
|
199 |
+
return torch.cat((-x2, x1), dim=-1)
|
200 |
+
|
201 |
+
|
202 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
203 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
q (`torch.Tensor`): The query tensor.
|
207 |
+
k (`torch.Tensor`): The key tensor.
|
208 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
209 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
210 |
+
position_ids (`torch.Tensor`):
|
211 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
212 |
+
used to pass offsetted position ids when working with a KV-cache.
|
213 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
214 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
215 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
216 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
217 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
218 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
219 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
220 |
+
Returns:
|
221 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
222 |
+
"""
|
223 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
224 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
225 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
226 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
227 |
+
return q_embed, k_embed
|
228 |
+
|
229 |
+
|
230 |
+
class LlamaMLP(nn.Module):
|
231 |
+
|
232 |
+
def __init__(self, config):
|
233 |
+
super().__init__()
|
234 |
+
self.config = config
|
235 |
+
self.hidden_size = config.hidden_size
|
236 |
+
self.intermediate_size = config.intermediate_size
|
237 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
238 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
239 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
240 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
241 |
+
|
242 |
+
def forward(self, x):
|
243 |
+
if self.config.pretraining_tp > 1:
|
244 |
+
slice = self.intermediate_size // self.config.pretraining_tp
|
245 |
+
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
|
246 |
+
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
247 |
+
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
|
248 |
+
|
249 |
+
gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
|
250 |
+
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
|
251 |
+
|
252 |
+
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
|
253 |
+
down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in
|
254 |
+
range(self.config.pretraining_tp)]
|
255 |
+
down_proj = sum(down_proj)
|
256 |
+
else:
|
257 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
258 |
+
|
259 |
+
return down_proj
|
260 |
+
|
261 |
+
|
262 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
263 |
+
"""
|
264 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
265 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
266 |
+
"""
|
267 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
268 |
+
if n_rep == 1:
|
269 |
+
return hidden_states
|
270 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
271 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
272 |
+
|
273 |
+
|
274 |
+
class LlamaAttention(nn.Module):
|
275 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
276 |
+
|
277 |
+
def __init__(self, config: LlamaConfig):
|
278 |
+
super().__init__()
|
279 |
+
self.config = config
|
280 |
+
self.hidden_size = config.hidden_size
|
281 |
+
self.num_heads = config.num_attention_heads
|
282 |
+
self.head_dim = self.hidden_size // self.num_heads
|
283 |
+
self.num_key_value_heads = config.num_key_value_heads
|
284 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
285 |
+
self.max_position_embeddings = config.max_position_embeddings
|
286 |
+
self.rope_theta = config.rope_theta
|
287 |
+
self.is_causal = True
|
288 |
+
|
289 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
290 |
+
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
291 |
+
f" and `num_heads`: {self.num_heads}).")
|
292 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
293 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
294 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
295 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
296 |
+
self._init_rope()
|
297 |
+
|
298 |
+
def _init_rope(self):
|
299 |
+
if self.config.rope_scaling is None:
|
300 |
+
self.rotary_emb = LlamaRotaryEmbedding(
|
301 |
+
self.head_dim,
|
302 |
+
max_position_embeddings=self.max_position_embeddings,
|
303 |
+
base=self.rope_theta,
|
304 |
+
)
|
305 |
+
else:
|
306 |
+
scaling_type = self.config.rope_scaling["type"]
|
307 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
308 |
+
if scaling_type == "linear":
|
309 |
+
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
310 |
+
self.head_dim,
|
311 |
+
max_position_embeddings=self.max_position_embeddings,
|
312 |
+
scaling_factor=scaling_factor,
|
313 |
+
base=self.rope_theta,
|
314 |
+
)
|
315 |
+
elif scaling_type == "dynamic":
|
316 |
+
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
317 |
+
self.head_dim,
|
318 |
+
max_position_embeddings=self.max_position_embeddings,
|
319 |
+
scaling_factor=scaling_factor,
|
320 |
+
base=self.rope_theta,
|
321 |
+
)
|
322 |
+
else:
|
323 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
324 |
+
|
325 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
326 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
327 |
+
|
328 |
+
def forward(
|
329 |
+
self,
|
330 |
+
hidden_states: torch.Tensor,
|
331 |
+
attention_mask: Optional[torch.Tensor] = None,
|
332 |
+
position_ids: Optional[torch.LongTensor] = None,
|
333 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
334 |
+
output_attentions: bool = False,
|
335 |
+
use_cache: bool = False,
|
336 |
+
**kwargs,
|
337 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
338 |
+
if "padding_mask" in kwargs:
|
339 |
+
warnings.warn(
|
340 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
341 |
+
)
|
342 |
+
|
343 |
+
bsz, q_len, _ = hidden_states.size()
|
344 |
+
|
345 |
+
if self.config.pretraining_tp > 1:
|
346 |
+
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
347 |
+
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp,
|
348 |
+
dim=0)
|
349 |
+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
350 |
+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
351 |
+
|
352 |
+
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
353 |
+
query_states = torch.cat(query_states, dim=-1)
|
354 |
+
|
355 |
+
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
356 |
+
key_states = torch.cat(key_states, dim=-1)
|
357 |
+
|
358 |
+
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
359 |
+
value_states = torch.cat(value_states, dim=-1)
|
360 |
+
|
361 |
+
else:
|
362 |
+
query_states = self.q_proj(hidden_states)
|
363 |
+
key_states = self.k_proj(hidden_states)
|
364 |
+
value_states = self.v_proj(hidden_states)
|
365 |
+
|
366 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
367 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
368 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
369 |
+
|
370 |
+
kv_seq_len = key_states.shape[-2]
|
371 |
+
if past_key_value is not None:
|
372 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
373 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
374 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
375 |
+
|
376 |
+
if past_key_value is not None:
|
377 |
+
# reuse k, v, self_attention
|
378 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
379 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
380 |
+
|
381 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
382 |
+
|
383 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
384 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
385 |
+
|
386 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
387 |
+
|
388 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
389 |
+
raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
390 |
+
f" {attn_weights.size()}")
|
391 |
+
|
392 |
+
if attention_mask is not None:
|
393 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
394 |
+
raise ValueError(
|
395 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
|
396 |
+
attn_weights = attn_weights + attention_mask
|
397 |
+
|
398 |
+
# upcast attention to fp32
|
399 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
400 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
401 |
+
|
402 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
403 |
+
raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
404 |
+
f" {attn_output.size()}")
|
405 |
+
|
406 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
407 |
+
|
408 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
409 |
+
|
410 |
+
if self.config.pretraining_tp > 1:
|
411 |
+
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
412 |
+
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
413 |
+
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
414 |
+
else:
|
415 |
+
attn_output = self.o_proj(attn_output)
|
416 |
+
|
417 |
+
if not output_attentions:
|
418 |
+
attn_weights = None
|
419 |
+
|
420 |
+
return attn_output, attn_weights, past_key_value
|
421 |
+
|
422 |
+
|
423 |
+
class LlamaFlashAttention2(LlamaAttention):
|
424 |
+
"""
|
425 |
+
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
|
426 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
427 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
428 |
+
"""
|
429 |
+
|
430 |
+
def forward(
|
431 |
+
self,
|
432 |
+
hidden_states: torch.Tensor,
|
433 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
434 |
+
position_ids: Optional[torch.LongTensor] = None,
|
435 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
436 |
+
output_attentions: bool = False,
|
437 |
+
use_cache: bool = False,
|
438 |
+
**kwargs,
|
439 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
440 |
+
# LlamaFlashAttention2 attention does not support output_attentions
|
441 |
+
if "padding_mask" in kwargs:
|
442 |
+
warnings.warn(
|
443 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
444 |
+
)
|
445 |
+
|
446 |
+
# overwrite attention_mask with padding_mask
|
447 |
+
attention_mask = kwargs.pop("padding_mask")
|
448 |
+
|
449 |
+
output_attentions = False
|
450 |
+
|
451 |
+
bsz, q_len, _ = hidden_states.size()
|
452 |
+
|
453 |
+
query_states = self.q_proj(hidden_states)
|
454 |
+
key_states = self.k_proj(hidden_states)
|
455 |
+
value_states = self.v_proj(hidden_states)
|
456 |
+
|
457 |
+
# Flash attention requires the input to have the shape
|
458 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
459 |
+
# therefore we just need to keep the original shape
|
460 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
461 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
462 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
463 |
+
|
464 |
+
kv_seq_len = key_states.shape[-2]
|
465 |
+
if past_key_value is not None:
|
466 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
467 |
+
|
468 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
469 |
+
|
470 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
471 |
+
|
472 |
+
if past_key_value is not None:
|
473 |
+
# reuse k, v, self_attention
|
474 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
475 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
476 |
+
|
477 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
478 |
+
|
479 |
+
query_states = query_states.transpose(1, 2)
|
480 |
+
key_states = key_states.transpose(1, 2)
|
481 |
+
value_states = value_states.transpose(1, 2)
|
482 |
+
|
483 |
+
# TODO: llama does not have dropout in the config??
|
484 |
+
# It is recommended to use dropout with FA according to the docs
|
485 |
+
# when training.
|
486 |
+
dropout_rate = 0.0 # if not self.training else self.attn_dropout
|
487 |
+
|
488 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
489 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
490 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
491 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
492 |
+
# in fp32. (LlamaRMSNorm handles it correctly)
|
493 |
+
|
494 |
+
input_dtype = query_states.dtype
|
495 |
+
if input_dtype == torch.float32:
|
496 |
+
# Handle the case where the model is quantized
|
497 |
+
if hasattr(self.config, "_pre_quantization_dtype"):
|
498 |
+
target_dtype = self.config._pre_quantization_dtype
|
499 |
+
else:
|
500 |
+
target_dtype = self.q_proj.weight.dtype
|
501 |
+
|
502 |
+
logger.warning_once(
|
503 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
504 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
505 |
+
f" {target_dtype}.")
|
506 |
+
|
507 |
+
query_states = query_states.to(target_dtype)
|
508 |
+
key_states = key_states.to(target_dtype)
|
509 |
+
value_states = value_states.to(target_dtype)
|
510 |
+
|
511 |
+
attn_output = self._flash_attention_forward(query_states,
|
512 |
+
key_states,
|
513 |
+
value_states,
|
514 |
+
attention_mask,
|
515 |
+
q_len,
|
516 |
+
dropout=dropout_rate)
|
517 |
+
|
518 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
519 |
+
attn_output = self.o_proj(attn_output)
|
520 |
+
|
521 |
+
if not output_attentions:
|
522 |
+
attn_weights = None
|
523 |
+
|
524 |
+
return attn_output, attn_weights, past_key_value
|
525 |
+
|
526 |
+
def _flash_attention_forward(self,
|
527 |
+
query_states,
|
528 |
+
key_states,
|
529 |
+
value_states,
|
530 |
+
attention_mask,
|
531 |
+
query_length,
|
532 |
+
dropout=0.0,
|
533 |
+
softmax_scale=None):
|
534 |
+
"""
|
535 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
536 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
537 |
+
|
538 |
+
Args:
|
539 |
+
query_states (`torch.Tensor`):
|
540 |
+
Input query states to be passed to Flash Attention API
|
541 |
+
key_states (`torch.Tensor`):
|
542 |
+
Input key states to be passed to Flash Attention API
|
543 |
+
value_states (`torch.Tensor`):
|
544 |
+
Input value states to be passed to Flash Attention API
|
545 |
+
attention_mask (`torch.Tensor`):
|
546 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
547 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
548 |
+
dropout (`int`, *optional*):
|
549 |
+
Attention dropout
|
550 |
+
softmax_scale (`float`, *optional*):
|
551 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
552 |
+
"""
|
553 |
+
# Contains at least one padding token in the sequence
|
554 |
+
if attention_mask is not None:
|
555 |
+
batch_size = query_states.shape[0]
|
556 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
557 |
+
query_states, key_states, value_states, attention_mask, query_length)
|
558 |
+
|
559 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
560 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
561 |
+
|
562 |
+
attn_output_unpad = flash_attn_varlen_func(
|
563 |
+
query_states,
|
564 |
+
key_states,
|
565 |
+
value_states,
|
566 |
+
cu_seqlens_q=cu_seqlens_q,
|
567 |
+
cu_seqlens_k=cu_seqlens_k,
|
568 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
569 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
570 |
+
dropout_p=dropout,
|
571 |
+
softmax_scale=softmax_scale,
|
572 |
+
causal=self.is_causal,
|
573 |
+
)
|
574 |
+
|
575 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
576 |
+
else:
|
577 |
+
attn_output = flash_attn_func(query_states,
|
578 |
+
key_states,
|
579 |
+
value_states,
|
580 |
+
dropout,
|
581 |
+
softmax_scale=softmax_scale,
|
582 |
+
causal=self.is_causal)
|
583 |
+
|
584 |
+
return attn_output
|
585 |
+
|
586 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
587 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
588 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
589 |
+
|
590 |
+
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
591 |
+
indices_k)
|
592 |
+
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
593 |
+
indices_k)
|
594 |
+
if query_length == kv_seq_len:
|
595 |
+
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
|
596 |
+
indices_k)
|
597 |
+
cu_seqlens_q = cu_seqlens_k
|
598 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
599 |
+
indices_q = indices_k
|
600 |
+
elif query_length == 1:
|
601 |
+
max_seqlen_in_batch_q = 1
|
602 |
+
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32,
|
603 |
+
device=query_layer.device) # There is a memcpy here, that is very bad.
|
604 |
+
indices_q = cu_seqlens_q[:-1]
|
605 |
+
query_layer = query_layer.squeeze(1)
|
606 |
+
else:
|
607 |
+
# The -q_len: slice assumes left padding.
|
608 |
+
attention_mask = attention_mask[:, -query_length:]
|
609 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
610 |
+
|
611 |
+
return (
|
612 |
+
query_layer,
|
613 |
+
key_layer,
|
614 |
+
value_layer,
|
615 |
+
indices_q,
|
616 |
+
(cu_seqlens_q, cu_seqlens_k),
|
617 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
618 |
+
)
|
619 |
+
|
620 |
+
|
621 |
+
class LlamaDecoderLayer(nn.Module):
|
622 |
+
|
623 |
+
def __init__(self, config: LlamaConfig):
|
624 |
+
super().__init__()
|
625 |
+
self.hidden_size = config.hidden_size
|
626 |
+
self.self_attn = (LlamaAttention(
|
627 |
+
config=config) if not getattr(config, "_flash_attn_2_enabled", False) else LlamaFlashAttention2(
|
628 |
+
config=config))
|
629 |
+
self.mlp = LlamaMLP(config)
|
630 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
631 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
632 |
+
|
633 |
+
def forward(
|
634 |
+
self,
|
635 |
+
hidden_states: torch.Tensor,
|
636 |
+
attention_mask: Optional[torch.Tensor] = None,
|
637 |
+
position_ids: Optional[torch.LongTensor] = None,
|
638 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
639 |
+
output_attentions: Optional[bool] = False,
|
640 |
+
use_cache: Optional[bool] = False,
|
641 |
+
**kwargs,
|
642 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
643 |
+
"""
|
644 |
+
Args:
|
645 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
646 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
647 |
+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
648 |
+
query_sequence_length, key_sequence_length)` if default attention is used.
|
649 |
+
output_attentions (`bool`, *optional*):
|
650 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
651 |
+
returned tensors for more detail.
|
652 |
+
use_cache (`bool`, *optional*):
|
653 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
654 |
+
(see `past_key_values`).
|
655 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
656 |
+
"""
|
657 |
+
if "padding_mask" in kwargs:
|
658 |
+
warnings.warn(
|
659 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
660 |
+
)
|
661 |
+
|
662 |
+
residual = hidden_states
|
663 |
+
|
664 |
+
hidden_states = self.input_layernorm(hidden_states)
|
665 |
+
|
666 |
+
# Self Attention
|
667 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
668 |
+
hidden_states=hidden_states,
|
669 |
+
attention_mask=attention_mask,
|
670 |
+
position_ids=position_ids,
|
671 |
+
past_key_value=past_key_value,
|
672 |
+
output_attentions=output_attentions,
|
673 |
+
use_cache=use_cache,
|
674 |
+
**kwargs,
|
675 |
+
)
|
676 |
+
hidden_states = residual + hidden_states
|
677 |
+
|
678 |
+
# Fully Connected
|
679 |
+
residual = hidden_states
|
680 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
681 |
+
hidden_states = self.mlp(hidden_states)
|
682 |
+
hidden_states = residual + hidden_states
|
683 |
+
|
684 |
+
outputs = (hidden_states,)
|
685 |
+
|
686 |
+
if output_attentions:
|
687 |
+
outputs += (self_attn_weights,)
|
688 |
+
|
689 |
+
if use_cache:
|
690 |
+
outputs += (present_key_value,)
|
691 |
+
|
692 |
+
return outputs
|
693 |
+
|
694 |
+
|
695 |
+
LLAMA_START_DOCSTRING = r"""
|
696 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
697 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
698 |
+
etc.)
|
699 |
+
|
700 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
701 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
702 |
+
and behavior.
|
703 |
+
|
704 |
+
Parameters:
|
705 |
+
config ([`LlamaConfig`]):
|
706 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
707 |
+
load the weights associated with the model, only the configuration. Check out the
|
708 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
709 |
+
"""
|
710 |
+
|
711 |
+
|
712 |
+
@add_start_docstrings(
|
713 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
714 |
+
LLAMA_START_DOCSTRING,
|
715 |
+
)
|
716 |
+
class LlamaPreTrainedModel(PreTrainedModel):
|
717 |
+
config_class = LlamaConfig
|
718 |
+
base_model_prefix = "model"
|
719 |
+
supports_gradient_checkpointing = True
|
720 |
+
_no_split_modules = ["LlamaDecoderLayer"]
|
721 |
+
_skip_keys_device_placement = "past_key_values"
|
722 |
+
_supports_flash_attn_2 = True
|
723 |
+
|
724 |
+
def _init_weights(self, module):
|
725 |
+
std = self.config.initializer_range
|
726 |
+
if isinstance(module, nn.Linear):
|
727 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
728 |
+
if module.bias is not None:
|
729 |
+
module.bias.data.zero_()
|
730 |
+
elif isinstance(module, nn.Embedding):
|
731 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
732 |
+
if module.padding_idx is not None:
|
733 |
+
module.weight.data[module.padding_idx].zero_()
|
734 |
+
|
735 |
+
|
736 |
+
LLAMA_INPUTS_DOCSTRING = r"""
|
737 |
+
Args:
|
738 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
739 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
740 |
+
it.
|
741 |
+
|
742 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
743 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
744 |
+
|
745 |
+
[What are input IDs?](../glossary#input-ids)
|
746 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
747 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
748 |
+
|
749 |
+
- 1 for tokens that are **not masked**,
|
750 |
+
- 0 for tokens that are **masked**.
|
751 |
+
|
752 |
+
[What are attention masks?](../glossary#attention-mask)
|
753 |
+
|
754 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
755 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
756 |
+
|
757 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
758 |
+
`past_key_values`).
|
759 |
+
|
760 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
761 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
762 |
+
information on the default strategy.
|
763 |
+
|
764 |
+
- 1 indicates the head is **not masked**,
|
765 |
+
- 0 indicates the head is **masked**.
|
766 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
767 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
768 |
+
config.n_positions - 1]`.
|
769 |
+
|
770 |
+
[What are position IDs?](../glossary#position-ids)
|
771 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
772 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
773 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
774 |
+
`(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
|
775 |
+
|
776 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
777 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
778 |
+
|
779 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
780 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
781 |
+
of shape `(batch_size, sequence_length)`.
|
782 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
783 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
784 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
785 |
+
model's internal embedding lookup matrix.
|
786 |
+
use_cache (`bool`, *optional*):
|
787 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
788 |
+
`past_key_values`).
|
789 |
+
output_attentions (`bool`, *optional*):
|
790 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
791 |
+
tensors for more detail.
|
792 |
+
output_hidden_states (`bool`, *optional*):
|
793 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
794 |
+
more detail.
|
795 |
+
return_dict (`bool`, *optional*):
|
796 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
797 |
+
"""
|
798 |
+
|
799 |
+
|
800 |
+
@add_start_docstrings(
|
801 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
802 |
+
LLAMA_START_DOCSTRING,
|
803 |
+
)
|
804 |
+
class LlamaModel(LlamaPreTrainedModel):
|
805 |
+
"""
|
806 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
807 |
+
|
808 |
+
Args:
|
809 |
+
config: LlamaConfig
|
810 |
+
"""
|
811 |
+
|
812 |
+
def __init__(self, config: LlamaConfig):
|
813 |
+
super().__init__(config)
|
814 |
+
self.padding_idx = config.pad_token_id
|
815 |
+
self.vocab_size = config.vocab_size
|
816 |
+
|
817 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
818 |
+
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
819 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
820 |
+
|
821 |
+
self.gradient_checkpointing = False
|
822 |
+
# Initialize weights and apply final processing
|
823 |
+
self.post_init()
|
824 |
+
|
825 |
+
def get_input_embeddings(self):
|
826 |
+
return self.embed_tokens
|
827 |
+
|
828 |
+
def set_input_embeddings(self, value):
|
829 |
+
self.embed_tokens = value
|
830 |
+
|
831 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
832 |
+
def forward(
|
833 |
+
self,
|
834 |
+
input_ids: torch.LongTensor = None,
|
835 |
+
attention_mask: Optional[torch.Tensor] = None,
|
836 |
+
position_ids: Optional[torch.LongTensor] = None,
|
837 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
838 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
839 |
+
use_cache: Optional[bool] = None,
|
840 |
+
output_attentions: Optional[bool] = None,
|
841 |
+
output_hidden_states: Optional[bool] = None,
|
842 |
+
return_dict: Optional[bool] = None,
|
843 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
844 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
845 |
+
output_hidden_states = (
|
846 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
|
847 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
848 |
+
|
849 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
850 |
+
|
851 |
+
# retrieve input_ids and inputs_embeds
|
852 |
+
if input_ids is not None and inputs_embeds is not None:
|
853 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
854 |
+
elif input_ids is not None:
|
855 |
+
batch_size, seq_length = input_ids.shape[:2]
|
856 |
+
elif inputs_embeds is not None:
|
857 |
+
batch_size, seq_length = inputs_embeds.shape[:2]
|
858 |
+
else:
|
859 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
860 |
+
|
861 |
+
past_key_values_length = 0
|
862 |
+
if past_key_values is not None:
|
863 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
864 |
+
|
865 |
+
if position_ids is None:
|
866 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
867 |
+
position_ids = torch.arange(past_key_values_length,
|
868 |
+
seq_length + past_key_values_length,
|
869 |
+
dtype=torch.long,
|
870 |
+
device=device)
|
871 |
+
position_ids = position_ids.unsqueeze(0)
|
872 |
+
|
873 |
+
if inputs_embeds is None:
|
874 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
875 |
+
|
876 |
+
if getattr(self.config, "_flash_attn_2_enabled", False):
|
877 |
+
# 2d mask is passed through the layers
|
878 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
879 |
+
else:
|
880 |
+
# 4d mask is passed through the layers
|
881 |
+
attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds,
|
882 |
+
past_key_values_length)
|
883 |
+
|
884 |
+
# embed positions
|
885 |
+
hidden_states = inputs_embeds
|
886 |
+
|
887 |
+
if self.gradient_checkpointing and self.training:
|
888 |
+
if use_cache:
|
889 |
+
logger.warning_once(
|
890 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
891 |
+
use_cache = False
|
892 |
+
|
893 |
+
# decoder layers
|
894 |
+
all_hidden_states = () if output_hidden_states else None
|
895 |
+
all_self_attns = () if output_attentions else None
|
896 |
+
next_decoder_cache = () if use_cache else None
|
897 |
+
|
898 |
+
for idx, decoder_layer in enumerate(self.layers):
|
899 |
+
if output_hidden_states:
|
900 |
+
all_hidden_states += (hidden_states,)
|
901 |
+
|
902 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
903 |
+
|
904 |
+
if self.gradient_checkpointing and self.training:
|
905 |
+
layer_outputs = self._gradient_checkpointing_func(
|
906 |
+
decoder_layer.__call__,
|
907 |
+
hidden_states,
|
908 |
+
attention_mask,
|
909 |
+
position_ids,
|
910 |
+
past_key_value,
|
911 |
+
output_attentions,
|
912 |
+
use_cache,
|
913 |
+
)
|
914 |
+
else:
|
915 |
+
layer_outputs = decoder_layer(
|
916 |
+
hidden_states,
|
917 |
+
attention_mask=attention_mask,
|
918 |
+
position_ids=position_ids,
|
919 |
+
past_key_value=past_key_value,
|
920 |
+
output_attentions=output_attentions,
|
921 |
+
use_cache=use_cache,
|
922 |
+
)
|
923 |
+
|
924 |
+
hidden_states = layer_outputs[0]
|
925 |
+
|
926 |
+
if use_cache:
|
927 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
928 |
+
|
929 |
+
if output_attentions:
|
930 |
+
all_self_attns += (layer_outputs[1],)
|
931 |
+
|
932 |
+
hidden_states = self.norm(hidden_states)
|
933 |
+
|
934 |
+
# add hidden states from the last decoder layer
|
935 |
+
if output_hidden_states:
|
936 |
+
all_hidden_states += (hidden_states,)
|
937 |
+
|
938 |
+
next_cache = next_decoder_cache if use_cache else None
|
939 |
+
if not return_dict:
|
940 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
941 |
+
return BaseModelOutputWithPast(
|
942 |
+
last_hidden_state=hidden_states,
|
943 |
+
past_key_values=next_cache,
|
944 |
+
hidden_states=all_hidden_states,
|
945 |
+
attentions=all_self_attns,
|
946 |
+
)
|
947 |
+
|
948 |
+
|
949 |
+
class LlamaForCausalLM(LlamaPreTrainedModel):
|
950 |
+
_tied_weights_keys = ["lm_head.weight"]
|
951 |
+
|
952 |
+
def __init__(self, config):
|
953 |
+
super().__init__(config)
|
954 |
+
self.model = LlamaModel(config)
|
955 |
+
self.vocab_size = config.vocab_size
|
956 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
957 |
+
|
958 |
+
# Initialize weights and apply final processing
|
959 |
+
self.post_init()
|
960 |
+
|
961 |
+
def get_input_embeddings(self):
|
962 |
+
return self.model.embed_tokens
|
963 |
+
|
964 |
+
def set_input_embeddings(self, value):
|
965 |
+
self.model.embed_tokens = value
|
966 |
+
|
967 |
+
def get_output_embeddings(self):
|
968 |
+
return self.lm_head
|
969 |
+
|
970 |
+
def set_output_embeddings(self, new_embeddings):
|
971 |
+
self.lm_head = new_embeddings
|
972 |
+
|
973 |
+
def set_decoder(self, decoder):
|
974 |
+
self.model = decoder
|
975 |
+
|
976 |
+
def get_decoder(self):
|
977 |
+
return self.model
|
978 |
+
|
979 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
980 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
981 |
+
def forward(
|
982 |
+
self,
|
983 |
+
input_ids: torch.LongTensor = None,
|
984 |
+
attention_mask: Optional[torch.Tensor] = None,
|
985 |
+
position_ids: Optional[torch.LongTensor] = None,
|
986 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
987 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
988 |
+
labels: Optional[torch.LongTensor] = None,
|
989 |
+
use_cache: Optional[bool] = None,
|
990 |
+
output_attentions: Optional[bool] = None,
|
991 |
+
output_hidden_states: Optional[bool] = None,
|
992 |
+
return_dict: Optional[bool] = None,
|
993 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
994 |
+
r"""
|
995 |
+
Args:
|
996 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
997 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
998 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
999 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1000 |
+
|
1001 |
+
Returns:
|
1002 |
+
|
1003 |
+
Example:
|
1004 |
+
|
1005 |
+
```python
|
1006 |
+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
1007 |
+
|
1008 |
+
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
1009 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
1010 |
+
|
1011 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
1012 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1013 |
+
|
1014 |
+
>>> # Generate
|
1015 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1016 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1017 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
1018 |
+
```"""
|
1019 |
+
|
1020 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1021 |
+
output_hidden_states = (
|
1022 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
|
1023 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1024 |
+
|
1025 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1026 |
+
outputs = self.model(
|
1027 |
+
input_ids=input_ids,
|
1028 |
+
attention_mask=attention_mask,
|
1029 |
+
position_ids=position_ids,
|
1030 |
+
past_key_values=past_key_values,
|
1031 |
+
inputs_embeds=inputs_embeds,
|
1032 |
+
use_cache=use_cache,
|
1033 |
+
output_attentions=output_attentions,
|
1034 |
+
output_hidden_states=output_hidden_states,
|
1035 |
+
return_dict=return_dict,
|
1036 |
+
)
|
1037 |
+
|
1038 |
+
hidden_states = outputs[0]
|
1039 |
+
if self.config.pretraining_tp > 1:
|
1040 |
+
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
1041 |
+
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
1042 |
+
logits = torch.cat(logits, dim=-1)
|
1043 |
+
else:
|
1044 |
+
logits = self.lm_head(hidden_states)
|
1045 |
+
logits = logits.float()
|
1046 |
+
|
1047 |
+
loss = None
|
1048 |
+
if labels is not None:
|
1049 |
+
# Shift so that tokens < n predict n
|
1050 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
1051 |
+
shift_labels = labels[..., 1:].contiguous()
|
1052 |
+
# Flatten the tokens
|
1053 |
+
loss_fct = CrossEntropyLoss()
|
1054 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1055 |
+
shift_labels = shift_labels.view(-1)
|
1056 |
+
# Enable model parallelism
|
1057 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
1058 |
+
loss = loss_fct(shift_logits, shift_labels)
|
1059 |
+
|
1060 |
+
if not return_dict:
|
1061 |
+
output = (logits,) + outputs[1:]
|
1062 |
+
return (loss,) + output if loss is not None else output
|
1063 |
+
|
1064 |
+
return CausalLMOutputWithPast(
|
1065 |
+
loss=loss,
|
1066 |
+
logits=logits,
|
1067 |
+
past_key_values=outputs.past_key_values,
|
1068 |
+
hidden_states=outputs.hidden_states,
|
1069 |
+
attentions=outputs.attentions,
|
1070 |
+
)
|
1071 |
+
|
1072 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None,
|
1073 |
+
**kwargs):
|
1074 |
+
if past_key_values is not None:
|
1075 |
+
past_length = past_key_values[0][0].shape[2]
|
1076 |
+
|
1077 |
+
# Some generation methods already pass only the last input ID
|
1078 |
+
if input_ids.shape[1] > past_length:
|
1079 |
+
remove_prefix_length = past_length
|
1080 |
+
else:
|
1081 |
+
# Default to old behavior: keep only final ID
|
1082 |
+
remove_prefix_length = input_ids.shape[1] - 1
|
1083 |
+
|
1084 |
+
input_ids = input_ids[:, remove_prefix_length:]
|
1085 |
+
|
1086 |
+
position_ids = kwargs.get("position_ids", None)
|
1087 |
+
if attention_mask is not None and position_ids is None:
|
1088 |
+
# create position_ids on the fly for batch generation
|
1089 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1090 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1091 |
+
if past_key_values:
|
1092 |
+
position_ids = position_ids[:, -input_ids.shape[1]:]
|
1093 |
+
|
1094 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1095 |
+
if inputs_embeds is not None and past_key_values is None:
|
1096 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
1097 |
+
else:
|
1098 |
+
model_inputs = {"input_ids": input_ids}
|
1099 |
+
|
1100 |
+
model_inputs.update({
|
1101 |
+
"position_ids": position_ids,
|
1102 |
+
"past_key_values": past_key_values,
|
1103 |
+
"use_cache": kwargs.get("use_cache"),
|
1104 |
+
"attention_mask": attention_mask,
|
1105 |
+
})
|
1106 |
+
return model_inputs
|
1107 |
+
|
1108 |
+
@staticmethod
|
1109 |
+
def _reorder_cache(past_key_values, beam_idx):
|
1110 |
+
reordered_past = ()
|
1111 |
+
for layer_past in past_key_values:
|
1112 |
+
reordered_past += (
|
1113 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
|
1114 |
+
return reordered_past
|
1115 |
+
|
1116 |
+
|
1117 |
+
@add_start_docstrings(
|
1118 |
+
"""
|
1119 |
+
The LLaMa Model transformer with a sequence classification head on top (linear layer).
|
1120 |
+
|
1121 |
+
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1122 |
+
(e.g. GPT-2) do.
|
1123 |
+
|
1124 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1125 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1126 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1127 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1128 |
+
each row of the batch).
|
1129 |
+
""",
|
1130 |
+
LLAMA_START_DOCSTRING,
|
1131 |
+
)
|
1132 |
+
class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
1133 |
+
|
1134 |
+
def __init__(self, config):
|
1135 |
+
super().__init__(config)
|
1136 |
+
self.num_labels = config.num_labels
|
1137 |
+
self.model = LlamaModel(config)
|
1138 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1139 |
+
|
1140 |
+
# Initialize weights and apply final processing
|
1141 |
+
self.post_init()
|
1142 |
+
|
1143 |
+
def get_input_embeddings(self):
|
1144 |
+
return self.model.embed_tokens
|
1145 |
+
|
1146 |
+
def set_input_embeddings(self, value):
|
1147 |
+
self.model.embed_tokens = value
|
1148 |
+
|
1149 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
1150 |
+
def forward(
|
1151 |
+
self,
|
1152 |
+
input_ids: torch.LongTensor = None,
|
1153 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1154 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1155 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1156 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1157 |
+
labels: Optional[torch.LongTensor] = None,
|
1158 |
+
use_cache: Optional[bool] = None,
|
1159 |
+
output_attentions: Optional[bool] = None,
|
1160 |
+
output_hidden_states: Optional[bool] = None,
|
1161 |
+
return_dict: Optional[bool] = None,
|
1162 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1163 |
+
r"""
|
1164 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1165 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1166 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1167 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1168 |
+
"""
|
1169 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1170 |
+
|
1171 |
+
transformer_outputs = self.model(
|
1172 |
+
input_ids,
|
1173 |
+
attention_mask=attention_mask,
|
1174 |
+
position_ids=position_ids,
|
1175 |
+
past_key_values=past_key_values,
|
1176 |
+
inputs_embeds=inputs_embeds,
|
1177 |
+
use_cache=use_cache,
|
1178 |
+
output_attentions=output_attentions,
|
1179 |
+
output_hidden_states=output_hidden_states,
|
1180 |
+
return_dict=return_dict,
|
1181 |
+
)
|
1182 |
+
hidden_states = transformer_outputs[0]
|
1183 |
+
logits = self.score(hidden_states)
|
1184 |
+
|
1185 |
+
if input_ids is not None:
|
1186 |
+
batch_size = input_ids.shape[0]
|
1187 |
+
else:
|
1188 |
+
batch_size = inputs_embeds.shape[0]
|
1189 |
+
|
1190 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
1191 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
1192 |
+
if self.config.pad_token_id is None:
|
1193 |
+
sequence_lengths = -1
|
1194 |
+
else:
|
1195 |
+
if input_ids is not None:
|
1196 |
+
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
|
1197 |
+
logits.device)
|
1198 |
+
else:
|
1199 |
+
sequence_lengths = -1
|
1200 |
+
|
1201 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
1202 |
+
|
1203 |
+
loss = None
|
1204 |
+
if labels is not None:
|
1205 |
+
labels = labels.to(logits.device)
|
1206 |
+
if self.config.problem_type is None:
|
1207 |
+
if self.num_labels == 1:
|
1208 |
+
self.config.problem_type = "regression"
|
1209 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1210 |
+
self.config.problem_type = "single_label_classification"
|
1211 |
+
else:
|
1212 |
+
self.config.problem_type = "multi_label_classification"
|
1213 |
+
|
1214 |
+
if self.config.problem_type == "regression":
|
1215 |
+
loss_fct = MSELoss()
|
1216 |
+
if self.num_labels == 1:
|
1217 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
1218 |
+
else:
|
1219 |
+
loss = loss_fct(pooled_logits, labels)
|
1220 |
+
elif self.config.problem_type == "single_label_classification":
|
1221 |
+
loss_fct = CrossEntropyLoss()
|
1222 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
1223 |
+
elif self.config.problem_type == "multi_label_classification":
|
1224 |
+
loss_fct = BCEWithLogitsLoss()
|
1225 |
+
loss = loss_fct(pooled_logits, labels)
|
1226 |
+
if not return_dict:
|
1227 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
1228 |
+
return ((loss,) + output) if loss is not None else output
|
1229 |
+
|
1230 |
+
return SequenceClassifierOutputWithPast(
|
1231 |
+
loss=loss,
|
1232 |
+
logits=pooled_logits,
|
1233 |
+
past_key_values=transformer_outputs.past_key_values,
|
1234 |
+
hidden_states=transformer_outputs.hidden_states,
|
1235 |
+
attentions=transformer_outputs.attentions,
|
1236 |
+
)
|
src/models_clm/modeling_llama_xformer.py
ADDED
@@ -0,0 +1,992 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
6 |
+
# and OPT implementations in this library. It has been modified from its
|
7 |
+
# original forms to accommodate minor architectural differences compared
|
8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
9 |
+
#
|
10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
11 |
+
# you may not use this file except in compliance with the License.
|
12 |
+
# You may obtain a copy of the License at
|
13 |
+
#
|
14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
15 |
+
#
|
16 |
+
# Unless required by applicable law or agreed to in writing, software
|
17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
19 |
+
# See the License for the specific language governing permissions and
|
20 |
+
# limitations under the License.
|
21 |
+
""" PyTorch LLaMA model."""
|
22 |
+
import math
|
23 |
+
from typing import List, Optional, Tuple, Union
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.utils.checkpoint
|
27 |
+
from torch import nn
|
28 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
29 |
+
|
30 |
+
from transformers.activations import ACT2FN
|
31 |
+
from transformers.modeling_outputs import (
|
32 |
+
BaseModelOutputWithPast,
|
33 |
+
CausalLMOutputWithPast,
|
34 |
+
SequenceClassifierOutputWithPast,
|
35 |
+
)
|
36 |
+
from transformers.modeling_utils import PreTrainedModel
|
37 |
+
from transformers.utils import (
|
38 |
+
add_start_docstrings,
|
39 |
+
add_start_docstrings_to_model_forward,
|
40 |
+
logging,
|
41 |
+
replace_return_docstrings,
|
42 |
+
)
|
43 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
44 |
+
import xformers.ops as xops
|
45 |
+
|
46 |
+
logger = logging.get_logger(__name__)
|
47 |
+
|
48 |
+
_CONFIG_FOR_DOC = "LlamaConfig"
|
49 |
+
|
50 |
+
|
51 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
52 |
+
def _make_causal_mask(
|
53 |
+
input_ids_shape: torch.Size,
|
54 |
+
dtype: torch.dtype,
|
55 |
+
device: torch.device,
|
56 |
+
past_key_values_length: int = 0,
|
57 |
+
):
|
58 |
+
"""
|
59 |
+
Make causal mask used for bi-directional self-attention.
|
60 |
+
"""
|
61 |
+
bsz, tgt_len = input_ids_shape
|
62 |
+
mask = torch.full(
|
63 |
+
(tgt_len, tgt_len),
|
64 |
+
torch.tensor(torch.finfo(dtype).min, device=device),
|
65 |
+
device=device,
|
66 |
+
)
|
67 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
68 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
69 |
+
mask = mask.to(dtype)
|
70 |
+
|
71 |
+
if past_key_values_length > 0:
|
72 |
+
mask = torch.cat(
|
73 |
+
[
|
74 |
+
torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device),
|
75 |
+
mask,
|
76 |
+
],
|
77 |
+
dim=-1,
|
78 |
+
)
|
79 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
80 |
+
|
81 |
+
|
82 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
83 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
84 |
+
"""
|
85 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
86 |
+
"""
|
87 |
+
bsz, src_len = mask.size()
|
88 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
89 |
+
|
90 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
91 |
+
|
92 |
+
inverted_mask = 1.0 - expanded_mask
|
93 |
+
|
94 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
95 |
+
|
96 |
+
|
97 |
+
class LlamaRMSNorm(nn.Module):
|
98 |
+
|
99 |
+
def __init__(self, hidden_size, eps=1e-6):
|
100 |
+
"""
|
101 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
102 |
+
"""
|
103 |
+
super().__init__()
|
104 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
105 |
+
self.variance_epsilon = eps
|
106 |
+
|
107 |
+
def forward(self, hidden_states):
|
108 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
109 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
110 |
+
|
111 |
+
# convert into half-precision if necessary
|
112 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
113 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
114 |
+
|
115 |
+
return self.weight * hidden_states
|
116 |
+
|
117 |
+
|
118 |
+
class LlamaRotaryEmbedding(torch.nn.Module):
|
119 |
+
|
120 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
121 |
+
super().__init__()
|
122 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
123 |
+
self.register_buffer("inv_freq", inv_freq)
|
124 |
+
|
125 |
+
# Build here to make `torch.jit.trace` work.
|
126 |
+
self.max_seq_len_cached = max_position_embeddings
|
127 |
+
t = torch.arange(
|
128 |
+
self.max_seq_len_cached,
|
129 |
+
device=self.inv_freq.device,
|
130 |
+
dtype=self.inv_freq.dtype,
|
131 |
+
)
|
132 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
133 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
134 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
135 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
136 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
137 |
+
|
138 |
+
def forward(self, x, seq_len=None):
|
139 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
140 |
+
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
141 |
+
if seq_len > self.max_seq_len_cached:
|
142 |
+
self.max_seq_len_cached = seq_len
|
143 |
+
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
144 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
145 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
146 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
147 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
148 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
149 |
+
return (
|
150 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
151 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
152 |
+
# self.cos_cached[:, :, :, ...].to(dtype=x.dtype),
|
153 |
+
# self.sin_cached[:, :, :, ...].to(dtype=x.dtype),
|
154 |
+
|
155 |
+
)
|
156 |
+
|
157 |
+
|
158 |
+
def rotate_half(x):
|
159 |
+
"""Rotates half the hidden dims of the input."""
|
160 |
+
x1 = x[..., :x.shape[-1] // 2]
|
161 |
+
x2 = x[..., x.shape[-1] // 2:]
|
162 |
+
return torch.cat((-x2, x1), dim=-1)
|
163 |
+
|
164 |
+
|
165 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
166 |
+
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
167 |
+
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
168 |
+
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
169 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
170 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
171 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
172 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
173 |
+
return q_embed, k_embed
|
174 |
+
|
175 |
+
|
176 |
+
class LlamaMLP(nn.Module):
|
177 |
+
|
178 |
+
def __init__(
|
179 |
+
self,
|
180 |
+
hidden_size: int,
|
181 |
+
intermediate_size: int,
|
182 |
+
hidden_act: str,
|
183 |
+
):
|
184 |
+
super().__init__()
|
185 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
186 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
187 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
188 |
+
self.act_fn = ACT2FN[hidden_act]
|
189 |
+
|
190 |
+
def forward(self, x):
|
191 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
192 |
+
|
193 |
+
|
194 |
+
class LlamaAttention(nn.Module):
|
195 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
196 |
+
|
197 |
+
def __init__(self, config: LlamaConfig):
|
198 |
+
super().__init__()
|
199 |
+
self.config = config
|
200 |
+
self.hidden_size = config.hidden_size
|
201 |
+
self.num_heads = config.num_attention_heads
|
202 |
+
self.head_dim = self.hidden_size // self.num_heads
|
203 |
+
self.max_position_embeddings = config.max_position_embeddings
|
204 |
+
|
205 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
206 |
+
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
207 |
+
f" and `num_heads`: {self.num_heads}).")
|
208 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
209 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
210 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
211 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
212 |
+
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
213 |
+
|
214 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
215 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
216 |
+
|
217 |
+
def forward(
|
218 |
+
self,
|
219 |
+
hidden_states: torch.Tensor,
|
220 |
+
attention_mask: Optional[torch.Tensor] = None,
|
221 |
+
position_ids: Optional[torch.LongTensor] = None,
|
222 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
223 |
+
output_attentions: bool = False,
|
224 |
+
use_cache: bool = False,
|
225 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
226 |
+
bsz, q_len, _ = hidden_states.size()
|
227 |
+
|
228 |
+
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
229 |
+
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
230 |
+
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
231 |
+
|
232 |
+
kv_seq_len = key_states.shape[-2]
|
233 |
+
if past_key_value is not None:
|
234 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
235 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
236 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
237 |
+
# [bsz, nh, t, hd]
|
238 |
+
|
239 |
+
if past_key_value is not None:
|
240 |
+
# reuse k, v, self_attention
|
241 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
242 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
243 |
+
|
244 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
245 |
+
|
246 |
+
# attn_weights
|
247 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
248 |
+
if attention_mask is None:
|
249 |
+
def lower_triangular_from_bottom_right_mask(qlen, klen, device):
|
250 |
+
"""
|
251 |
+
Create a lower triangular mask from the bottom-right corner of a matrix.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
- qlen (int): Length of the query dimension.
|
255 |
+
- klen (int): Length of the key dimension.
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
- torch.Tensor: A mask with shape (1, 1, qlen, klen) where the bottom-right triangle is True.
|
259 |
+
"""
|
260 |
+
# Create a grid of indices where rows correspond to query indices and columns to key indices
|
261 |
+
q_indices = torch.arange(qlen - 1, -1, -1, device=device).unsqueeze(1) # Reverse the query indices
|
262 |
+
k_indices = torch.arange(klen - 1, -1, -1, device=device).unsqueeze(0) # Reverse the key indices
|
263 |
+
|
264 |
+
# Generate the mask where we compare query indices to key indices
|
265 |
+
# The condition q_indices >= k_indices creates a lower triangular mask from the top-left corner
|
266 |
+
# By reversing both indices, we get the lower triangular effect from the bottom-right
|
267 |
+
mask = q_indices >= k_indices
|
268 |
+
|
269 |
+
# Reshape to (1, 1, qlen, klen) as required
|
270 |
+
return mask.unsqueeze(0).unsqueeze(0)
|
271 |
+
|
272 |
+
attention_mask = lower_triangular_from_bottom_right_mask(attn_weights.shape[-2], attn_weights.shape[-1],
|
273 |
+
device=attn_weights.device)
|
274 |
+
attn_weights = attn_weights + attention_mask
|
275 |
+
attn_weights = attn_weights[:, 0, :, :]
|
276 |
+
# attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
277 |
+
|
278 |
+
query_states = query_states.transpose(1, 2)
|
279 |
+
key_states = key_states.transpose(1, 2)
|
280 |
+
value_states = value_states.transpose(1, 2)
|
281 |
+
if self.training:
|
282 |
+
attn_output = xops.memory_efficient_attention(
|
283 |
+
query_states,
|
284 |
+
key_states,
|
285 |
+
value_states,
|
286 |
+
attn_bias=xops.LowerTriangularMask(),
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
xops_attention_mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask()
|
290 |
+
attn_output = xops.memory_efficient_attention(
|
291 |
+
query_states,
|
292 |
+
key_states,
|
293 |
+
value_states,
|
294 |
+
attn_bias=xops_attention_mask,
|
295 |
+
)
|
296 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
297 |
+
attn_output = self.o_proj(attn_output)
|
298 |
+
|
299 |
+
if not output_attentions:
|
300 |
+
attn_weights = None
|
301 |
+
return attn_output, attn_weights, past_key_value
|
302 |
+
|
303 |
+
|
304 |
+
class LlamaDecoderLayer(nn.Module):
|
305 |
+
|
306 |
+
def __init__(self, config: LlamaConfig):
|
307 |
+
super().__init__()
|
308 |
+
self.hidden_size = config.hidden_size
|
309 |
+
self.self_attn = LlamaAttention(config=config)
|
310 |
+
self.mlp = LlamaMLP(
|
311 |
+
hidden_size=self.hidden_size,
|
312 |
+
intermediate_size=config.intermediate_size,
|
313 |
+
hidden_act=config.hidden_act,
|
314 |
+
)
|
315 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
316 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
317 |
+
|
318 |
+
def forward(
|
319 |
+
self,
|
320 |
+
hidden_states: torch.Tensor,
|
321 |
+
attention_mask: Optional[torch.Tensor] = None,
|
322 |
+
position_ids: Optional[torch.LongTensor] = None,
|
323 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
324 |
+
output_attentions: Optional[bool] = False,
|
325 |
+
use_cache: Optional[bool] = False,
|
326 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
327 |
+
"""
|
328 |
+
Args:
|
329 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
330 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
331 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
332 |
+
output_attentions (`bool`, *optional*):
|
333 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
334 |
+
returned tensors for more detail.
|
335 |
+
use_cache (`bool`, *optional*):
|
336 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
337 |
+
(see `past_key_values`).
|
338 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
339 |
+
"""
|
340 |
+
|
341 |
+
residual = hidden_states
|
342 |
+
|
343 |
+
hidden_states = self.input_layernorm(hidden_states)
|
344 |
+
# Self Attention
|
345 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
346 |
+
hidden_states=hidden_states,
|
347 |
+
attention_mask=attention_mask,
|
348 |
+
position_ids=position_ids,
|
349 |
+
past_key_value=past_key_value,
|
350 |
+
output_attentions=output_attentions,
|
351 |
+
use_cache=use_cache,
|
352 |
+
)
|
353 |
+
hidden_states = residual + hidden_states
|
354 |
+
|
355 |
+
# Fully Connected
|
356 |
+
residual = hidden_states
|
357 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
358 |
+
hidden_states = self.mlp(hidden_states)
|
359 |
+
hidden_states = residual + hidden_states
|
360 |
+
|
361 |
+
outputs = (hidden_states,)
|
362 |
+
|
363 |
+
if output_attentions:
|
364 |
+
outputs += (self_attn_weights,)
|
365 |
+
|
366 |
+
if use_cache:
|
367 |
+
outputs += (present_key_value,)
|
368 |
+
return outputs
|
369 |
+
|
370 |
+
|
371 |
+
LLAMA_START_DOCSTRING = r"""
|
372 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
373 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
374 |
+
etc.)
|
375 |
+
|
376 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
377 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
378 |
+
and behavior.
|
379 |
+
|
380 |
+
Parameters:
|
381 |
+
config ([`LlamaConfig`]):
|
382 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
383 |
+
load the weights associated with the model, only the configuration. Check out the
|
384 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
385 |
+
"""
|
386 |
+
|
387 |
+
|
388 |
+
@add_start_docstrings(
|
389 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
390 |
+
LLAMA_START_DOCSTRING,
|
391 |
+
)
|
392 |
+
class LlamaPreTrainedModel(PreTrainedModel):
|
393 |
+
config_class = LlamaConfig
|
394 |
+
base_model_prefix = "model"
|
395 |
+
supports_gradient_checkpointing = True
|
396 |
+
_no_split_modules = ["LlamaDecoderLayer"]
|
397 |
+
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
398 |
+
|
399 |
+
def _init_weights(self, module):
|
400 |
+
std = self.config.initializer_range
|
401 |
+
if isinstance(module, nn.Linear):
|
402 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
403 |
+
if module.bias is not None:
|
404 |
+
module.bias.data.zero_()
|
405 |
+
elif isinstance(module, nn.Embedding):
|
406 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
407 |
+
if module.padding_idx is not None:
|
408 |
+
module.weight.data[module.padding_idx].zero_()
|
409 |
+
|
410 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
411 |
+
if isinstance(module, LlamaModel):
|
412 |
+
module.gradient_checkpointing = value
|
413 |
+
|
414 |
+
|
415 |
+
LLAMA_INPUTS_DOCSTRING = r"""
|
416 |
+
Args:
|
417 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
418 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
419 |
+
it.
|
420 |
+
|
421 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
422 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
423 |
+
|
424 |
+
[What are input IDs?](../glossary#input-ids)
|
425 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
426 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
427 |
+
|
428 |
+
- 1 for tokens that are **not masked**,
|
429 |
+
- 0 for tokens that are **masked**.
|
430 |
+
|
431 |
+
[What are attention masks?](../glossary#attention-mask)
|
432 |
+
|
433 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
434 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
435 |
+
|
436 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
437 |
+
`past_key_values`).
|
438 |
+
|
439 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
440 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
441 |
+
information on the default strategy.
|
442 |
+
|
443 |
+
- 1 indicates the head is **not masked**,
|
444 |
+
- 0 indicates the head is **masked**.
|
445 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
446 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
447 |
+
config.n_positions - 1]`.
|
448 |
+
|
449 |
+
[What are position IDs?](../glossary#position-ids)
|
450 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
451 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
452 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
453 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
454 |
+
|
455 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
456 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
457 |
+
|
458 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
459 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
460 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
461 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
462 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
463 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
464 |
+
model's internal embedding lookup matrix.
|
465 |
+
use_cache (`bool`, *optional*):
|
466 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
467 |
+
`past_key_values`).
|
468 |
+
output_attentions (`bool`, *optional*):
|
469 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
470 |
+
tensors for more detail.
|
471 |
+
output_hidden_states (`bool`, *optional*):
|
472 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
473 |
+
more detail.
|
474 |
+
return_dict (`bool`, *optional*):
|
475 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
476 |
+
"""
|
477 |
+
|
478 |
+
|
479 |
+
@add_start_docstrings(
|
480 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
481 |
+
LLAMA_START_DOCSTRING,
|
482 |
+
)
|
483 |
+
class LlamaModel(LlamaPreTrainedModel):
|
484 |
+
"""
|
485 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
486 |
+
|
487 |
+
Args:
|
488 |
+
config: LlamaConfig
|
489 |
+
"""
|
490 |
+
|
491 |
+
def __init__(self, config: LlamaConfig):
|
492 |
+
super().__init__(config)
|
493 |
+
self.padding_idx = config.pad_token_id
|
494 |
+
self.vocab_size = config.vocab_size
|
495 |
+
|
496 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
497 |
+
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
498 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
499 |
+
|
500 |
+
self.gradient_checkpointing = False
|
501 |
+
# Initialize weights and apply final processing
|
502 |
+
self.post_init()
|
503 |
+
|
504 |
+
def get_input_embeddings(self):
|
505 |
+
return self.embed_tokens
|
506 |
+
|
507 |
+
def set_input_embeddings(self, value):
|
508 |
+
self.embed_tokens = value
|
509 |
+
|
510 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
511 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
512 |
+
# create causal mask
|
513 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
514 |
+
combined_attention_mask = None
|
515 |
+
if input_shape[-1] > 1:
|
516 |
+
combined_attention_mask = _make_causal_mask(
|
517 |
+
input_shape,
|
518 |
+
inputs_embeds.dtype,
|
519 |
+
device=inputs_embeds.device,
|
520 |
+
past_key_values_length=past_key_values_length,
|
521 |
+
)
|
522 |
+
|
523 |
+
if attention_mask is not None:
|
524 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
525 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype,
|
526 |
+
tgt_len=input_shape[-1]).to(inputs_embeds.device)
|
527 |
+
combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
528 |
+
|
529 |
+
return combined_attention_mask
|
530 |
+
|
531 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
532 |
+
def forward(
|
533 |
+
self,
|
534 |
+
input_ids: torch.LongTensor = None,
|
535 |
+
attention_mask: Optional[torch.Tensor] = None,
|
536 |
+
position_ids: Optional[torch.LongTensor] = None,
|
537 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
538 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
539 |
+
use_cache: Optional[bool] = None,
|
540 |
+
output_attentions: Optional[bool] = None,
|
541 |
+
output_hidden_states: Optional[bool] = None,
|
542 |
+
return_dict: Optional[bool] = None,
|
543 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
544 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
545 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
546 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
547 |
+
|
548 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
549 |
+
|
550 |
+
# retrieve input_ids and inputs_embeds
|
551 |
+
# if input_ids is not None and inputs_embeds is not None:
|
552 |
+
# raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
553 |
+
# elif input_ids is not None:
|
554 |
+
if input_ids is not None:
|
555 |
+
batch_size, seq_length = input_ids.shape
|
556 |
+
elif inputs_embeds is not None:
|
557 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
558 |
+
else:
|
559 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
560 |
+
|
561 |
+
seq_length_with_past = seq_length
|
562 |
+
past_key_values_length = 0
|
563 |
+
|
564 |
+
if past_key_values is not None:
|
565 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
566 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
567 |
+
|
568 |
+
if position_ids is None:
|
569 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
570 |
+
position_ids = torch.arange(
|
571 |
+
past_key_values_length,
|
572 |
+
seq_length + past_key_values_length,
|
573 |
+
dtype=torch.long,
|
574 |
+
device=device,
|
575 |
+
)
|
576 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
577 |
+
else:
|
578 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
579 |
+
|
580 |
+
if inputs_embeds is None:
|
581 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
582 |
+
# embed positions
|
583 |
+
|
584 |
+
# rm when use streaming
|
585 |
+
# if attention_mask is None:
|
586 |
+
# attention_mask = torch.ones(
|
587 |
+
# (batch_size, seq_length_with_past),
|
588 |
+
# dtype=torch.bool,
|
589 |
+
# device=inputs_embeds.device,
|
590 |
+
# )
|
591 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
592 |
+
attention_mask,
|
593 |
+
(batch_size, seq_length),
|
594 |
+
inputs_embeds,
|
595 |
+
past_key_values_length,
|
596 |
+
)
|
597 |
+
|
598 |
+
hidden_states = inputs_embeds
|
599 |
+
|
600 |
+
if self.gradient_checkpointing and self.training:
|
601 |
+
if use_cache:
|
602 |
+
logger.warning_once(
|
603 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
604 |
+
use_cache = False
|
605 |
+
|
606 |
+
# decoder layers
|
607 |
+
all_hidden_states = () if output_hidden_states else None
|
608 |
+
all_self_attns = () if output_attentions else None
|
609 |
+
next_decoder_cache = () if use_cache else None
|
610 |
+
|
611 |
+
for idx, decoder_layer in enumerate(self.layers):
|
612 |
+
if output_hidden_states:
|
613 |
+
all_hidden_states += (hidden_states,)
|
614 |
+
|
615 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
616 |
+
|
617 |
+
if self.gradient_checkpointing and self.training:
|
618 |
+
|
619 |
+
def create_custom_forward(module):
|
620 |
+
|
621 |
+
def custom_forward(*inputs):
|
622 |
+
# None for past_key_value
|
623 |
+
return module(*inputs, output_attentions, None)
|
624 |
+
|
625 |
+
return custom_forward
|
626 |
+
|
627 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
628 |
+
create_custom_forward(decoder_layer),
|
629 |
+
hidden_states,
|
630 |
+
attention_mask,
|
631 |
+
position_ids,
|
632 |
+
None,
|
633 |
+
)
|
634 |
+
else:
|
635 |
+
layer_outputs = decoder_layer(
|
636 |
+
hidden_states,
|
637 |
+
attention_mask=attention_mask,
|
638 |
+
position_ids=position_ids,
|
639 |
+
past_key_value=past_key_value,
|
640 |
+
output_attentions=output_attentions,
|
641 |
+
use_cache=use_cache,
|
642 |
+
)
|
643 |
+
|
644 |
+
hidden_states = layer_outputs[0]
|
645 |
+
|
646 |
+
if use_cache:
|
647 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
648 |
+
|
649 |
+
if output_attentions:
|
650 |
+
all_self_attns += (layer_outputs[1],)
|
651 |
+
|
652 |
+
hidden_states = self.norm(hidden_states)
|
653 |
+
|
654 |
+
# add hidden states from the last decoder layer
|
655 |
+
if output_hidden_states:
|
656 |
+
all_hidden_states += (hidden_states,)
|
657 |
+
|
658 |
+
next_cache = next_decoder_cache if use_cache else None
|
659 |
+
if not return_dict:
|
660 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
661 |
+
return BaseModelOutputWithPast(
|
662 |
+
last_hidden_state=hidden_states,
|
663 |
+
past_key_values=next_cache,
|
664 |
+
hidden_states=all_hidden_states,
|
665 |
+
attentions=all_self_attns,
|
666 |
+
)
|
667 |
+
|
668 |
+
|
669 |
+
class LlamaForCausalLM(LlamaPreTrainedModel):
|
670 |
+
|
671 |
+
def __init__(self, config):
|
672 |
+
super().__init__(config)
|
673 |
+
self.model = LlamaModel(config)
|
674 |
+
|
675 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
676 |
+
self.past_key_values = None
|
677 |
+
self.kv_cache_head = None
|
678 |
+
self.use_kv_cache_head = True
|
679 |
+
# self.position_ids = None
|
680 |
+
# Initialize weights and apply final processing
|
681 |
+
self.post_init()
|
682 |
+
|
683 |
+
def get_input_embeddings(self):
|
684 |
+
return self.model.embed_tokens
|
685 |
+
|
686 |
+
def set_input_embeddings(self, value):
|
687 |
+
self.model.embed_tokens = value
|
688 |
+
|
689 |
+
def get_output_embeddings(self):
|
690 |
+
return self.lm_head
|
691 |
+
|
692 |
+
def set_output_embeddings(self, new_embeddings):
|
693 |
+
self.lm_head = new_embeddings
|
694 |
+
|
695 |
+
def set_decoder(self, decoder):
|
696 |
+
self.model = decoder
|
697 |
+
|
698 |
+
def get_decoder(self):
|
699 |
+
return self.model
|
700 |
+
|
701 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
702 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
703 |
+
def forward(
|
704 |
+
self,
|
705 |
+
input_ids: torch.LongTensor = None,
|
706 |
+
attention_mask: Optional[torch.Tensor] = None,
|
707 |
+
position_ids: Optional[torch.LongTensor] = None,
|
708 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
709 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
710 |
+
labels: Optional[torch.LongTensor] = None,
|
711 |
+
use_cache: Optional[bool] = None,
|
712 |
+
output_attentions: Optional[bool] = None,
|
713 |
+
output_hidden_states: Optional[bool] = None,
|
714 |
+
return_dict: Optional[bool] = None,
|
715 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
716 |
+
r"""
|
717 |
+
Args:
|
718 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
719 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
720 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
721 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
722 |
+
|
723 |
+
Returns:
|
724 |
+
|
725 |
+
Example:
|
726 |
+
|
727 |
+
```python
|
728 |
+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
729 |
+
|
730 |
+
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
731 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
732 |
+
|
733 |
+
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
734 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
735 |
+
|
736 |
+
>>> # Generate
|
737 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
738 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
739 |
+
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
740 |
+
```"""
|
741 |
+
|
742 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
743 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
744 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
745 |
+
|
746 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
747 |
+
outputs = self.model(
|
748 |
+
input_ids=input_ids,
|
749 |
+
attention_mask=attention_mask,
|
750 |
+
position_ids=position_ids,
|
751 |
+
past_key_values=past_key_values,
|
752 |
+
inputs_embeds=inputs_embeds,
|
753 |
+
use_cache=use_cache,
|
754 |
+
output_attentions=output_attentions,
|
755 |
+
output_hidden_states=output_hidden_states,
|
756 |
+
return_dict=return_dict,
|
757 |
+
)
|
758 |
+
hidden_states = outputs[0]
|
759 |
+
logits = self.lm_head(hidden_states)
|
760 |
+
|
761 |
+
loss = None
|
762 |
+
if labels is not None:
|
763 |
+
# Shift so that tokens < n predict n
|
764 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
765 |
+
shift_labels = labels[..., 1:].contiguous()
|
766 |
+
# Flatten the tokens
|
767 |
+
loss_fct = CrossEntropyLoss()
|
768 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
769 |
+
shift_labels = shift_labels.view(-1)
|
770 |
+
# Enable model parallelism
|
771 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
772 |
+
loss = loss_fct(shift_logits, shift_labels)
|
773 |
+
|
774 |
+
if not return_dict:
|
775 |
+
output = (logits,) + outputs[1:]
|
776 |
+
return (loss,) + output if loss is not None else output
|
777 |
+
|
778 |
+
self.past_key_values = outputs.past_key_values
|
779 |
+
|
780 |
+
if self.use_kv_cache_head and not self.training:
|
781 |
+
if self.kv_cache_head is None:
|
782 |
+
self.kv_cache_head = input_ids.shape[1]
|
783 |
+
else:
|
784 |
+
self.kv_cache_head += input_ids.shape[1]
|
785 |
+
|
786 |
+
# new_position_ids = torch.ones((1, 1), device=self.position_ids.device) * (self.position_ids[0, -1].item() + 1)
|
787 |
+
# self.position_ids = torch.cat((self.position_ids, new_position_ids), dim=1)
|
788 |
+
return CausalLMOutputWithPast(
|
789 |
+
loss=loss,
|
790 |
+
logits=logits,
|
791 |
+
past_key_values=outputs.past_key_values,
|
792 |
+
hidden_states=outputs.hidden_states,
|
793 |
+
attentions=outputs.attentions,
|
794 |
+
)
|
795 |
+
|
796 |
+
def prepare_inputs_for_generation(
|
797 |
+
self,
|
798 |
+
input_ids,
|
799 |
+
past_key_values=None,
|
800 |
+
attention_mask=None,
|
801 |
+
inputs_embeds=None,
|
802 |
+
**kwargs,
|
803 |
+
):
|
804 |
+
if self.use_kv_cache_head and not self.training:
|
805 |
+
if past_key_values:
|
806 |
+
input_ids = input_ids[:, self.kv_cache_head:]
|
807 |
+
if inputs_embeds is not None:
|
808 |
+
inputs_embeds = inputs_embeds[:, self.kv_cache_head:]
|
809 |
+
|
810 |
+
position_ids = kwargs.get("position_ids", None)
|
811 |
+
if attention_mask is not None and position_ids is None:
|
812 |
+
# create position_ids on the fly for batch generation
|
813 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
814 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
815 |
+
if past_key_values:
|
816 |
+
position_ids = position_ids[:, self.kv_cache_head:].unsqueeze(-1)
|
817 |
+
|
818 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
819 |
+
if inputs_embeds is not None and past_key_values is None:
|
820 |
+
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
|
821 |
+
elif past_key_values is not None and input_ids.shape[1] > 1:
|
822 |
+
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
|
823 |
+
else:
|
824 |
+
model_inputs = {"input_ids": input_ids}
|
825 |
+
|
826 |
+
attention_mask = None
|
827 |
+
else:
|
828 |
+
if past_key_values:
|
829 |
+
input_ids = input_ids[:, -1:]
|
830 |
+
|
831 |
+
position_ids = kwargs.get("position_ids", None)
|
832 |
+
if attention_mask is not None and position_ids is None:
|
833 |
+
# create position_ids on the fly for batch generation
|
834 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
835 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
836 |
+
if past_key_values:
|
837 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
838 |
+
|
839 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
840 |
+
if inputs_embeds is not None and past_key_values is None:
|
841 |
+
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
|
842 |
+
else:
|
843 |
+
model_inputs = {"input_ids": input_ids}
|
844 |
+
attention_mask = None
|
845 |
+
|
846 |
+
model_inputs.update({
|
847 |
+
"position_ids": position_ids,
|
848 |
+
"past_key_values": past_key_values,
|
849 |
+
"use_cache": kwargs.get("use_cache"),
|
850 |
+
"attention_mask": attention_mask,
|
851 |
+
})
|
852 |
+
return model_inputs
|
853 |
+
|
854 |
+
@staticmethod
|
855 |
+
def _reorder_cache(past_key_values, beam_idx):
|
856 |
+
reordered_past = ()
|
857 |
+
for layer_past in past_key_values:
|
858 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
859 |
+
return reordered_past
|
860 |
+
|
861 |
+
|
862 |
+
@add_start_docstrings(
|
863 |
+
"""
|
864 |
+
The LLaMa Model transformer with a sequence classification head on top (linear layer).
|
865 |
+
|
866 |
+
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
867 |
+
(e.g. GPT-2) do.
|
868 |
+
|
869 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
870 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
871 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
872 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
873 |
+
each row of the batch).
|
874 |
+
""",
|
875 |
+
LLAMA_START_DOCSTRING,
|
876 |
+
)
|
877 |
+
class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
878 |
+
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
879 |
+
|
880 |
+
def __init__(self, config):
|
881 |
+
super().__init__(config)
|
882 |
+
self.num_labels = config.num_labels
|
883 |
+
self.model = LlamaModel(config)
|
884 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
885 |
+
|
886 |
+
# Initialize weights and apply final processing
|
887 |
+
self.post_init()
|
888 |
+
|
889 |
+
def get_input_embeddings(self):
|
890 |
+
return self.model.embed_tokens
|
891 |
+
|
892 |
+
def set_input_embeddings(self, value):
|
893 |
+
self.model.embed_tokens = value
|
894 |
+
|
895 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
896 |
+
def forward(
|
897 |
+
self,
|
898 |
+
input_ids: torch.LongTensor = None,
|
899 |
+
attention_mask: Optional[torch.Tensor] = None,
|
900 |
+
position_ids: Optional[torch.LongTensor] = None,
|
901 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
902 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
903 |
+
labels: Optional[torch.LongTensor] = None,
|
904 |
+
use_cache: Optional[bool] = None,
|
905 |
+
output_attentions: Optional[bool] = None,
|
906 |
+
output_hidden_states: Optional[bool] = None,
|
907 |
+
return_dict: Optional[bool] = None,
|
908 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
909 |
+
r"""
|
910 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
911 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
912 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
913 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
914 |
+
"""
|
915 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
916 |
+
|
917 |
+
transformer_outputs = self.model(
|
918 |
+
input_ids,
|
919 |
+
attention_mask=attention_mask,
|
920 |
+
position_ids=position_ids,
|
921 |
+
past_key_values=past_key_values,
|
922 |
+
inputs_embeds=inputs_embeds,
|
923 |
+
use_cache=use_cache,
|
924 |
+
output_attentions=output_attentions,
|
925 |
+
output_hidden_states=output_hidden_states,
|
926 |
+
return_dict=return_dict,
|
927 |
+
)
|
928 |
+
hidden_states = transformer_outputs[0]
|
929 |
+
logits = self.score(hidden_states)
|
930 |
+
|
931 |
+
if input_ids is not None:
|
932 |
+
batch_size = input_ids.shape[0]
|
933 |
+
else:
|
934 |
+
batch_size = inputs_embeds.shape[0]
|
935 |
+
|
936 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
937 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
938 |
+
if self.config.pad_token_id is None:
|
939 |
+
sequence_lengths = -1
|
940 |
+
else:
|
941 |
+
if input_ids is not None:
|
942 |
+
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
943 |
+
else:
|
944 |
+
sequence_lengths = -1
|
945 |
+
|
946 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
947 |
+
|
948 |
+
loss = None
|
949 |
+
if labels is not None:
|
950 |
+
labels = labels.to(logits.device)
|
951 |
+
if self.config.problem_type is None:
|
952 |
+
if self.num_labels == 1:
|
953 |
+
self.config.problem_type = "regression"
|
954 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
955 |
+
self.config.problem_type = "single_label_classification"
|
956 |
+
else:
|
957 |
+
self.config.problem_type = "multi_label_classification"
|
958 |
+
|
959 |
+
if self.config.problem_type == "regression":
|
960 |
+
loss_fct = MSELoss()
|
961 |
+
if self.num_labels == 1:
|
962 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
963 |
+
else:
|
964 |
+
loss = loss_fct(pooled_logits, labels)
|
965 |
+
elif self.config.problem_type == "single_label_classification":
|
966 |
+
loss_fct = CrossEntropyLoss()
|
967 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
968 |
+
elif self.config.problem_type == "multi_label_classification":
|
969 |
+
loss_fct = BCEWithLogitsLoss()
|
970 |
+
loss = loss_fct(pooled_logits, labels)
|
971 |
+
if not return_dict:
|
972 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
973 |
+
return ((loss,) + output) if loss is not None else output
|
974 |
+
|
975 |
+
return SequenceClassifierOutputWithPast(
|
976 |
+
loss=loss,
|
977 |
+
logits=pooled_logits,
|
978 |
+
past_key_values=transformer_outputs.past_key_values,
|
979 |
+
hidden_states=transformer_outputs.hidden_states,
|
980 |
+
attentions=transformer_outputs.attentions,
|
981 |
+
)
|
982 |
+
|
983 |
+
|
984 |
+
if __name__ == "__main__":
|
985 |
+
from transformers import LlamaTokenizer
|
986 |
+
|
987 |
+
model = LlamaForCausalLM.from_pretrained("luodian/llama-7b-hf", device_map="auto")
|
988 |
+
tokenizer = LlamaTokenizer.from_pretrained("luodian/llama-7b-hf")
|
989 |
+
prompt = "Hey, are you consciours? Can you talk to me?"
|
990 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
991 |
+
generate_ids = model.generate(inputs.input_ids, max_length=30)
|
992 |
+
print(tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])
|
src/models_clm/models.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import LlamaForCausalLM, LlamaConfig
|
4 |
+
from transformers import LogitsProcessor, LogitsProcessorList
|
5 |
+
from .generation import AutoImageTokenGenerationProcessor
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
BOI_TOKEN = '<img>'
|
9 |
+
EOI_TOKEN = '</img>'
|
10 |
+
IMG_TOKEN = '<img_{:05d}>'
|
11 |
+
|
12 |
+
|
13 |
+
def cosine_loss(rec, target):
|
14 |
+
target = target / target.norm(dim=-1, keepdim=True)
|
15 |
+
rec = rec / rec.norm(dim=-1, keepdim=True)
|
16 |
+
rec_loss = (1 - (target * rec).sum(-1)).mean()
|
17 |
+
return rec_loss
|
18 |
+
|
19 |
+
|
20 |
+
class ContinuousLVLM(nn.Module):
|
21 |
+
|
22 |
+
def __init__(self, llm, input_resampler, output_resampler, lm_loss_scale=1.0, rec_loss_scale=1.0) -> None:
|
23 |
+
super().__init__()
|
24 |
+
self.llm = llm
|
25 |
+
self.input_resampler = input_resampler
|
26 |
+
self.output_resampler = output_resampler
|
27 |
+
self.lm_loss_scale = lm_loss_scale
|
28 |
+
self.rec_loss_scale = rec_loss_scale
|
29 |
+
|
30 |
+
# input_resampler.requires_grad_(False)
|
31 |
+
# output_resampler.requires_grad_(False)
|
32 |
+
|
33 |
+
def forward(self, input_ids, attention_mask, labels, image_embeds, embeds_gen_mask, embeds_cmp_mask, ids_gen_mask,
|
34 |
+
ids_cmp_mask, return_recon_image_embeds=False):
|
35 |
+
|
36 |
+
input_embeds = self.llm.get_input_embeddings()(input_ids) # bz x seq_len x dim, 4 x 160 x 4096
|
37 |
+
|
38 |
+
bz, sq, dim = input_embeds.shape
|
39 |
+
|
40 |
+
if image_embeds is not None:
|
41 |
+
image_embeds_lm = self.input_resampler(image_embeds) # num_imgs_in_batch x nq x dim, 4 x 64 x 4096
|
42 |
+
has_image = True
|
43 |
+
else:
|
44 |
+
image_embeds = torch.randn(bz, self.output_resampler.num_queries,
|
45 |
+
self.output_resampler.embed_dim).to(input_embeds.device,
|
46 |
+
dtype=input_embeds.dtype)
|
47 |
+
image_embeds_lm = self.input_resampler(image_embeds)
|
48 |
+
has_image = False
|
49 |
+
|
50 |
+
has_image_input = has_image and embeds_cmp_mask.sum().item() > 0
|
51 |
+
has_image_output = has_image and embeds_gen_mask.sum().item() > 0
|
52 |
+
|
53 |
+
if has_image_input:
|
54 |
+
input_embeds[ids_cmp_mask] = image_embeds_lm[embeds_cmp_mask].view(-1, dim) # eg, 128 x 4096
|
55 |
+
# zero_loss = 0.0
|
56 |
+
else:
|
57 |
+
min_bz = min(input_embeds.shape[0], image_embeds_lm.shape[0])
|
58 |
+
input_embeds[:min_bz, :self.input_resampler.
|
59 |
+
num_queries, :] = input_embeds[:min_bz, :self.input_resampler.
|
60 |
+
num_queries, :] + 0.0 * image_embeds_lm[:min_bz, :, :]
|
61 |
+
|
62 |
+
output_lm = self.llm(attention_mask=attention_mask,
|
63 |
+
inputs_embeds=input_embeds,
|
64 |
+
labels=labels,
|
65 |
+
output_hidden_states=True,
|
66 |
+
return_dict=True)
|
67 |
+
lm_loss = output_lm['loss']
|
68 |
+
|
69 |
+
last_hidden_state = output_lm.hidden_states[-1] # 4 x 160 x 4096
|
70 |
+
|
71 |
+
if has_image_output:
|
72 |
+
target_embeds = image_embeds[embeds_gen_mask] # num_imgs_gen_target x nq_in x dim_in, 2 x 256 x 4096
|
73 |
+
num_imgs_for_rec = target_embeds.shape[0]
|
74 |
+
output_image_embeds = last_hidden_state[ids_gen_mask].view(num_imgs_for_rec, -1,
|
75 |
+
dim) # 128 x 4096 -> 2 x 64 x 4096
|
76 |
+
|
77 |
+
recon_image_embeds = self.output_resampler(output_image_embeds) # 2 x 256 x 4096
|
78 |
+
|
79 |
+
rec_loss = cosine_loss(recon_image_embeds, target_embeds)
|
80 |
+
else:
|
81 |
+
output_image_embeds = torch.randn(bz, self.input_resampler.num_queries,
|
82 |
+
self.input_resampler.embed_dim).to(input_embeds.device,
|
83 |
+
dtype=input_embeds.dtype)
|
84 |
+
recon_image_embeds = self.output_resampler(output_image_embeds)
|
85 |
+
target_embeds = torch.randn(bz, self.output_resampler.num_queries,
|
86 |
+
self.output_resampler.embed_dim).to(input_embeds.device,
|
87 |
+
dtype=input_embeds.dtype)
|
88 |
+
rec_loss = cosine_loss(recon_image_embeds, target_embeds) * 0.0
|
89 |
+
|
90 |
+
total_loss = self.lm_loss_scale * lm_loss + self.rec_loss_scale * rec_loss
|
91 |
+
|
92 |
+
if return_recon_image_embeds and has_image_output:
|
93 |
+
return {'total_loss': total_loss, 'lm_loss': lm_loss, 'rec_loss': rec_loss,
|
94 |
+
'recon_image_embeds': recon_image_embeds}
|
95 |
+
else:
|
96 |
+
return {'total_loss': total_loss, 'lm_loss': lm_loss, 'rec_loss': rec_loss}
|
97 |
+
|
98 |
+
def generate(self,
|
99 |
+
tokenizer,
|
100 |
+
prompt=None,
|
101 |
+
input_ids=None,
|
102 |
+
image_embeds=None,
|
103 |
+
embeds_cmp_mask=None,
|
104 |
+
ids_cmp_mask=None,
|
105 |
+
logits_processor=None,
|
106 |
+
num_img_gen_tokens=64,
|
107 |
+
temperature=0.7,
|
108 |
+
num_beams=1,
|
109 |
+
max_new_tokens=120,
|
110 |
+
top_p=0.5,
|
111 |
+
past_key_values=None,
|
112 |
+
# position_ids=None,
|
113 |
+
dtype=torch.float16,
|
114 |
+
device='cuda'):
|
115 |
+
if logits_processor is None:
|
116 |
+
logits_processor = LogitsProcessorList()
|
117 |
+
logits_processor.append(
|
118 |
+
AutoImageTokenGenerationProcessor(tokenizer=tokenizer, num_img_gen_tokens=num_img_gen_tokens))
|
119 |
+
|
120 |
+
if prompt is not None:
|
121 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
122 |
+
|
123 |
+
if isinstance(input_ids, list):
|
124 |
+
input_ids = torch.tensor(input_ids)
|
125 |
+
|
126 |
+
input_ids = input_ids.to(device=device)
|
127 |
+
input_embeds = self.llm.get_input_embeddings()(input_ids)
|
128 |
+
bz, sq, dim = input_embeds.shape
|
129 |
+
|
130 |
+
if image_embeds is not None:
|
131 |
+
assert embeds_cmp_mask is not None and ids_cmp_mask is not None
|
132 |
+
with torch.no_grad():
|
133 |
+
image_embeds_lm = self.input_resampler(image_embeds)
|
134 |
+
|
135 |
+
input_embeds[ids_cmp_mask] = image_embeds_lm[embeds_cmp_mask].view(-1, dim)
|
136 |
+
|
137 |
+
generation_config = {
|
138 |
+
'temperature': temperature,
|
139 |
+
'num_beams': num_beams,
|
140 |
+
'max_new_tokens': max_new_tokens,
|
141 |
+
'top_p': top_p,
|
142 |
+
'do_sample': False
|
143 |
+
}
|
144 |
+
|
145 |
+
# generate_ids = self.llm.generate(input_ids=input_ids, **generation_config)
|
146 |
+
output = self.llm.generate(input_ids=input_ids,
|
147 |
+
inputs_embeds=input_embeds,
|
148 |
+
output_hidden_states=True,
|
149 |
+
return_dict_in_generate=True,
|
150 |
+
logits_processor=logits_processor,
|
151 |
+
past_key_values=past_key_values,
|
152 |
+
# position_ids=position_ids,
|
153 |
+
**generation_config)
|
154 |
+
# self.llm.base_model.model.position_ids = self.llm.base_model.model.position_ids[:, :-2]
|
155 |
+
|
156 |
+
output_past_key_values = self.llm.past_key_values
|
157 |
+
generate_ids = output.sequences[0][input_ids.shape[1]:]
|
158 |
+
generate_id_list = generate_ids.tolist()
|
159 |
+
boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
|
160 |
+
eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
|
161 |
+
|
162 |
+
attn_weights = ()
|
163 |
+
|
164 |
+
def merge_attn_weights(attn_weights):
|
165 |
+
merged_attn_weights = attn_weights[0]
|
166 |
+
|
167 |
+
# Iterate through the remaining attention weight tensors
|
168 |
+
for i, attn_weight in enumerate(attn_weights[1:]):
|
169 |
+
merged_attn_weights = F.pad(merged_attn_weights, (0, 1), "constant", float('nan'))
|
170 |
+
# Concatenate the expanded tensor to the merged tensor along the kv_len dimension
|
171 |
+
merged_attn_weights = torch.cat([merged_attn_weights, attn_weight], dim=1)
|
172 |
+
|
173 |
+
return merged_attn_weights
|
174 |
+
|
175 |
+
if output.attentions is not None:
|
176 |
+
# for idx in [0, 1, 2, 9, 16, 23, 31]:
|
177 |
+
for idx in range(32):
|
178 |
+
attn_weights += (
|
179 |
+
merge_attn_weights([output.attentions[j][idx] for j in range(len(output.attentions))]),)
|
180 |
+
|
181 |
+
# for skip image multi turn kvcache
|
182 |
+
last_hidden_states = torch.cat([hidden_state[-1] for hidden_state in output.hidden_states], dim=1)
|
183 |
+
if past_key_values is None:
|
184 |
+
last_hidden_states = last_hidden_states[0, input_ids.shape[1]:, :]
|
185 |
+
eoi_indices = torch.where(generate_ids == eoi_token_id)[0].tolist()
|
186 |
+
else:
|
187 |
+
last_hidden_states = last_hidden_states[0, :, :]
|
188 |
+
hidden_len = last_hidden_states.shape[0]
|
189 |
+
eoi_indices = torch.where(output.sequences[0][-hidden_len:] == eoi_token_id)[0].tolist()
|
190 |
+
|
191 |
+
num_gen_imgs = 1 if len(eoi_indices) > 0 else 0
|
192 |
+
|
193 |
+
text_mask = torch.ones_like(generate_ids, dtype=torch.bool)
|
194 |
+
has_img_output = num_gen_imgs > 0
|
195 |
+
if has_img_output:
|
196 |
+
img_gen_feats = []
|
197 |
+
img_gen_feats.append(last_hidden_states[eoi_indices[-1] - num_img_gen_tokens:eoi_indices[-1]])
|
198 |
+
text_mask[eoi_indices[-1] - num_img_gen_tokens:eoi_indices[-1]] = False
|
199 |
+
|
200 |
+
# for eoi_idx in eoi_indices:
|
201 |
+
# img_gen_feats.append(last_hidden_states[eoi_idx - num_img_gen_tokens:eoi_idx])
|
202 |
+
# text_mask[eoi_idx - num_img_gen_tokens:eoi_idx] = False
|
203 |
+
|
204 |
+
img_gen_feats = torch.stack(img_gen_feats)
|
205 |
+
img_gen_feat = self.output_resampler(img_gen_feats)
|
206 |
+
else:
|
207 |
+
img_gen_feat = None
|
208 |
+
|
209 |
+
text_mask[generate_ids == boi_token_id] = False
|
210 |
+
# generate_ids = generate_ids[text_mask]
|
211 |
+
generate_text = tokenizer.decode(generate_ids, skip_special_tokens=False)
|
212 |
+
|
213 |
+
return {
|
214 |
+
'text': generate_text,
|
215 |
+
'generate_ids': generate_ids,
|
216 |
+
'has_img_output': has_img_output,
|
217 |
+
'img_gen_feat': img_gen_feat,
|
218 |
+
'num_gen_imgs': num_gen_imgs,
|
219 |
+
'attn_weights': attn_weights,
|
220 |
+
'past_key_values': output_past_key_values
|
221 |
+
}
|
222 |
+
|
223 |
+
@classmethod
|
224 |
+
def from_pretrained(cls, llm, input_resampler, output_resampler, pretrained_model_path=None, **kwargs):
|
225 |
+
model = cls(llm=llm, input_resampler=input_resampler, output_resampler=output_resampler, **kwargs)
|
226 |
+
if pretrained_model_path is not None:
|
227 |
+
ckpt = torch.load(pretrained_model_path, map_location='cpu')
|
228 |
+
missing, unexpected = model.load_state_dict(ckpt, strict=False)
|
229 |
+
print('agent model, missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
|
230 |
+
return model
|
231 |
+
|
232 |
+
|
233 |
+
class SEEDLLaMAAlignGeneration(nn.Module):
|
234 |
+
|
235 |
+
def __init__(self, llm, output_resampler) -> None:
|
236 |
+
super().__init__()
|
237 |
+
|
238 |
+
self.llm = llm
|
239 |
+
self.output_resampler = output_resampler
|
240 |
+
# self.rec_loss_scale = rec_loss_scale
|
241 |
+
|
242 |
+
self.llm.requires_grad_(False)
|
243 |
+
|
244 |
+
def forward(self, input_ids, attention_mask, labels, image_embeds, embeds_gen_mask, embeds_cmp_mask, ids_gen_mask,
|
245 |
+
ids_cmp_mask):
|
246 |
+
|
247 |
+
input_embeds = self.llm.get_input_embeddings()(input_ids) # bz x seq_len x dim, 4 x 160 x 4096
|
248 |
+
|
249 |
+
bz, sq, dim = input_embeds.shape
|
250 |
+
|
251 |
+
output_lm = self.llm(attention_mask=attention_mask,
|
252 |
+
inputs_embeds=input_embeds,
|
253 |
+
labels=labels,
|
254 |
+
output_hidden_states=True,
|
255 |
+
return_dict=True)
|
256 |
+
|
257 |
+
last_hidden_state = output_lm.hidden_states[-1] # 4 x 160 x 4096
|
258 |
+
|
259 |
+
target_embeds = image_embeds[embeds_gen_mask] # num_imgs_gen_target x nq_in x dim_in, 2 x 256 x 4096
|
260 |
+
num_imgs_for_rec = target_embeds.shape[0]
|
261 |
+
output_image_embeds = last_hidden_state[ids_gen_mask].view(num_imgs_for_rec, -1,
|
262 |
+
dim) # 128 x 4096 -> 2 x 64 x 4096
|
263 |
+
|
264 |
+
recon_image_embeds = self.output_resampler(output_image_embeds) # 2 x 256 x 4096
|
265 |
+
|
266 |
+
rec_loss = cosine_loss(recon_image_embeds, target_embeds)
|
267 |
+
|
268 |
+
return {'total_loss': rec_loss, 'rec_loss': rec_loss}
|
269 |
+
|
270 |
+
@classmethod
|
271 |
+
def from_pretrained(cls, llm, output_resampler, pretrained_model_path=None, **kwargs):
|
272 |
+
model = cls(llm=llm, output_resampler=output_resampler, **kwargs)
|
273 |
+
if pretrained_model_path is not None:
|
274 |
+
ckpt = torch.load(pretrained_model_path, map_location='cpu')
|
275 |
+
missing, unexpected = model.load_state_dict(ckpt, strict=False)
|
276 |
+
print('agent model, missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
|
277 |
+
return model
|
278 |
+
|
279 |
+
def generate(self,
|
280 |
+
tokenizer,
|
281 |
+
input_ids=None,
|
282 |
+
temperature=0.7,
|
283 |
+
num_beams=1,
|
284 |
+
max_new_tokens=120,
|
285 |
+
num_img_gen_tokens=64,
|
286 |
+
top_p=0.5,
|
287 |
+
dtype=torch.float16,
|
288 |
+
device='cuda'):
|
289 |
+
input_ids = input_ids.to(device=device)
|
290 |
+
input_embeds = self.llm.get_input_embeddings()(input_ids) # bz x seq_len x dim, 4 x 160 x 4096
|
291 |
+
|
292 |
+
generation_config = {
|
293 |
+
'temperature': temperature,
|
294 |
+
'num_beams': num_beams,
|
295 |
+
'max_new_tokens': max_new_tokens,
|
296 |
+
'top_p': top_p,
|
297 |
+
'do_sample': False
|
298 |
+
}
|
299 |
+
output = self.llm.generate(input_ids=input_ids,
|
300 |
+
inputs_embeds=input_embeds,
|
301 |
+
output_hidden_states=True,
|
302 |
+
return_dict_in_generate=True,
|
303 |
+
**generation_config)
|
304 |
+
|
305 |
+
generate_ids = output.sequences[0][input_ids.shape[1]:]
|
306 |
+
generate_id_list = generate_ids.tolist()
|
307 |
+
# boi_token_id = tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
|
308 |
+
eoi_token_id = tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
|
309 |
+
|
310 |
+
# print('output ids: ', generate_ids, generate_ids.shape)
|
311 |
+
# last_hidden_states = output.hidden_states[-1]
|
312 |
+
|
313 |
+
last_hidden_states = torch.cat([hidden_state[-1] for hidden_state in output.hidden_states],
|
314 |
+
dim=1)[:1, input_ids.shape[1]:, :]
|
315 |
+
|
316 |
+
has_img_output = eoi_token_id in generate_id_list
|
317 |
+
|
318 |
+
if has_img_output:
|
319 |
+
# print(boi_token_id, generate_id_list, generate_id_list.index(boi_token_id))
|
320 |
+
# boi_idx = generate_id_list.index(boi_token_id)
|
321 |
+
eoi_idx = generate_id_list.index(eoi_token_id)
|
322 |
+
print(len(generate_id_list), generate_id_list, eoi_idx)
|
323 |
+
# print(generate_id_list[boi_idx + 1:boi_idx + 1 + num_img_gen_tokens])
|
324 |
+
|
325 |
+
# img_gen_feat = last_hidden_states[:, eoi_idx - num_img_gen_tokens:eoi_idx]
|
326 |
+
img_gen_feat = last_hidden_states[:, 0:eoi_idx]
|
327 |
+
print('img_gen_feat', img_gen_feat.shape, last_hidden_states.shape, num_img_gen_tokens)
|
328 |
+
img_gen_feat = self.output_resampler(img_gen_feat)
|
329 |
+
|
330 |
+
else:
|
331 |
+
img_gen_feat = None
|
332 |
+
|
333 |
+
generate_text = tokenizer.decode(generate_ids, skip_special_tokens=False)
|
334 |
+
# print('output keys: ', output.keys())
|
335 |
+
|
336 |
+
return {'text': generate_text, 'has_img_output': has_img_output, 'img_gen_feat': img_gen_feat}
|
src/models_clm/peft_models.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from peft import (
|
2 |
+
LoraConfig,
|
3 |
+
PeftModel,
|
4 |
+
LoraModel,
|
5 |
+
PeftModelForCausalLM,
|
6 |
+
get_peft_model,
|
7 |
+
get_peft_model_state_dict,
|
8 |
+
prepare_model_for_int8_training,
|
9 |
+
set_peft_model_state_dict,
|
10 |
+
)
|
11 |
+
from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING
|
12 |
+
from peft.utils import _set_trainable, PromptLearningConfig
|
13 |
+
from peft.utils import PeftConfig
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from transformers import LlamaForCausalLM
|
17 |
+
from omegaconf import DictConfig
|
18 |
+
import hydra
|
19 |
+
|
20 |
+
|
21 |
+
def get_peft_model_with_resize_embedding(
|
22 |
+
model,
|
23 |
+
peft_config=None,
|
24 |
+
model_id=None,
|
25 |
+
vocab_size=None,
|
26 |
+
torch_dtype='bf16'
|
27 |
+
):
|
28 |
+
if torch_dtype == 'bf16' or torch_dtype == 'bfloat16':
|
29 |
+
torch_dtype = torch.bfloat16
|
30 |
+
elif torch_dtype == 'fp16' or torch_dtype == 'float16':
|
31 |
+
torch_dtype = torch.float16
|
32 |
+
else:
|
33 |
+
torch_dtype = torch.float32
|
34 |
+
|
35 |
+
if isinstance(model, DictConfig):
|
36 |
+
model = hydra.utils.instantiate(model, torch_dtype=torch_dtype)
|
37 |
+
|
38 |
+
# model.gradient_checkpointing_enable()
|
39 |
+
|
40 |
+
assert (peft_config is None) + (model_id is None) == 1
|
41 |
+
|
42 |
+
# print(type(peft_config.target_modules))
|
43 |
+
if vocab_size is not None:
|
44 |
+
print(f'Length of tokenizer and resize embedding: {vocab_size}')
|
45 |
+
model.resize_token_embeddings(vocab_size)
|
46 |
+
|
47 |
+
if peft_config is not None:
|
48 |
+
print('peft config: ', peft_config)
|
49 |
+
peft_model = get_peft_model(model=model, peft_config=peft_config)
|
50 |
+
peft_model.get_input_embeddings().requires_grad_(True)
|
51 |
+
peft_model.get_output_embeddings().requires_grad_(True)
|
52 |
+
|
53 |
+
peft_model.print_trainable_parameters()
|
54 |
+
|
55 |
+
# param_count = 0
|
56 |
+
# if peft_model.modules_to_save is not None:
|
57 |
+
# for name, param in peft_model.named_parameters():
|
58 |
+
# if any(module_name in name for module_name in peft_model.modules_to_save):
|
59 |
+
# param_count += param.numel()
|
60 |
+
# print(name, param.numel())
|
61 |
+
|
62 |
+
else:
|
63 |
+
peft_model = PeftModel.from_pretrained(model=model, model_id=model_id)
|
64 |
+
|
65 |
+
return peft_model
|
66 |
+
|
67 |
+
|
68 |
+
def get_model_with_resize_embedding(model, vocab_size=None, torch_dtype='bf16'):
|
69 |
+
if torch_dtype == 'bf16' or torch_dtype == 'bfloat16':
|
70 |
+
torch_dtype = torch.bfloat16
|
71 |
+
elif torch_dtype == 'fp16' or torch_dtype == 'float16':
|
72 |
+
torch_dtype = torch.float16
|
73 |
+
else:
|
74 |
+
torch_dtype = torch.float32
|
75 |
+
|
76 |
+
if isinstance(model, DictConfig):
|
77 |
+
model = hydra.utils.instantiate(model, torch_dtype=torch_dtype)
|
78 |
+
|
79 |
+
model.requires_grad_(False)
|
80 |
+
if vocab_size is not None:
|
81 |
+
print(f'Length of tokenizer and resize embedding: {vocab_size}')
|
82 |
+
model.resize_token_embeddings(vocab_size)
|
83 |
+
model.get_input_embeddings().requires_grad_(True)
|
84 |
+
model.get_output_embeddings().requires_grad_(True)
|
85 |
+
|
86 |
+
return model
|
87 |
+
|
88 |
+
|
89 |
+
def get_full_model_with_resize_embedding(model, vocab_size=None, torch_dtype='bf16'):
|
90 |
+
if torch_dtype == 'bf16' or torch_dtype == 'bfloat16':
|
91 |
+
torch_dtype = torch.bfloat16
|
92 |
+
elif torch_dtype == 'fp16' or torch_dtype == 'float16':
|
93 |
+
torch_dtype = torch.float16
|
94 |
+
else:
|
95 |
+
torch_dtype = torch.float32
|
96 |
+
|
97 |
+
if isinstance(model, DictConfig):
|
98 |
+
model = hydra.utils.instantiate(model, torch_dtype=torch_dtype)
|
99 |
+
|
100 |
+
if vocab_size is not None:
|
101 |
+
print(f'Length of tokenizer and resize embedding: {vocab_size}')
|
102 |
+
model.resize_token_embeddings(vocab_size)
|
103 |
+
|
104 |
+
return model
|
src/models_ipa/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
src/models_ipa/adapter_modules.py
ADDED
@@ -0,0 +1,920 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import itertools
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from typing import List
|
6 |
+
from diffusers import (
|
7 |
+
StableDiffusionPipeline,
|
8 |
+
StableDiffusionXLPipeline,
|
9 |
+
StableDiffusionXLInstructPix2PixPipeline,
|
10 |
+
StableDiffusionInstructPix2PixPipeline,
|
11 |
+
)
|
12 |
+
from PIL import Image
|
13 |
+
from .ipa_utils import is_torch2_available
|
14 |
+
|
15 |
+
if is_torch2_available():
|
16 |
+
from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
|
17 |
+
else:
|
18 |
+
from .attention_processor import IPAttnProcessor, AttnProcessor
|
19 |
+
|
20 |
+
from diffusers.loaders import LoraLoaderMixin
|
21 |
+
from diffusers.models.lora import LoRALinearLayer
|
22 |
+
from diffusers.models.unet_2d_blocks import DownBlock2D
|
23 |
+
|
24 |
+
|
25 |
+
# from .pipeline_stable_diffusion_xl_t2i_edit import StableDiffusionXLText2ImageAndEditPipeline
|
26 |
+
# from .pipeline_stable_diffusion_t2i_edit import StableDiffusionText2ImageAndEditPipeline
|
27 |
+
|
28 |
+
|
29 |
+
class IPAdapterSD(nn.Module):
|
30 |
+
|
31 |
+
def __init__(self, unet, resampler) -> None:
|
32 |
+
super().__init__()
|
33 |
+
self.unet = unet
|
34 |
+
self.resampler = resampler
|
35 |
+
self.set_ip_adapter()
|
36 |
+
self.set_trainable()
|
37 |
+
|
38 |
+
def set_ip_adapter(self):
|
39 |
+
attn_procs = {}
|
40 |
+
unet_sd = self.unet.state_dict()
|
41 |
+
for name in self.unet.attn_processors.keys():
|
42 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
|
43 |
+
if name.startswith("mid_block"):
|
44 |
+
hidden_size = self.unet.config.block_out_channels[-1]
|
45 |
+
elif name.startswith("up_blocks"):
|
46 |
+
block_id = int(name[len("up_blocks.")])
|
47 |
+
hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
|
48 |
+
elif name.startswith("down_blocks"):
|
49 |
+
block_id = int(name[len("down_blocks.")])
|
50 |
+
hidden_size = self.unet.config.block_out_channels[block_id]
|
51 |
+
if cross_attention_dim is None:
|
52 |
+
attn_procs[name] = AttnProcessor()
|
53 |
+
else:
|
54 |
+
layer_name = name.split(".processor")[0]
|
55 |
+
weights = {
|
56 |
+
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
57 |
+
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
58 |
+
}
|
59 |
+
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
60 |
+
attn_procs[name].load_state_dict(weights)
|
61 |
+
self.unet.set_attn_processor(attn_procs)
|
62 |
+
self.adapter = torch.nn.ModuleList(self.unet.attn_processors.values())
|
63 |
+
|
64 |
+
def set_trainable(self):
|
65 |
+
self.unet.requires_grad_(False)
|
66 |
+
self.resampler.requires_grad_(True)
|
67 |
+
self.adapter.requires_grad_(True)
|
68 |
+
|
69 |
+
def params_to_opt(self):
|
70 |
+
return itertools.chain(self.resampler.parameters(), self.adapter.parameters())
|
71 |
+
|
72 |
+
def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise):
|
73 |
+
|
74 |
+
image_embeds = self.resampler(image_embeds)
|
75 |
+
# image_embeds = image_embeds.to(dtype=text_embeds.dtype)
|
76 |
+
|
77 |
+
text_embeds = torch.cat([text_embeds, image_embeds], dim=1)
|
78 |
+
# Predict the noise residual and compute loss
|
79 |
+
noise_pred = self.unet(noisy_latents, timesteps, text_embeds).sample
|
80 |
+
|
81 |
+
# if noise is not None:
|
82 |
+
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
83 |
+
# else:
|
84 |
+
# loss = torch.tensor(0.0, device=noisy_latents)
|
85 |
+
|
86 |
+
return {'total_loss': loss, 'noise_pred': noise_pred}
|
87 |
+
|
88 |
+
def encode_image_embeds(self, image_embeds):
|
89 |
+
dtype = image_embeds.dtype
|
90 |
+
image_embeds = self.resampler(image_embeds)
|
91 |
+
image_embeds = image_embeds.to(dtype=dtype)
|
92 |
+
return image_embeds
|
93 |
+
|
94 |
+
@classmethod
|
95 |
+
def from_pretrained(cls,
|
96 |
+
unet,
|
97 |
+
resampler,
|
98 |
+
pretrained_model_path=None,
|
99 |
+
pretrained_resampler_path=None,
|
100 |
+
pretrained_adapter_path=None):
|
101 |
+
model = cls(unet=unet, resampler=resampler)
|
102 |
+
if pretrained_model_path is not None:
|
103 |
+
ckpt = torch.load(pretrained_model_path, map_location='cpu')
|
104 |
+
missing, unexpected = model.load_state_dict(ckpt, strict=False)
|
105 |
+
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
|
106 |
+
if pretrained_resampler_path is not None:
|
107 |
+
ckpt = torch.load(pretrained_resampler_path, map_location='cpu')
|
108 |
+
missing, unexpected = model.resampler.load_state_dict(ckpt, strict=True)
|
109 |
+
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
|
110 |
+
if pretrained_adapter_path is not None:
|
111 |
+
ckpt = torch.load(pretrained_adapter_path, map_location='cpu')
|
112 |
+
missing, unexpected = model.adapter.load_state_dict(ckpt, strict=True)
|
113 |
+
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
|
114 |
+
return model
|
115 |
+
|
116 |
+
@classmethod
|
117 |
+
def from_pretrained_legacy(cls, unet, resampler, pretrained_model_path=None):
|
118 |
+
model = cls(unet=unet, resampler=resampler)
|
119 |
+
if pretrained_model_path is not None:
|
120 |
+
ckpt = torch.load(pretrained_model_path, map_location='cpu')
|
121 |
+
ckpt_image_proj = {}
|
122 |
+
ckpt_ip_layers = {}
|
123 |
+
|
124 |
+
for key, value in ckpt.items():
|
125 |
+
if key.startswith('image_proj_model'):
|
126 |
+
new_key = key.replace('image_proj_model.', '')
|
127 |
+
ckpt_image_proj[new_key] = value
|
128 |
+
elif key.startswith('adapter_modules.'):
|
129 |
+
new_key = key.replace('adapter_modules.', '')
|
130 |
+
ckpt_ip_layers[new_key] = value
|
131 |
+
|
132 |
+
missing, unexpected = model.resampler.load_state_dict(ckpt_image_proj, strict=True)
|
133 |
+
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
|
134 |
+
missing, unexpected = model.adapter.load_state_dict(ckpt_ip_layers, strict=True)
|
135 |
+
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
|
136 |
+
|
137 |
+
return model
|
138 |
+
|
139 |
+
|
140 |
+
class IPAdapterSDPipe(nn.Module):
|
141 |
+
|
142 |
+
def __init__(
|
143 |
+
self,
|
144 |
+
ip_adapter,
|
145 |
+
discrete_model,
|
146 |
+
vae,
|
147 |
+
visual_encoder,
|
148 |
+
text_encoder,
|
149 |
+
tokenizer,
|
150 |
+
scheduler,
|
151 |
+
image_transform,
|
152 |
+
device,
|
153 |
+
dtype,
|
154 |
+
) -> None:
|
155 |
+
super().__init__()
|
156 |
+
|
157 |
+
self.ip_adapter = ip_adapter
|
158 |
+
self.vae = vae
|
159 |
+
self.visual_encoder = visual_encoder
|
160 |
+
self.text_encoder = text_encoder
|
161 |
+
self.tokenizer = tokenizer
|
162 |
+
self.scheduler = scheduler
|
163 |
+
self.image_transform = image_transform
|
164 |
+
self.discrete_model = discrete_model
|
165 |
+
self.device = device
|
166 |
+
self.dtype = dtype
|
167 |
+
|
168 |
+
self.sd_pipe = StableDiffusionPipeline(vae=vae,
|
169 |
+
text_encoder=text_encoder,
|
170 |
+
tokenizer=tokenizer,
|
171 |
+
unet=ip_adapter.unet,
|
172 |
+
scheduler=scheduler,
|
173 |
+
safety_checker=None,
|
174 |
+
feature_extractor=None,
|
175 |
+
requires_safety_checker=False)
|
176 |
+
|
177 |
+
def set_scale(self, scale):
|
178 |
+
for attn_processor in self.sd_pipe.unet.attn_processors.values():
|
179 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
180 |
+
attn_processor.scale = scale
|
181 |
+
|
182 |
+
@torch.inference_mode()
|
183 |
+
def get_image_embeds(self, image_pil=None, image_tensor=None, return_negative=True):
|
184 |
+
assert int(image_pil is not None) + int(image_tensor is not None) == 1
|
185 |
+
if image_pil is not None:
|
186 |
+
image_tensor = self.image_transform(image_pil).unsqueeze(0).to(self.device, dtype=self.dtype)
|
187 |
+
if return_negative:
|
188 |
+
image_tensor_neg = torch.zeros_like(image_tensor)
|
189 |
+
image_tensor = torch.cat([image_tensor, image_tensor_neg], dim=0)
|
190 |
+
with torch.cuda.amp.autocast(dtype=self.dtype):
|
191 |
+
image_embeds = self.visual_encoder(image_tensor)
|
192 |
+
image_embeds = self.discrete_model.encode_image_embeds(image_embeds)
|
193 |
+
image_embeds = self.ip_adapter.encode_image_embeds(image_embeds)
|
194 |
+
|
195 |
+
if return_negative:
|
196 |
+
# bz = image_embeds.shape[0]
|
197 |
+
# image_embeds_neg = image_embeds[bz // 2:]
|
198 |
+
# image_embeds = image_embeds[0:bz // 2]
|
199 |
+
image_embeds, image_embeds_neg = image_embeds.chunk(2)
|
200 |
+
else:
|
201 |
+
image_embeds_neg = None
|
202 |
+
|
203 |
+
return image_embeds, image_embeds_neg
|
204 |
+
|
205 |
+
def generate(self,
|
206 |
+
image_pil=None,
|
207 |
+
image_tensor=None,
|
208 |
+
prompt=None,
|
209 |
+
negative_prompt=None,
|
210 |
+
scale=1.0,
|
211 |
+
num_samples=1,
|
212 |
+
seed=42,
|
213 |
+
guidance_scale=7.5,
|
214 |
+
num_inference_steps=30,
|
215 |
+
**kwargs):
|
216 |
+
self.set_scale(scale)
|
217 |
+
assert int(image_pil is not None) + int(image_tensor is not None) == 1
|
218 |
+
|
219 |
+
if image_pil is not None:
|
220 |
+
assert isinstance(image_pil, Image.Image)
|
221 |
+
num_prompts = 1
|
222 |
+
else:
|
223 |
+
num_prompts = image_tensor.shape[0]
|
224 |
+
|
225 |
+
if prompt is None:
|
226 |
+
# prompt = "best quality, high quality"
|
227 |
+
prompt = ""
|
228 |
+
if negative_prompt is None:
|
229 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
230 |
+
|
231 |
+
if not isinstance(prompt, List):
|
232 |
+
prompt = [prompt] * num_prompts
|
233 |
+
if not isinstance(negative_prompt, List):
|
234 |
+
negative_prompt = [negative_prompt] * num_prompts
|
235 |
+
|
236 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
|
237 |
+
image_pil=image_pil,
|
238 |
+
image_tensor=image_tensor,
|
239 |
+
return_negative=True,
|
240 |
+
)
|
241 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
242 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
243 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
244 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
245 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
246 |
+
|
247 |
+
with torch.inference_mode():
|
248 |
+
prompt_embeds, negative_prompt_embeds = self.sd_pipe.encode_prompt(
|
249 |
+
prompt,
|
250 |
+
device=self.device,
|
251 |
+
num_images_per_prompt=num_samples,
|
252 |
+
do_classifier_free_guidance=True,
|
253 |
+
negative_prompt=negative_prompt,
|
254 |
+
)
|
255 |
+
|
256 |
+
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
257 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
258 |
+
|
259 |
+
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
260 |
+
images = self.sd_pipe(
|
261 |
+
prompt_embeds=prompt_embeds,
|
262 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
263 |
+
guidance_scale=guidance_scale,
|
264 |
+
num_inference_steps=num_inference_steps,
|
265 |
+
generator=generator,
|
266 |
+
**kwargs,
|
267 |
+
).images
|
268 |
+
|
269 |
+
return images
|
270 |
+
|
271 |
+
|
272 |
+
def compute_time_ids(original_size, crops_coords_top_left, target_resolution):
|
273 |
+
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
274 |
+
target_size = (target_resolution, target_resolution)
|
275 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
276 |
+
add_time_ids = torch.tensor([add_time_ids])
|
277 |
+
# add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
278 |
+
return add_time_ids
|
279 |
+
|
280 |
+
|
281 |
+
class SDXLAdapter(nn.Module):
|
282 |
+
|
283 |
+
def __init__(self, unet, resampler, full_ft=False) -> None:
|
284 |
+
super().__init__()
|
285 |
+
self.unet = unet
|
286 |
+
self.resampler = resampler
|
287 |
+
self.full_ft = full_ft
|
288 |
+
self.set_trainable_v2()
|
289 |
+
# self.set_adapter()
|
290 |
+
|
291 |
+
# self.set_trainable()
|
292 |
+
|
293 |
+
# def set_adapter(self):
|
294 |
+
|
295 |
+
# adapter = []
|
296 |
+
# for name, module in self.unet.named_modules():
|
297 |
+
# if name.endswith('to_k') or name.endswith('to_v'):
|
298 |
+
# if module is not None:
|
299 |
+
# adapter.append(module)
|
300 |
+
|
301 |
+
# self.adapter = torch.nn.ModuleList(adapter)
|
302 |
+
# print(f'adapter: {self.adapter}')
|
303 |
+
|
304 |
+
# def set_trainable(self):
|
305 |
+
# self.unet.requires_grad_(False)
|
306 |
+
# self.resampler.requires_grad_(True)
|
307 |
+
# self.adapter.requires_grad_(True)
|
308 |
+
|
309 |
+
def set_trainable_v2(self):
|
310 |
+
self.resampler.requires_grad_(True)
|
311 |
+
adapter_parameters = []
|
312 |
+
if self.full_ft:
|
313 |
+
self.unet.requires_grad_(True)
|
314 |
+
adapter_parameters.extend(self.unet.parameters())
|
315 |
+
else:
|
316 |
+
self.unet.requires_grad_(False)
|
317 |
+
for name, module in self.unet.named_modules():
|
318 |
+
if name.endswith('to_k') or name.endswith('to_v'):
|
319 |
+
if module is not None:
|
320 |
+
adapter_parameters.extend(module.parameters())
|
321 |
+
self.adapter_parameters = adapter_parameters
|
322 |
+
for param in self.adapter_parameters:
|
323 |
+
param.requires_grad_(True)
|
324 |
+
|
325 |
+
# def params_to_opt(self):
|
326 |
+
# return itertools.chain(self.resampler.parameters(), self.adapter.parameters())
|
327 |
+
def params_to_opt(self):
|
328 |
+
return itertools.chain(self.resampler.parameters(), self.adapter_parameters)
|
329 |
+
|
330 |
+
def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise, time_ids):
|
331 |
+
|
332 |
+
image_embeds, pooled_image_embeds = self.resampler(image_embeds)
|
333 |
+
|
334 |
+
unet_added_conditions = {"time_ids": time_ids, 'text_embeds': pooled_image_embeds}
|
335 |
+
|
336 |
+
noise_pred = self.unet(noisy_latents, timesteps, image_embeds, added_cond_kwargs=unet_added_conditions).sample
|
337 |
+
|
338 |
+
# if noise is not None:
|
339 |
+
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
340 |
+
# else:
|
341 |
+
# loss = torch.tensor(0.0, device=noisy_latents)
|
342 |
+
|
343 |
+
return {'total_loss': loss, 'noise_pred': noise_pred}
|
344 |
+
|
345 |
+
def encode_image_embeds(self, image_embeds):
|
346 |
+
image_embeds, pooled_image_embeds = self.resampler(image_embeds)
|
347 |
+
|
348 |
+
return image_embeds, pooled_image_embeds
|
349 |
+
|
350 |
+
@classmethod
|
351 |
+
def from_pretrained(cls, unet, resampler, pretrained_model_path=None, **kwargs):
|
352 |
+
model = cls(unet=unet, resampler=resampler, **kwargs)
|
353 |
+
if pretrained_model_path is not None:
|
354 |
+
ckpt = torch.load(pretrained_model_path, map_location='cpu')
|
355 |
+
missing, unexpected = model.load_state_dict(ckpt, strict=False)
|
356 |
+
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
|
357 |
+
return model
|
358 |
+
|
359 |
+
def init_pipe(self,
|
360 |
+
vae,
|
361 |
+
scheduler,
|
362 |
+
visual_encoder,
|
363 |
+
image_transform,
|
364 |
+
discrete_model=None,
|
365 |
+
dtype=torch.float16,
|
366 |
+
device='cuda'):
|
367 |
+
self.device = device
|
368 |
+
self.dtype = dtype
|
369 |
+
sdxl_pipe = StableDiffusionXLPipeline(tokenizer=None,
|
370 |
+
tokenizer_2=None,
|
371 |
+
text_encoder=None,
|
372 |
+
text_encoder_2=None,
|
373 |
+
vae=vae,
|
374 |
+
unet=self.unet,
|
375 |
+
scheduler=scheduler)
|
376 |
+
|
377 |
+
self.sdxl_pipe = sdxl_pipe # .to(self.device, dtype=self.dtype)
|
378 |
+
# print(sdxl_pipe.text_encoder_2, sdxl_pipe.text_encoder)
|
379 |
+
|
380 |
+
self.visual_encoder = visual_encoder.to(self.device, dtype=self.dtype)
|
381 |
+
if discrete_model is not None:
|
382 |
+
self.discrete_model = discrete_model.to(self.device, dtype=self.dtype)
|
383 |
+
else:
|
384 |
+
self.discrete_model = None
|
385 |
+
self.image_transform = image_transform
|
386 |
+
|
387 |
+
@torch.inference_mode()
|
388 |
+
def get_image_embeds(self,
|
389 |
+
image_pil=None,
|
390 |
+
image_tensor=None,
|
391 |
+
image_embeds=None,
|
392 |
+
return_negative=True,
|
393 |
+
image_size=448
|
394 |
+
):
|
395 |
+
assert int(image_pil is not None) + int(image_tensor is not None) + int(image_embeds is not None) == 1
|
396 |
+
|
397 |
+
if image_pil is not None:
|
398 |
+
image_tensor = self.image_transform(image_pil).unsqueeze(0).to(self.device, dtype=self.dtype)
|
399 |
+
|
400 |
+
if image_tensor is not None:
|
401 |
+
if return_negative:
|
402 |
+
image_tensor_neg = torch.zeros_like(image_tensor)
|
403 |
+
image_tensor = torch.cat([image_tensor, image_tensor_neg], dim=0)
|
404 |
+
|
405 |
+
image_embeds = self.visual_encoder(image_tensor)
|
406 |
+
elif return_negative:
|
407 |
+
image_tensor_neg = torch.zeros(
|
408 |
+
1, 3,
|
409 |
+
image_size, image_size
|
410 |
+
).to(
|
411 |
+
image_embeds.device, dtype=image_embeds.dtype
|
412 |
+
)
|
413 |
+
image_embeds_neg = self.visual_encoder(image_tensor_neg)
|
414 |
+
image_embeds = torch.cat([image_embeds, image_embeds_neg], dim=0)
|
415 |
+
|
416 |
+
if self.discrete_model is not None:
|
417 |
+
image_embeds = self.discrete_model.encode_image_embeds(image_embeds)
|
418 |
+
image_embeds, pooled_image_embeds = self.encode_image_embeds(image_embeds)
|
419 |
+
|
420 |
+
if return_negative:
|
421 |
+
image_embeds, image_embeds_neg = image_embeds.chunk(2)
|
422 |
+
pooled_image_embeds, pooled_image_embeds_neg = pooled_image_embeds.chunk(2)
|
423 |
+
|
424 |
+
else:
|
425 |
+
image_embeds_neg = None
|
426 |
+
pooled_image_embeds_neg = None
|
427 |
+
|
428 |
+
return image_embeds, image_embeds_neg, pooled_image_embeds, pooled_image_embeds_neg
|
429 |
+
|
430 |
+
def generate(self,
|
431 |
+
image_pil=None,
|
432 |
+
image_tensor=None,
|
433 |
+
image_embeds=None,
|
434 |
+
seed=42,
|
435 |
+
height=1024,
|
436 |
+
width=1024,
|
437 |
+
guidance_scale=7.5,
|
438 |
+
num_inference_steps=30,
|
439 |
+
input_image_size=448,
|
440 |
+
**kwargs):
|
441 |
+
if image_pil is not None:
|
442 |
+
assert isinstance(image_pil, Image.Image)
|
443 |
+
|
444 |
+
image_prompt_embeds, uncond_image_prompt_embeds, pooled_image_prompt_embeds, \
|
445 |
+
pooled_uncond_image_prompt_embeds = self.get_image_embeds(
|
446 |
+
image_pil=image_pil,
|
447 |
+
image_tensor=image_tensor,
|
448 |
+
image_embeds=image_embeds,
|
449 |
+
return_negative=True,
|
450 |
+
image_size=input_image_size,
|
451 |
+
)
|
452 |
+
# print(image_prompt_embeds.shape, pooled_image_prompt_embeds.shape)
|
453 |
+
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
454 |
+
|
455 |
+
images = self.sdxl_pipe(
|
456 |
+
prompt_embeds=image_prompt_embeds,
|
457 |
+
negative_prompt_embeds=uncond_image_prompt_embeds,
|
458 |
+
pooled_prompt_embeds=pooled_image_prompt_embeds,
|
459 |
+
negative_pooled_prompt_embeds=pooled_uncond_image_prompt_embeds,
|
460 |
+
guidance_scale=guidance_scale,
|
461 |
+
num_inference_steps=num_inference_steps,
|
462 |
+
generator=generator,
|
463 |
+
height=height,
|
464 |
+
width=width,
|
465 |
+
**kwargs,
|
466 |
+
).images
|
467 |
+
|
468 |
+
return images
|
469 |
+
|
470 |
+
|
471 |
+
class SDXLText2ImageAndEditAdapter(nn.Module):
|
472 |
+
|
473 |
+
def __init__(self, unet, resampler, lora_rank=16, fully_ft=False) -> None:
|
474 |
+
super().__init__()
|
475 |
+
|
476 |
+
self.unet = unet
|
477 |
+
self.resampler = resampler
|
478 |
+
self.lora_rank = lora_rank
|
479 |
+
|
480 |
+
if fully_ft:
|
481 |
+
self.set_fully_trainable()
|
482 |
+
else:
|
483 |
+
self.set_adapter()
|
484 |
+
|
485 |
+
def set_adapter(self):
|
486 |
+
self.unet.requires_grad_(False)
|
487 |
+
adapter_parameters = []
|
488 |
+
|
489 |
+
in_channels = 8
|
490 |
+
out_channels = self.unet.conv_in.out_channels
|
491 |
+
self.unet.register_to_config(in_channels=in_channels)
|
492 |
+
|
493 |
+
with torch.no_grad():
|
494 |
+
new_conv_in = nn.Conv2d(in_channels, out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride,
|
495 |
+
self.unet.conv_in.padding)
|
496 |
+
|
497 |
+
new_conv_in.weight.zero_()
|
498 |
+
new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
|
499 |
+
self.unet.conv_in = new_conv_in
|
500 |
+
self.unet.conv_in.requires_grad_(True)
|
501 |
+
print('Make conv_in trainable.')
|
502 |
+
adapter_parameters.extend(self.unet.conv_in.parameters())
|
503 |
+
|
504 |
+
for name, module in self.unet.named_modules():
|
505 |
+
if isinstance(module, DownBlock2D):
|
506 |
+
module.requires_grad_(True)
|
507 |
+
adapter_parameters.extend(module.parameters())
|
508 |
+
print('Make DownBlock2D trainable.')
|
509 |
+
|
510 |
+
for attn_processor_name, attn_processor in self.unet.attn_processors.items():
|
511 |
+
# Parse the attention module.
|
512 |
+
attn_module = self.unet
|
513 |
+
for n in attn_processor_name.split(".")[:-1]:
|
514 |
+
attn_module = getattr(attn_module, n)
|
515 |
+
|
516 |
+
# Set the `lora_layer` attribute of the attention-related matrices.
|
517 |
+
attn_module.to_q.set_lora_layer(
|
518 |
+
LoRALinearLayer(in_features=attn_module.to_q.in_features,
|
519 |
+
out_features=attn_module.to_q.out_features,
|
520 |
+
rank=self.lora_rank))
|
521 |
+
# attn_module.to_k.set_lora_layer(
|
522 |
+
# LoRALinearLayer(in_features=attn_module.to_k.in_features,
|
523 |
+
# out_features=attn_module.to_k.out_features,
|
524 |
+
# rank=self.lora_rank))
|
525 |
+
# attn_module.to_v.set_lora_layer(
|
526 |
+
# LoRALinearLayer(in_features=attn_module.to_v.in_features,
|
527 |
+
# out_features=attn_module.to_v.out_features,
|
528 |
+
# rank=self.lora_rank))
|
529 |
+
attn_module.to_out[0].set_lora_layer(
|
530 |
+
LoRALinearLayer(
|
531 |
+
in_features=attn_module.to_out[0].in_features,
|
532 |
+
out_features=attn_module.to_out[0].out_features,
|
533 |
+
rank=self.lora_rank,
|
534 |
+
))
|
535 |
+
|
536 |
+
attn_module.to_k.requires_grad_(True)
|
537 |
+
attn_module.to_v.requires_grad_(True)
|
538 |
+
|
539 |
+
adapter_parameters.extend(attn_module.to_q.lora_layer.parameters())
|
540 |
+
adapter_parameters.extend(attn_module.to_k.parameters())
|
541 |
+
adapter_parameters.extend(attn_module.to_v.parameters())
|
542 |
+
adapter_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
|
543 |
+
|
544 |
+
self.adapter_parameters = adapter_parameters
|
545 |
+
|
546 |
+
def set_fully_trainable(self):
|
547 |
+
|
548 |
+
in_channels = 8
|
549 |
+
out_channels = self.unet.conv_in.out_channels
|
550 |
+
self.unet.register_to_config(in_channels=in_channels)
|
551 |
+
with torch.no_grad():
|
552 |
+
new_conv_in = nn.Conv2d(in_channels, out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride,
|
553 |
+
self.unet.conv_in.padding)
|
554 |
+
|
555 |
+
new_conv_in.weight.zero_()
|
556 |
+
new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
|
557 |
+
self.unet.conv_in = new_conv_in
|
558 |
+
|
559 |
+
self.unet.requires_grad_(True)
|
560 |
+
self.adapter_parameters = self.unet.parameters()
|
561 |
+
|
562 |
+
def params_to_opt(self):
|
563 |
+
return itertools.chain(self.resampler.parameters(), self.adapter_parameters)
|
564 |
+
|
565 |
+
def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise, time_ids, pooled_text_embeds=None):
|
566 |
+
|
567 |
+
text_embeds, pooled_text_embeds = self.resampler(text_embeds, pooled_text_embeds=pooled_text_embeds)
|
568 |
+
unet_added_conditions = {"time_ids": time_ids, 'text_embeds': pooled_text_embeds}
|
569 |
+
|
570 |
+
noise_pred = self.unet(noisy_latents, timesteps, text_embeds, added_cond_kwargs=unet_added_conditions).sample
|
571 |
+
|
572 |
+
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
573 |
+
return {'total_loss': loss, 'noise_pred': noise_pred}
|
574 |
+
|
575 |
+
def encode_text_embeds(self, text_embeds, pooled_text_embeds=None):
|
576 |
+
text_embeds, pooled_text_embeds = self.resampler(text_embeds, pooled_text_embeds=pooled_text_embeds)
|
577 |
+
|
578 |
+
return text_embeds, pooled_text_embeds
|
579 |
+
|
580 |
+
@classmethod
|
581 |
+
def from_pretrained(cls, unet, resampler, pretrained_model_path=None, **kwargs):
|
582 |
+
model = cls(unet=unet, resampler=resampler, **kwargs)
|
583 |
+
if pretrained_model_path is not None:
|
584 |
+
ckpt = torch.load(pretrained_model_path, map_location='cpu')
|
585 |
+
missing, unexpected = model.load_state_dict(ckpt, strict=False)
|
586 |
+
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
|
587 |
+
return model
|
588 |
+
|
589 |
+
def init_pipe(self,
|
590 |
+
vae,
|
591 |
+
scheduler,
|
592 |
+
text_encoder,
|
593 |
+
text_encoder_2,
|
594 |
+
tokenizer,
|
595 |
+
tokenizer_2,
|
596 |
+
dtype=torch.float16,
|
597 |
+
device='cuda'):
|
598 |
+
self.device = device
|
599 |
+
self.dtype = dtype
|
600 |
+
|
601 |
+
sdxl_pipe = StableDiffusionXLText2ImageAndEditPipeline(
|
602 |
+
tokenizer=None,
|
603 |
+
tokenizer_2=None,
|
604 |
+
text_encoder=None,
|
605 |
+
text_encoder_2=None,
|
606 |
+
vae=vae,
|
607 |
+
unet=self.unet,
|
608 |
+
scheduler=scheduler,
|
609 |
+
)
|
610 |
+
|
611 |
+
self.sdxl_pipe = sdxl_pipe
|
612 |
+
self.sdxl_pipe.to(device, dtype=dtype)
|
613 |
+
|
614 |
+
self.tokenizer = tokenizer
|
615 |
+
self.tokenizer_2 = tokenizer_2
|
616 |
+
self.text_encoder = text_encoder
|
617 |
+
self.text_encoder_2 = text_encoder_2
|
618 |
+
|
619 |
+
@torch.inference_mode()
|
620 |
+
def get_text_embeds(self, prompt=None, negative_prompt='', text_embeds=None):
|
621 |
+
assert int(prompt is not None) + int(text_embeds is not None) == 1
|
622 |
+
|
623 |
+
if prompt is not None:
|
624 |
+
text_input_ids = self.tokenizer([prompt, negative_prompt],
|
625 |
+
max_length=self.tokenizer.model_max_length,
|
626 |
+
padding="max_length",
|
627 |
+
truncation=True,
|
628 |
+
return_tensors="pt").input_ids
|
629 |
+
text_input_ids_2 = self.tokenizer_2([prompt, negative_prompt],
|
630 |
+
max_length=self.tokenizer.model_max_length,
|
631 |
+
padding="max_length",
|
632 |
+
truncation=True,
|
633 |
+
return_tensors="pt").input_ids
|
634 |
+
encoder_output = self.text_encoder(text_input_ids.to(self.device), output_hidden_states=True)
|
635 |
+
text_embeds = encoder_output.hidden_states[-2]
|
636 |
+
|
637 |
+
encoder_output_2 = self.text_encoder_2(text_input_ids_2.to(self.device), output_hidden_states=True)
|
638 |
+
pooled_text_embeds = encoder_output_2[0]
|
639 |
+
text_embeds_2 = encoder_output_2.hidden_states[-2]
|
640 |
+
|
641 |
+
text_embeds = torch.cat([text_embeds, text_embeds_2], dim=-1)
|
642 |
+
else:
|
643 |
+
text_input_ids = self.tokenizer(negative_prompt,
|
644 |
+
max_length=self.tokenizer.model_max_length,
|
645 |
+
padding="max_length",
|
646 |
+
truncation=True,
|
647 |
+
return_tensors="pt").input_ids
|
648 |
+
text_input_ids_2 = self.tokenizer_2(negative_prompt,
|
649 |
+
max_length=self.tokenizer.model_max_length,
|
650 |
+
padding="max_length",
|
651 |
+
truncation=True,
|
652 |
+
return_tensors="pt").input_ids
|
653 |
+
encoder_output = self.text_encoder(text_input_ids.to(self.device), output_hidden_states=True)
|
654 |
+
text_embeds_neg = encoder_output.hidden_states[-2]
|
655 |
+
|
656 |
+
encoder_output_2 = self.text_encoder_2(text_input_ids_2.to(self.device), output_hidden_states=True)
|
657 |
+
text_embeds_neg_2 = encoder_output_2.hidden_states[-2]
|
658 |
+
pooled_text_embeds = encoder_output_2[0]
|
659 |
+
|
660 |
+
text_embeds_neg = torch.cat([text_embeds_neg, text_embeds_neg_2], dim=-1)
|
661 |
+
|
662 |
+
text_embeds = torch.cat([text_embeds, text_embeds_neg], dim=0)
|
663 |
+
|
664 |
+
text_embeds, pooled_text_embeds = self.encode_text_embeds(text_embeds, pooled_text_embeds=pooled_text_embeds)
|
665 |
+
text_embeds, text_embeds_neg = text_embeds.chunk(2)
|
666 |
+
pooled_text_embeds, pooled_text_embeds_neg = pooled_text_embeds.chunk(2)
|
667 |
+
|
668 |
+
return text_embeds, text_embeds_neg, pooled_text_embeds, pooled_text_embeds_neg
|
669 |
+
|
670 |
+
def generate(self,
|
671 |
+
prompt=None,
|
672 |
+
negative_prompt='',
|
673 |
+
image=None,
|
674 |
+
text_embeds=None,
|
675 |
+
seed=42,
|
676 |
+
height=1024,
|
677 |
+
width=1024,
|
678 |
+
guidance_scale=7.5,
|
679 |
+
num_inference_steps=30,
|
680 |
+
**kwargs):
|
681 |
+
|
682 |
+
text_embeds, text_embeds_neg, pooled_text_embeds, pooled_text_embeds_neg = self.get_text_embeds(
|
683 |
+
prompt=prompt, negative_prompt=negative_prompt, text_embeds=text_embeds)
|
684 |
+
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
685 |
+
|
686 |
+
images = self.sdxl_pipe(
|
687 |
+
image=image,
|
688 |
+
prompt_embeds=text_embeds,
|
689 |
+
negative_prompt_embeds=text_embeds_neg,
|
690 |
+
pooled_prompt_embeds=pooled_text_embeds,
|
691 |
+
negative_pooled_prompt_embeds=pooled_text_embeds_neg,
|
692 |
+
guidance_scale=guidance_scale,
|
693 |
+
num_inference_steps=num_inference_steps,
|
694 |
+
generator=generator,
|
695 |
+
height=height,
|
696 |
+
width=width,
|
697 |
+
**kwargs,
|
698 |
+
).images
|
699 |
+
|
700 |
+
return images
|
701 |
+
|
702 |
+
|
703 |
+
class SD21Text2ImageAndEditAdapter(SDXLText2ImageAndEditAdapter):
|
704 |
+
|
705 |
+
def forward(self, noisy_latents, timesteps, image_embeds, text_embeds, noise):
|
706 |
+
|
707 |
+
text_embeds, _ = self.resampler(text_embeds)
|
708 |
+
# unet_added_conditions = {"time_ids": time_ids, 'text_embeds': pooled_text_embeds}
|
709 |
+
|
710 |
+
noise_pred = self.unet(noisy_latents, timesteps, text_embeds).sample
|
711 |
+
|
712 |
+
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
713 |
+
return {'total_loss': loss, 'noise_pred': noise_pred}
|
714 |
+
|
715 |
+
def init_pipe(self,
|
716 |
+
vae,
|
717 |
+
scheduler,
|
718 |
+
text_encoder,
|
719 |
+
tokenizer,
|
720 |
+
feature_extractor,
|
721 |
+
dtype=torch.float16,
|
722 |
+
device='cuda'):
|
723 |
+
self.device = device
|
724 |
+
self.dtype = dtype
|
725 |
+
|
726 |
+
sd_pipe = StableDiffusionText2ImageAndEditPipeline(
|
727 |
+
tokenizer=tokenizer,
|
728 |
+
text_encoder=text_encoder,
|
729 |
+
vae=vae,
|
730 |
+
unet=self.unet,
|
731 |
+
feature_extractor=feature_extractor,
|
732 |
+
safety_checker=None,
|
733 |
+
requires_safety_checker=False,
|
734 |
+
scheduler=scheduler,
|
735 |
+
)
|
736 |
+
|
737 |
+
self.sd_pipe = sd_pipe
|
738 |
+
self.sd_pipe.to(device, dtype=dtype)
|
739 |
+
|
740 |
+
self.tokenizer = tokenizer
|
741 |
+
self.text_encoder = text_encoder
|
742 |
+
|
743 |
+
@torch.inference_mode()
|
744 |
+
def get_text_embeds(self, prompt=None, negative_prompt='', text_embeds=None):
|
745 |
+
assert int(prompt is not None) + int(text_embeds is not None) == 1
|
746 |
+
|
747 |
+
if prompt is not None:
|
748 |
+
text_input_ids = self.tokenizer([prompt, negative_prompt],
|
749 |
+
max_length=self.tokenizer.model_max_length,
|
750 |
+
padding="max_length",
|
751 |
+
truncation=True,
|
752 |
+
return_tensors="pt").input_ids
|
753 |
+
encoder_output = self.text_encoder(text_input_ids.to(self.device))
|
754 |
+
text_embeds = encoder_output[0]
|
755 |
+
|
756 |
+
else:
|
757 |
+
text_input_ids = self.tokenizer(negative_prompt,
|
758 |
+
max_length=self.tokenizer.model_max_length,
|
759 |
+
padding="max_length",
|
760 |
+
truncation=True,
|
761 |
+
return_tensors="pt").input_ids
|
762 |
+
encoder_output = self.text_encoder(text_input_ids.to(self.device))
|
763 |
+
text_embeds_neg = encoder_output[0]
|
764 |
+
|
765 |
+
text_embeds = torch.cat([text_embeds, text_embeds_neg], dim=0)
|
766 |
+
|
767 |
+
text_embeds, _ = self.encode_text_embeds(text_embeds)
|
768 |
+
text_embeds, text_embeds_neg = text_embeds.chunk(2)
|
769 |
+
|
770 |
+
return text_embeds, text_embeds_neg
|
771 |
+
|
772 |
+
def generate(self,
|
773 |
+
prompt=None,
|
774 |
+
negative_prompt='',
|
775 |
+
image=None,
|
776 |
+
text_embeds=None,
|
777 |
+
seed=42,
|
778 |
+
height=1024,
|
779 |
+
width=1024,
|
780 |
+
guidance_scale=7.5,
|
781 |
+
num_inference_steps=30,
|
782 |
+
**kwargs):
|
783 |
+
|
784 |
+
text_embeds, text_embeds_neg = self.get_text_embeds(
|
785 |
+
prompt=prompt, negative_prompt=negative_prompt, text_embeds=text_embeds)
|
786 |
+
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
787 |
+
|
788 |
+
print(f'text_embeds: {text_embeds.shape}')
|
789 |
+
print(f'text_embeds_neg: {text_embeds_neg.shape}')
|
790 |
+
images = self.sd_pipe(
|
791 |
+
image=image,
|
792 |
+
prompt_embeds=text_embeds,
|
793 |
+
negative_prompt_embeds=text_embeds_neg,
|
794 |
+
guidance_scale=guidance_scale,
|
795 |
+
num_inference_steps=num_inference_steps,
|
796 |
+
generator=generator,
|
797 |
+
height=height,
|
798 |
+
width=width,
|
799 |
+
**kwargs,
|
800 |
+
).images
|
801 |
+
|
802 |
+
return images
|
803 |
+
|
804 |
+
|
805 |
+
class SDXLAdapterWithLatentImage(SDXLAdapter):
|
806 |
+
def __init__(self, unet, resampler, full_ft=False, set_trainable_late=False) -> None:
|
807 |
+
nn.Module.__init__(self)
|
808 |
+
self.unet = unet
|
809 |
+
self.resampler = resampler
|
810 |
+
self.full_ft = full_ft
|
811 |
+
if not set_trainable_late:
|
812 |
+
self.set_trainable()
|
813 |
+
|
814 |
+
def set_trainable(self):
|
815 |
+
self.resampler.requires_grad_(True)
|
816 |
+
adapter_parameters = []
|
817 |
+
|
818 |
+
in_channels = 8
|
819 |
+
out_channels = self.unet.conv_in.out_channels
|
820 |
+
self.unet.register_to_config(in_channels=in_channels)
|
821 |
+
self.unet.requires_grad_(False)
|
822 |
+
with torch.no_grad():
|
823 |
+
new_conv_in = nn.Conv2d(in_channels, out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride,
|
824 |
+
self.unet.conv_in.padding)
|
825 |
+
|
826 |
+
new_conv_in.weight.zero_()
|
827 |
+
new_conv_in.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
|
828 |
+
self.unet.conv_in = new_conv_in
|
829 |
+
self.unet.conv_in.requires_grad_(True)
|
830 |
+
|
831 |
+
if self.full_ft:
|
832 |
+
self.unet.requires_grad_(True)
|
833 |
+
adapter_parameters.extend(self.unet.parameters())
|
834 |
+
else:
|
835 |
+
adapter_parameters.extend(self.unet.conv_in.parameters())
|
836 |
+
for name, module in self.unet.named_modules():
|
837 |
+
if name.endswith('to_k') or name.endswith('to_v'):
|
838 |
+
if module is not None:
|
839 |
+
adapter_parameters.extend(module.parameters())
|
840 |
+
self.adapter_parameters = adapter_parameters
|
841 |
+
|
842 |
+
@classmethod
|
843 |
+
def from_pretrained(cls, unet, resampler, pretrained_model_path=None, set_trainable_late=False, **kwargs):
|
844 |
+
model = cls(unet=unet, resampler=resampler, set_trainable_late=set_trainable_late, **kwargs)
|
845 |
+
if pretrained_model_path is not None:
|
846 |
+
ckpt = torch.load(pretrained_model_path, map_location='cpu')
|
847 |
+
missing, unexpected = model.load_state_dict(ckpt, strict=False)
|
848 |
+
print('missing keys: ', len(missing), 'unexpected keys:', len(unexpected))
|
849 |
+
if set_trainable_late:
|
850 |
+
model.set_trainable()
|
851 |
+
return model
|
852 |
+
|
853 |
+
def init_pipe(self,
|
854 |
+
vae,
|
855 |
+
scheduler,
|
856 |
+
visual_encoder,
|
857 |
+
image_transform,
|
858 |
+
dtype=torch.float16,
|
859 |
+
device='cuda'):
|
860 |
+
self.device = device
|
861 |
+
self.dtype = dtype
|
862 |
+
|
863 |
+
sdxl_pipe = StableDiffusionXLText2ImageAndEditPipeline(
|
864 |
+
tokenizer=None,
|
865 |
+
tokenizer_2=None,
|
866 |
+
text_encoder=None,
|
867 |
+
text_encoder_2=None,
|
868 |
+
vae=vae,
|
869 |
+
unet=self.unet,
|
870 |
+
scheduler=scheduler,
|
871 |
+
)
|
872 |
+
|
873 |
+
self.sdxl_pipe = sdxl_pipe
|
874 |
+
self.sdxl_pipe.to(device, dtype=dtype)
|
875 |
+
self.discrete_model = None
|
876 |
+
|
877 |
+
self.visual_encoder = visual_encoder.to(self.device, dtype=self.dtype)
|
878 |
+
self.image_transform = image_transform
|
879 |
+
|
880 |
+
def generate(self,
|
881 |
+
image_pil=None,
|
882 |
+
image_tensor=None,
|
883 |
+
image_embeds=None,
|
884 |
+
latent_image=None,
|
885 |
+
seed=42,
|
886 |
+
height=1024,
|
887 |
+
width=1024,
|
888 |
+
guidance_scale=7.5,
|
889 |
+
num_inference_steps=30,
|
890 |
+
input_image_size=448,
|
891 |
+
**kwargs):
|
892 |
+
if image_pil is not None:
|
893 |
+
assert isinstance(image_pil, Image.Image)
|
894 |
+
|
895 |
+
image_prompt_embeds, uncond_image_prompt_embeds, \
|
896 |
+
pooled_image_prompt_embeds, pooled_uncond_image_prompt_embeds = self.get_image_embeds(
|
897 |
+
image_pil=image_pil,
|
898 |
+
image_tensor=image_tensor,
|
899 |
+
image_embeds=image_embeds,
|
900 |
+
return_negative=True,
|
901 |
+
image_size=input_image_size,
|
902 |
+
)
|
903 |
+
# print(image_prompt_embeds.shape, pooled_image_prompt_embeds.shape)
|
904 |
+
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
905 |
+
|
906 |
+
images = self.sdxl_pipe(
|
907 |
+
image=latent_image,
|
908 |
+
prompt_embeds=image_prompt_embeds,
|
909 |
+
negative_prompt_embeds=uncond_image_prompt_embeds,
|
910 |
+
pooled_prompt_embeds=pooled_image_prompt_embeds,
|
911 |
+
negative_pooled_prompt_embeds=pooled_uncond_image_prompt_embeds,
|
912 |
+
guidance_scale=guidance_scale,
|
913 |
+
num_inference_steps=num_inference_steps,
|
914 |
+
generator=generator,
|
915 |
+
height=height,
|
916 |
+
width=width,
|
917 |
+
**kwargs,
|
918 |
+
).images
|
919 |
+
|
920 |
+
return images
|
src/models_ipa/attention_processor.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class AttnProcessor(nn.Module):
|
8 |
+
r"""
|
9 |
+
Default processor for performing attention-related computations.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
hidden_size=None,
|
15 |
+
cross_attention_dim=None,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
def __call__(
|
20 |
+
self,
|
21 |
+
attn,
|
22 |
+
hidden_states,
|
23 |
+
encoder_hidden_states=None,
|
24 |
+
attention_mask=None,
|
25 |
+
temb=None,
|
26 |
+
):
|
27 |
+
residual = hidden_states
|
28 |
+
|
29 |
+
if attn.spatial_norm is not None:
|
30 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
31 |
+
|
32 |
+
input_ndim = hidden_states.ndim
|
33 |
+
|
34 |
+
if input_ndim == 4:
|
35 |
+
batch_size, channel, height, width = hidden_states.shape
|
36 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
37 |
+
|
38 |
+
batch_size, sequence_length, _ = (
|
39 |
+
hidden_states.shape
|
40 |
+
if encoder_hidden_states is None
|
41 |
+
else encoder_hidden_states.shape
|
42 |
+
)
|
43 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
44 |
+
|
45 |
+
if attn.group_norm is not None:
|
46 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
47 |
+
|
48 |
+
query = attn.to_q(hidden_states)
|
49 |
+
|
50 |
+
if encoder_hidden_states is None:
|
51 |
+
encoder_hidden_states = hidden_states
|
52 |
+
elif attn.norm_cross:
|
53 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
54 |
+
|
55 |
+
key = attn.to_k(encoder_hidden_states)
|
56 |
+
value = attn.to_v(encoder_hidden_states)
|
57 |
+
|
58 |
+
query = attn.head_to_batch_dim(query)
|
59 |
+
key = attn.head_to_batch_dim(key)
|
60 |
+
value = attn.head_to_batch_dim(value)
|
61 |
+
|
62 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
63 |
+
hidden_states = torch.bmm(attention_probs, value)
|
64 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
65 |
+
|
66 |
+
# linear proj
|
67 |
+
hidden_states = attn.to_out[0](hidden_states)
|
68 |
+
# dropout
|
69 |
+
hidden_states = attn.to_out[1](hidden_states)
|
70 |
+
|
71 |
+
if input_ndim == 4:
|
72 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
73 |
+
|
74 |
+
if attn.residual_connection:
|
75 |
+
hidden_states = hidden_states + residual
|
76 |
+
|
77 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
78 |
+
|
79 |
+
return hidden_states
|
80 |
+
|
81 |
+
|
82 |
+
class IPAttnProcessor(nn.Module):
|
83 |
+
r"""
|
84 |
+
Attention processor for IP-Adapater.
|
85 |
+
Args:
|
86 |
+
hidden_size (`int`):
|
87 |
+
The hidden size of the attention layer.
|
88 |
+
cross_attention_dim (`int`):
|
89 |
+
The number of channels in the `encoder_hidden_states`.
|
90 |
+
text_context_len (`int`, defaults to 77):
|
91 |
+
The context length of the text features.
|
92 |
+
scale (`float`, defaults to 1.0):
|
93 |
+
the weight scale of image prompt.
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
self.hidden_size = hidden_size
|
100 |
+
self.cross_attention_dim = cross_attention_dim
|
101 |
+
self.text_context_len = text_context_len
|
102 |
+
self.scale = scale
|
103 |
+
|
104 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
105 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
106 |
+
|
107 |
+
def __call__(
|
108 |
+
self,
|
109 |
+
attn,
|
110 |
+
hidden_states,
|
111 |
+
encoder_hidden_states=None,
|
112 |
+
attention_mask=None,
|
113 |
+
temb=None,
|
114 |
+
):
|
115 |
+
residual = hidden_states
|
116 |
+
|
117 |
+
if attn.spatial_norm is not None:
|
118 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
119 |
+
|
120 |
+
input_ndim = hidden_states.ndim
|
121 |
+
|
122 |
+
if input_ndim == 4:
|
123 |
+
batch_size, channel, height, width = hidden_states.shape
|
124 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
125 |
+
|
126 |
+
batch_size, sequence_length, _ = (
|
127 |
+
hidden_states.shape
|
128 |
+
if encoder_hidden_states is None
|
129 |
+
else encoder_hidden_states.shape
|
130 |
+
)
|
131 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
132 |
+
|
133 |
+
if attn.group_norm is not None:
|
134 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
135 |
+
|
136 |
+
query = attn.to_q(hidden_states)
|
137 |
+
|
138 |
+
if encoder_hidden_states is None:
|
139 |
+
encoder_hidden_states = hidden_states
|
140 |
+
elif attn.norm_cross:
|
141 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
142 |
+
|
143 |
+
# split hidden states
|
144 |
+
encoder_hidden_states, \
|
145 |
+
ip_hidden_states = \
|
146 |
+
encoder_hidden_states[:, :self.text_context_len, :], \
|
147 |
+
encoder_hidden_states[:, self.text_context_len:, :]
|
148 |
+
|
149 |
+
key = attn.to_k(encoder_hidden_states)
|
150 |
+
value = attn.to_v(encoder_hidden_states)
|
151 |
+
|
152 |
+
query = attn.head_to_batch_dim(query)
|
153 |
+
key = attn.head_to_batch_dim(key)
|
154 |
+
value = attn.head_to_batch_dim(value)
|
155 |
+
|
156 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
157 |
+
hidden_states = torch.bmm(attention_probs, value)
|
158 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
159 |
+
|
160 |
+
# for ip-adapter
|
161 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
162 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
163 |
+
|
164 |
+
ip_key = attn.head_to_batch_dim(ip_key)
|
165 |
+
ip_value = attn.head_to_batch_dim(ip_value)
|
166 |
+
|
167 |
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
168 |
+
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
169 |
+
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
170 |
+
|
171 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
172 |
+
|
173 |
+
# linear proj
|
174 |
+
hidden_states = attn.to_out[0](hidden_states)
|
175 |
+
# dropout
|
176 |
+
hidden_states = attn.to_out[1](hidden_states)
|
177 |
+
|
178 |
+
if input_ndim == 4:
|
179 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
180 |
+
|
181 |
+
if attn.residual_connection:
|
182 |
+
hidden_states = hidden_states + residual
|
183 |
+
|
184 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
185 |
+
|
186 |
+
return hidden_states
|
187 |
+
|
188 |
+
|
189 |
+
class AttnProcessor2_0(torch.nn.Module):
|
190 |
+
r"""
|
191 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
192 |
+
"""
|
193 |
+
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
hidden_size=None,
|
197 |
+
cross_attention_dim=None,
|
198 |
+
):
|
199 |
+
super().__init__()
|
200 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
201 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
202 |
+
|
203 |
+
def __call__(
|
204 |
+
self,
|
205 |
+
attn,
|
206 |
+
hidden_states,
|
207 |
+
encoder_hidden_states=None,
|
208 |
+
attention_mask=None,
|
209 |
+
temb=None,
|
210 |
+
):
|
211 |
+
residual = hidden_states
|
212 |
+
|
213 |
+
if attn.spatial_norm is not None:
|
214 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
215 |
+
|
216 |
+
input_ndim = hidden_states.ndim
|
217 |
+
|
218 |
+
if input_ndim == 4:
|
219 |
+
batch_size, channel, height, width = hidden_states.shape
|
220 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
221 |
+
|
222 |
+
batch_size, sequence_length, _ = (
|
223 |
+
hidden_states.shape
|
224 |
+
if encoder_hidden_states is None
|
225 |
+
else encoder_hidden_states.shape
|
226 |
+
)
|
227 |
+
|
228 |
+
if attention_mask is not None:
|
229 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
230 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
231 |
+
# (batch, heads, source_length, target_length)
|
232 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
233 |
+
|
234 |
+
if attn.group_norm is not None:
|
235 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
236 |
+
|
237 |
+
query = attn.to_q(hidden_states)
|
238 |
+
|
239 |
+
if encoder_hidden_states is None:
|
240 |
+
encoder_hidden_states = hidden_states
|
241 |
+
elif attn.norm_cross:
|
242 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
243 |
+
|
244 |
+
key = attn.to_k(encoder_hidden_states)
|
245 |
+
value = attn.to_v(encoder_hidden_states)
|
246 |
+
|
247 |
+
inner_dim = key.shape[-1]
|
248 |
+
head_dim = inner_dim // attn.heads
|
249 |
+
|
250 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
251 |
+
|
252 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
253 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
254 |
+
|
255 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
256 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
257 |
+
hidden_states = F.scaled_dot_product_attention(query,
|
258 |
+
key,
|
259 |
+
value,
|
260 |
+
attn_mask=attention_mask,
|
261 |
+
dropout_p=0.0,
|
262 |
+
is_causal=False)
|
263 |
+
|
264 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
265 |
+
hidden_states = hidden_states.to(query.dtype)
|
266 |
+
|
267 |
+
# linear proj
|
268 |
+
hidden_states = attn.to_out[0](hidden_states)
|
269 |
+
# dropout
|
270 |
+
hidden_states = attn.to_out[1](hidden_states)
|
271 |
+
|
272 |
+
if input_ndim == 4:
|
273 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
274 |
+
|
275 |
+
if attn.residual_connection:
|
276 |
+
hidden_states = hidden_states + residual
|
277 |
+
|
278 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
279 |
+
|
280 |
+
return hidden_states
|
281 |
+
|
282 |
+
|
283 |
+
class IPAttnProcessor2_0(torch.nn.Module):
|
284 |
+
r"""
|
285 |
+
Attention processor for IP-Adapater for PyTorch 2.0.
|
286 |
+
Args:
|
287 |
+
hidden_size (`int`):
|
288 |
+
The hidden size of the attention layer.
|
289 |
+
cross_attention_dim (`int`):
|
290 |
+
The number of channels in the `encoder_hidden_states`.
|
291 |
+
text_context_len (`int`, defaults to 77):
|
292 |
+
The context length of the text features.
|
293 |
+
scale (`float`, defaults to 1.0):
|
294 |
+
the weight scale of image prompt.
|
295 |
+
"""
|
296 |
+
|
297 |
+
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
|
298 |
+
super().__init__()
|
299 |
+
|
300 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
301 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
302 |
+
|
303 |
+
self.hidden_size = hidden_size
|
304 |
+
self.cross_attention_dim = cross_attention_dim
|
305 |
+
self.text_context_len = text_context_len
|
306 |
+
self.scale = scale
|
307 |
+
|
308 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
309 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
310 |
+
|
311 |
+
def __call__(
|
312 |
+
self,
|
313 |
+
attn,
|
314 |
+
hidden_states,
|
315 |
+
encoder_hidden_states=None,
|
316 |
+
attention_mask=None,
|
317 |
+
temb=None,
|
318 |
+
):
|
319 |
+
residual = hidden_states
|
320 |
+
|
321 |
+
if attn.spatial_norm is not None:
|
322 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
323 |
+
|
324 |
+
input_ndim = hidden_states.ndim
|
325 |
+
|
326 |
+
if input_ndim == 4:
|
327 |
+
batch_size, channel, height, width = hidden_states.shape
|
328 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
329 |
+
|
330 |
+
batch_size, sequence_length, _ = (
|
331 |
+
hidden_states.shape
|
332 |
+
if encoder_hidden_states is None
|
333 |
+
else encoder_hidden_states.shape
|
334 |
+
)
|
335 |
+
if attention_mask is not None:
|
336 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
337 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
338 |
+
# (batch, heads, source_length, target_length)
|
339 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
340 |
+
|
341 |
+
if attn.group_norm is not None:
|
342 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
343 |
+
|
344 |
+
query = attn.to_q(hidden_states)
|
345 |
+
|
346 |
+
if encoder_hidden_states is None:
|
347 |
+
encoder_hidden_states = hidden_states
|
348 |
+
elif attn.norm_cross:
|
349 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
350 |
+
|
351 |
+
# split hidden states
|
352 |
+
encoder_hidden_states, \
|
353 |
+
ip_hidden_states = \
|
354 |
+
encoder_hidden_states[:, :self.text_context_len, :], \
|
355 |
+
encoder_hidden_states[:, self.text_context_len:, :]
|
356 |
+
|
357 |
+
key = attn.to_k(encoder_hidden_states)
|
358 |
+
value = attn.to_v(encoder_hidden_states)
|
359 |
+
|
360 |
+
inner_dim = key.shape[-1]
|
361 |
+
head_dim = inner_dim // attn.heads
|
362 |
+
|
363 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
364 |
+
|
365 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
366 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
367 |
+
|
368 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
369 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
370 |
+
hidden_states = F.scaled_dot_product_attention(query,
|
371 |
+
key,
|
372 |
+
value,
|
373 |
+
attn_mask=attention_mask,
|
374 |
+
dropout_p=0.0,
|
375 |
+
is_causal=False)
|
376 |
+
|
377 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
378 |
+
hidden_states = hidden_states.to(query.dtype)
|
379 |
+
|
380 |
+
# for ip-adapter
|
381 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
382 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
383 |
+
|
384 |
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
385 |
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
386 |
+
|
387 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
388 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
389 |
+
ip_hidden_states = F.scaled_dot_product_attention(query,
|
390 |
+
ip_key,
|
391 |
+
ip_value,
|
392 |
+
attn_mask=None,
|
393 |
+
dropout_p=0.0,
|
394 |
+
is_causal=False)
|
395 |
+
|
396 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
397 |
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
398 |
+
|
399 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
400 |
+
|
401 |
+
# linear proj
|
402 |
+
hidden_states = attn.to_out[0](hidden_states)
|
403 |
+
# dropout
|
404 |
+
hidden_states = attn.to_out[1](hidden_states)
|
405 |
+
|
406 |
+
if input_ndim == 4:
|
407 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
408 |
+
|
409 |
+
if attn.residual_connection:
|
410 |
+
hidden_states = hidden_states + residual
|
411 |
+
|
412 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
413 |
+
|
414 |
+
return hidden_states
|
src/models_ipa/ipa_utils.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
|
3 |
+
|
4 |
+
def is_torch2_available():
|
5 |
+
return hasattr(F, "scaled_dot_product_attention")
|
src/models_ipa/resampler.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
# FFN
|
10 |
+
def FeedForward(dim, mult=4):
|
11 |
+
inner_dim = int(dim * mult)
|
12 |
+
return nn.Sequential(
|
13 |
+
nn.LayerNorm(dim),
|
14 |
+
nn.Linear(dim, inner_dim, bias=False),
|
15 |
+
nn.GELU(),
|
16 |
+
nn.Linear(inner_dim, dim, bias=False),
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
def reshape_tensor(x, heads):
|
21 |
+
bs, length, width = x.shape
|
22 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
23 |
+
x = x.view(bs, length, heads, -1)
|
24 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
25 |
+
x = x.transpose(1, 2)
|
26 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
27 |
+
x = x.reshape(bs, heads, length, -1)
|
28 |
+
return x
|
29 |
+
|
30 |
+
|
31 |
+
class PerceiverAttention(nn.Module):
|
32 |
+
|
33 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
34 |
+
super().__init__()
|
35 |
+
self.scale = dim_head ** -0.5
|
36 |
+
self.dim_head = dim_head
|
37 |
+
self.heads = heads
|
38 |
+
inner_dim = dim_head * heads
|
39 |
+
|
40 |
+
self.norm1 = nn.LayerNorm(dim)
|
41 |
+
self.norm2 = nn.LayerNorm(dim)
|
42 |
+
|
43 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
44 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
45 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
46 |
+
|
47 |
+
def forward(self, x, latents):
|
48 |
+
"""
|
49 |
+
Args:
|
50 |
+
x (torch.Tensor): image features
|
51 |
+
shape (b, n1, D)
|
52 |
+
latent (torch.Tensor): latent features
|
53 |
+
shape (b, n2, D)
|
54 |
+
"""
|
55 |
+
x = self.norm1(x)
|
56 |
+
latents = self.norm2(latents)
|
57 |
+
|
58 |
+
b, l, _ = latents.shape
|
59 |
+
|
60 |
+
q = self.to_q(latents)
|
61 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
62 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
63 |
+
|
64 |
+
q = reshape_tensor(q, self.heads)
|
65 |
+
k = reshape_tensor(k, self.heads)
|
66 |
+
v = reshape_tensor(v, self.heads)
|
67 |
+
|
68 |
+
# attention
|
69 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
70 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
71 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
72 |
+
out = weight @ v
|
73 |
+
|
74 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
75 |
+
|
76 |
+
return self.to_out(out)
|
77 |
+
|
78 |
+
|
79 |
+
class AttentionPool2d(nn.Module):
|
80 |
+
|
81 |
+
def __init__(self, seq_len: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
82 |
+
super().__init__()
|
83 |
+
self.positional_embedding = nn.Parameter(torch.randn(seq_len + 1, embed_dim) / embed_dim ** 0.5)
|
84 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
85 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
86 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
87 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
88 |
+
self.num_heads = num_heads
|
89 |
+
|
90 |
+
def forward(self, x, return_all_tokens=False):
|
91 |
+
# x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
92 |
+
x = x.permute(1, 0, 2) # (N(HW)C) => (HW)NC
|
93 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
94 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
95 |
+
x, _ = F.multi_head_attention_forward(query=x,
|
96 |
+
key=x,
|
97 |
+
value=x,
|
98 |
+
embed_dim_to_check=x.shape[-1],
|
99 |
+
num_heads=self.num_heads,
|
100 |
+
q_proj_weight=self.q_proj.weight,
|
101 |
+
k_proj_weight=self.k_proj.weight,
|
102 |
+
v_proj_weight=self.v_proj.weight,
|
103 |
+
in_proj_weight=None,
|
104 |
+
in_proj_bias=torch.cat(
|
105 |
+
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
106 |
+
bias_k=None,
|
107 |
+
bias_v=None,
|
108 |
+
add_zero_attn=False,
|
109 |
+
dropout_p=0,
|
110 |
+
out_proj_weight=self.c_proj.weight,
|
111 |
+
out_proj_bias=self.c_proj.bias,
|
112 |
+
use_separate_proj_weight=True,
|
113 |
+
training=self.training,
|
114 |
+
need_weights=False)
|
115 |
+
if return_all_tokens:
|
116 |
+
return x
|
117 |
+
else:
|
118 |
+
return x[0]
|
119 |
+
|
120 |
+
|
121 |
+
class Resampler(nn.Module):
|
122 |
+
|
123 |
+
def __init__(
|
124 |
+
self,
|
125 |
+
dim=1024,
|
126 |
+
depth=8,
|
127 |
+
dim_head=64,
|
128 |
+
heads=16,
|
129 |
+
num_queries=8,
|
130 |
+
embedding_dim=768,
|
131 |
+
output_dim=1024,
|
132 |
+
ff_mult=4,
|
133 |
+
):
|
134 |
+
super().__init__()
|
135 |
+
|
136 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
|
137 |
+
|
138 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
139 |
+
|
140 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
141 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
142 |
+
|
143 |
+
self.in_dim = dim
|
144 |
+
self.out_dim = output_dim
|
145 |
+
|
146 |
+
self.layers = nn.ModuleList([])
|
147 |
+
for _ in range(depth):
|
148 |
+
self.layers.append(
|
149 |
+
nn.ModuleList([
|
150 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
151 |
+
FeedForward(dim=dim, mult=ff_mult),
|
152 |
+
]))
|
153 |
+
|
154 |
+
def forward(self, x):
|
155 |
+
|
156 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
157 |
+
|
158 |
+
x = self.proj_in(x)
|
159 |
+
|
160 |
+
for attn, ff in self.layers:
|
161 |
+
latents = attn(x, latents) + latents
|
162 |
+
latents = ff(latents) + latents
|
163 |
+
|
164 |
+
latents = self.proj_out(latents)
|
165 |
+
output_embeds = self.norm_out(latents)
|
166 |
+
|
167 |
+
return output_embeds
|
168 |
+
|
169 |
+
|
170 |
+
class ResamplerXL(nn.Module):
|
171 |
+
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
dim=1024,
|
175 |
+
depth=8,
|
176 |
+
dim_head=64,
|
177 |
+
heads=16,
|
178 |
+
num_queries=8,
|
179 |
+
embedding_dim=768,
|
180 |
+
output1_dim=768,
|
181 |
+
output2_dim=1280,
|
182 |
+
ff_mult=4,
|
183 |
+
):
|
184 |
+
super().__init__()
|
185 |
+
|
186 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
|
187 |
+
|
188 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
189 |
+
|
190 |
+
# self.proj_out = nn.Linear(dim, output_dim)
|
191 |
+
self.norm_out = nn.LayerNorm(dim)
|
192 |
+
|
193 |
+
self.in_dim = dim
|
194 |
+
self.out_dim = output1_dim + output2_dim
|
195 |
+
|
196 |
+
self.layers = nn.ModuleList([])
|
197 |
+
for _ in range(depth):
|
198 |
+
self.layers.append(
|
199 |
+
nn.ModuleList([
|
200 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
201 |
+
FeedForward(dim=dim, mult=ff_mult),
|
202 |
+
]))
|
203 |
+
|
204 |
+
self.unet_proj_1 = nn.Linear(self.in_dim, output1_dim)
|
205 |
+
self.unet_proj_2 = nn.Linear(self.in_dim, output2_dim)
|
206 |
+
self.unet_attnpool = AttentionPool2d(num_queries, self.in_dim, heads, output2_dim)
|
207 |
+
|
208 |
+
def forward(self, x):
|
209 |
+
|
210 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
211 |
+
|
212 |
+
x = self.proj_in(x)
|
213 |
+
|
214 |
+
for attn, ff in self.layers:
|
215 |
+
latents = attn(x, latents) + latents
|
216 |
+
latents = ff(latents) + latents
|
217 |
+
|
218 |
+
hidden_embeds = self.norm_out(latents)
|
219 |
+
|
220 |
+
encoder_hidden_1 = self.unet_proj_1(hidden_embeds) # [bs, 256, 768]
|
221 |
+
encoder_hidden_2 = self.unet_proj_2(hidden_embeds) # [bs, 256, 1280]
|
222 |
+
prompt_embeds = torch.cat([encoder_hidden_1, encoder_hidden_2], dim=-1) # [bs, 256, 2048]
|
223 |
+
pooled_prompt_embeds = self.unet_attnpool(hidden_embeds) # [bs, 1280]
|
224 |
+
|
225 |
+
return prompt_embeds, pooled_prompt_embeds
|
226 |
+
|
227 |
+
|
228 |
+
class ResamplerXLV2(nn.Module):
|
229 |
+
|
230 |
+
def __init__(
|
231 |
+
self,
|
232 |
+
dim=1024,
|
233 |
+
depth=8,
|
234 |
+
dim_head=64,
|
235 |
+
heads=16,
|
236 |
+
num_queries=8,
|
237 |
+
embedding_dim=768,
|
238 |
+
output1_dim=768,
|
239 |
+
output2_dim=1280,
|
240 |
+
ff_mult=4,
|
241 |
+
):
|
242 |
+
super().__init__()
|
243 |
+
|
244 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
|
245 |
+
|
246 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
247 |
+
|
248 |
+
# self.proj_out = nn.Linear(dim, output_dim)
|
249 |
+
self.norm_out = nn.LayerNorm(dim)
|
250 |
+
|
251 |
+
self.in_dim = dim
|
252 |
+
self.out_dim = output1_dim + output2_dim
|
253 |
+
|
254 |
+
self.layers = nn.ModuleList([])
|
255 |
+
for _ in range(depth):
|
256 |
+
self.layers.append(
|
257 |
+
nn.ModuleList([
|
258 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
259 |
+
FeedForward(dim=dim, mult=ff_mult),
|
260 |
+
]))
|
261 |
+
|
262 |
+
self.unet_proj_1 = nn.Linear(self.in_dim, output1_dim)
|
263 |
+
self.unet_proj_2 = nn.Linear(self.in_dim, output2_dim)
|
264 |
+
self.unet_attnpool = AttentionPool2d(num_queries, self.in_dim, heads, output2_dim)
|
265 |
+
|
266 |
+
def forward(self, x, pooled_text_embeds=None):
|
267 |
+
|
268 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
269 |
+
x = F.normalize(x)
|
270 |
+
|
271 |
+
x = self.proj_in(x)
|
272 |
+
|
273 |
+
for attn, ff in self.layers:
|
274 |
+
latents = attn(x, latents) + latents
|
275 |
+
latents = ff(latents) + latents
|
276 |
+
|
277 |
+
hidden_embeds = self.norm_out(latents)
|
278 |
+
|
279 |
+
encoder_hidden_1 = self.unet_proj_1(hidden_embeds) # [bs, 256, 768]
|
280 |
+
encoder_hidden_2 = self.unet_proj_2(hidden_embeds) # [bs, 256, 1280]
|
281 |
+
prompt_embeds = torch.cat([encoder_hidden_1, encoder_hidden_2], dim=-1) # [bs, 256, 2048]
|
282 |
+
pooled_prompt_embeds = self.unet_attnpool(hidden_embeds) # [bs, 1280]
|
283 |
+
|
284 |
+
return prompt_embeds, pooled_prompt_embeds
|
285 |
+
|
286 |
+
|
287 |
+
class ResamplerXLIdentity(nn.Module):
|
288 |
+
def __init__(self) -> None:
|
289 |
+
super().__init__()
|
290 |
+
|
291 |
+
def forward(self, x, pooled_text_embeds=None):
|
292 |
+
return x, pooled_text_embeds
|
293 |
+
|
294 |
+
|
295 |
+
if __name__ == '__main__':
|
296 |
+
image_proj_model = Resampler(dim=1024,
|
297 |
+
depth=4,
|
298 |
+
dim_head=64,
|
299 |
+
heads=12,
|
300 |
+
num_queries=1024,
|
301 |
+
embedding_dim=1024,
|
302 |
+
output_dim=1024,
|
303 |
+
ff_mult=4)
|
304 |
+
numel = 0
|
305 |
+
for name, param in image_proj_model.named_parameters():
|
306 |
+
numel += param.numel()
|
307 |
+
|
308 |
+
print(f'Total params: {numel}')
|
src/processer/tokenizer.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BertTokenizer
|
2 |
+
|
3 |
+
|
4 |
+
def bert_tokenizer(pretrained_model_name_or_path):
|
5 |
+
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path,
|
6 |
+
truncation_side='right')
|
7 |
+
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
8 |
+
return tokenizer
|
src/processer/transforms.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import transforms
|
2 |
+
|
3 |
+
|
4 |
+
def get_transform(type='clip', keep_ratio=True, image_size=224):
|
5 |
+
if type == 'clip':
|
6 |
+
transform = []
|
7 |
+
if keep_ratio:
|
8 |
+
transform.extend([
|
9 |
+
transforms.Resize(image_size),
|
10 |
+
transforms.CenterCrop(image_size),
|
11 |
+
])
|
12 |
+
else:
|
13 |
+
transform.append(transforms.Resize((image_size, image_size)))
|
14 |
+
transform.extend([
|
15 |
+
transforms.ToTensor(),
|
16 |
+
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
17 |
+
])
|
18 |
+
|
19 |
+
return transforms.Compose(transform)
|
20 |
+
elif type == 'clipa':
|
21 |
+
transform = []
|
22 |
+
if keep_ratio:
|
23 |
+
transform.extend([
|
24 |
+
transforms.Resize(image_size),
|
25 |
+
transforms.CenterCrop(image_size),
|
26 |
+
])
|
27 |
+
else:
|
28 |
+
transform.append(transforms.Resize((image_size, image_size)))
|
29 |
+
transform.extend(
|
30 |
+
[transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])
|
31 |
+
|
32 |
+
return transforms.Compose(transform)
|
33 |
+
elif type == 'sd':
|
34 |
+
transform = []
|
35 |
+
if keep_ratio:
|
36 |
+
transform.extend([
|
37 |
+
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
|
38 |
+
transforms.CenterCrop(image_size),
|
39 |
+
])
|
40 |
+
else:
|
41 |
+
transform.append(
|
42 |
+
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC))
|
43 |
+
transform.extend([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
|
44 |
+
|
45 |
+
return transforms.Compose(transform)
|
46 |
+
else:
|
47 |
+
raise NotImplementedError
|
src/tools/reload_qwen_vit.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM
|
3 |
+
|
4 |
+
torch.manual_seed(1234)
|
5 |
+
|
6 |
+
qwen_model_path = 'pretrained/Qwen-VL-Chat'
|
7 |
+
save_path = 'pretrained/QwenViT/qwen_vit_G.pt'
|
8 |
+
|
9 |
+
model = AutoModelForCausalLM.from_pretrained(qwen_model_path, device_map="cpu", trust_remote_code=True).eval()
|
10 |
+
|
11 |
+
visual_encoder = model.transformer.visual
|
12 |
+
print(visual_encoder)
|
13 |
+
|
14 |
+
torch.save(visual_encoder.state_dict(), save_path)
|
src/train/dist_utils.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.distributed as dist
|
3 |
+
|
4 |
+
|
5 |
+
def all_gather(tensor):
|
6 |
+
world_size = dist.get_world_size()
|
7 |
+
tensor_list = [torch.zeros_like(tensor) for _ in range(world_size)]
|
8 |
+
dist.all_gather(tensor_list, tensor)
|
9 |
+
return tensor_list
|
10 |
+
|
11 |
+
|
12 |
+
def is_dist_avail_and_initialized():
|
13 |
+
if not dist.is_available():
|
14 |
+
return False
|
15 |
+
if not dist.is_initialized():
|
16 |
+
return False
|
17 |
+
return True
|
18 |
+
|
19 |
+
|
20 |
+
@torch.no_grad()
|
21 |
+
def concat_all_gather(tensor):
|
22 |
+
"""
|
23 |
+
Performs all_gather operation on the provided tensors.
|
24 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
25 |
+
"""
|
26 |
+
# if use distributed training
|
27 |
+
if not is_dist_avail_and_initialized():
|
28 |
+
return tensor
|
29 |
+
|
30 |
+
tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
|
31 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
32 |
+
|
33 |
+
output = torch.cat(tensors_gather, dim=0)
|
34 |
+
return output
|
src/train/schedular.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from functools import partial
|
4 |
+
from typing import Callable, Iterable, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.optim import Optimizer
|
9 |
+
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
|
10 |
+
from transformers.trainer_utils import SchedulerType
|
11 |
+
from transformers.utils import logging
|
12 |
+
|
13 |
+
from transformers.optimization import get_linear_schedule_with_warmup, \
|
14 |
+
get_cosine_with_hard_restarts_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, \
|
15 |
+
get_constant_schedule, get_constant_schedule_with_warmup, get_inverse_sqrt_schedule, get_reduce_on_plateau_schedule
|
16 |
+
|
17 |
+
logger = logging.get_logger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
def _get_cosine_schedule_with_warmup_lr_lambda(current_step: int,
|
21 |
+
*,
|
22 |
+
num_warmup_steps: int,
|
23 |
+
num_training_steps: int,
|
24 |
+
num_cycles: float,
|
25 |
+
min_lr_ratio: float = 0.0):
|
26 |
+
if current_step < num_warmup_steps:
|
27 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
28 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
29 |
+
# return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
30 |
+
return max(0.0,
|
31 |
+
0.5 * ((1.0 + min_lr_ratio) + (1.0 - min_lr_ratio) * math.cos(
|
32 |
+
math.pi * float(num_cycles) * 2.0 * progress)))
|
33 |
+
|
34 |
+
|
35 |
+
def get_cosine_schedule_with_warmup(optimizer: Optimizer,
|
36 |
+
num_warmup_steps: int,
|
37 |
+
num_training_steps: int,
|
38 |
+
num_cycles: float = 0.5,
|
39 |
+
last_epoch: int = -1,
|
40 |
+
min_lr_ratio: float = 0.0):
|
41 |
+
"""
|
42 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
43 |
+
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
44 |
+
initial lr set in the optimizer.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
48 |
+
The optimizer for which to schedule the learning rate.
|
49 |
+
num_warmup_steps (`int`):
|
50 |
+
The number of steps for the warmup phase.
|
51 |
+
num_training_steps (`int`):
|
52 |
+
The total number of training steps.
|
53 |
+
num_cycles (`float`, *optional*, defaults to 0.5):
|
54 |
+
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
55 |
+
following a half-cosine).
|
56 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
57 |
+
The index of the last epoch when resuming training.
|
58 |
+
|
59 |
+
Return:
|
60 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
61 |
+
"""
|
62 |
+
|
63 |
+
lr_lambda = partial(
|
64 |
+
_get_cosine_schedule_with_warmup_lr_lambda,
|
65 |
+
num_warmup_steps=num_warmup_steps,
|
66 |
+
num_training_steps=num_training_steps,
|
67 |
+
num_cycles=num_cycles,
|
68 |
+
min_lr_ratio=min_lr_ratio,
|
69 |
+
)
|
70 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
71 |
+
|
72 |
+
|
73 |
+
TYPE_TO_SCHEDULER_FUNCTION = {
|
74 |
+
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
|
75 |
+
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
|
76 |
+
SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
|
77 |
+
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
|
78 |
+
SchedulerType.CONSTANT: get_constant_schedule,
|
79 |
+
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
|
80 |
+
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
|
81 |
+
SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule,
|
82 |
+
}
|
83 |
+
|
84 |
+
|
85 |
+
def get_scheduler(
|
86 |
+
name: Union[str, SchedulerType],
|
87 |
+
optimizer: Optimizer,
|
88 |
+
num_warmup_steps: Optional[int] = None,
|
89 |
+
num_training_steps: Optional[int] = None,
|
90 |
+
min_lr_ratio: Optional[float] = 0.0,
|
91 |
+
):
|
92 |
+
"""
|
93 |
+
Unified API to get any scheduler from its name.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
name (`str` or `SchedulerType`):
|
97 |
+
The name of the scheduler to use.
|
98 |
+
optimizer (`torch.optim.Optimizer`):
|
99 |
+
The optimizer that will be used during training.
|
100 |
+
num_warmup_steps (`int`, *optional*):
|
101 |
+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
102 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
103 |
+
num_training_steps (`int``, *optional*):
|
104 |
+
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
105 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
106 |
+
"""
|
107 |
+
name = SchedulerType(name)
|
108 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
109 |
+
if name == SchedulerType.CONSTANT or name == SchedulerType.REDUCE_ON_PLATEAU:
|
110 |
+
return schedule_func(optimizer)
|
111 |
+
|
112 |
+
# All other schedulers require `num_warmup_steps`
|
113 |
+
if num_warmup_steps is None:
|
114 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
115 |
+
|
116 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
117 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
118 |
+
|
119 |
+
if name == SchedulerType.INVERSE_SQRT:
|
120 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
121 |
+
|
122 |
+
# All other schedulers require `num_training_steps`
|
123 |
+
if num_training_steps is None:
|
124 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
125 |
+
|
126 |
+
logger.info(f'Initialize lr scheduler with min_lr_ratio: {min_lr_ratio}')
|
127 |
+
return schedule_func(optimizer,
|
128 |
+
num_warmup_steps=num_warmup_steps,
|
129 |
+
num_training_steps=num_training_steps,
|
130 |
+
min_lr_ratio=min_lr_ratio)
|
src/train/train.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
import hydra
|
3 |
+
|
4 |
+
import pyrootutils
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
from accelerate import Accelerator
|
8 |
+
from accelerate.logging import get_logger
|
9 |
+
from accelerate.utils import ProjectConfiguration
|
10 |
+
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
from omegaconf.dictconfig import DictConfig
|
14 |
+
import argparse
|
15 |
+
from flask import Flask, request
|
16 |
+
from typing import List, Union
|
17 |
+
import json
|
18 |
+
from typing import Optional
|
19 |
+
import transformers
|
20 |
+
from dataclasses import dataclass, field, asdict, is_dataclass
|
21 |
+
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, DistributedReadingService, \
|
22 |
+
SequentialReadingService
|
23 |
+
import logging
|
24 |
+
|
25 |
+
pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
|
26 |
+
from src.train.schedular import get_scheduler
|
27 |
+
from src.train.dist_utils import all_gather
|
28 |
+
|
29 |
+
# logger = get_logger(__name__, log_level='info')
|
30 |
+
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
31 |
+
logging.basicConfig(level=logging.INFO, format=log_format)
|
32 |
+
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
os.environ["WANDB_MODE"] = "offline"
|
35 |
+
|
36 |
+
|
37 |
+
@dataclass
|
38 |
+
class ConfigPathArguments:
|
39 |
+
image_transform: Optional[str] = field(default=None, metadata={"help": "config path of image transform"})
|
40 |
+
tokenizer: Optional[str] = field(default=None,
|
41 |
+
metadata={"help": "config path of tokenizer used to initialize tokenizer"})
|
42 |
+
# model: Optional[str] = field(default=None, metadata={"help": "config path of llm"})
|
43 |
+
visual_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"})
|
44 |
+
text_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"})
|
45 |
+
discrete_model: Optional[str] = field(default=None, metadata={"help": "config path of discrete model"})
|
46 |
+
train_dataset: Optional[str] = field(default=None, metadata={"help": "config path of training dataset"})
|
47 |
+
|
48 |
+
|
49 |
+
@dataclass
|
50 |
+
class TrainingArguments:
|
51 |
+
output_dir: str = field(
|
52 |
+
metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, )
|
53 |
+
resume_from_checkpoint: Optional[str] = field(
|
54 |
+
default=None, metadata={"help": "The path to a folder with a valid checkpoint for your model."})
|
55 |
+
resume_steps: Optional[int] = field(default=None, metadata={"help": "The training sterps of saved checkpoint"})
|
56 |
+
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
57 |
+
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
58 |
+
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
59 |
+
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
|
60 |
+
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
|
61 |
+
max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
|
62 |
+
gradient_accumulation_steps: int = field(
|
63 |
+
default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."})
|
64 |
+
mixed_precision: Optional[str] = field(
|
65 |
+
default='no',
|
66 |
+
metadata={
|
67 |
+
"help":
|
68 |
+
"Whether to use mixed precision. \
|
69 |
+
Choose between fp16 and bf16 (bfloat16). \
|
70 |
+
Bf16 requires PyTorch >=1.10.and an Nvidia Ampere GPU."
|
71 |
+
})
|
72 |
+
num_train_epochs: int = field(default=3, metadata={"help": "Total number of training epochs to perform."})
|
73 |
+
max_steps: int = field(default=-1, metadata={"help": "Total number of training steps to perform. "})
|
74 |
+
save_steps: int = field(default=10000, metadata={"help": "Number of updates steps before two checkpoint saves."})
|
75 |
+
lr_scheduler_type: str = field(default="cosine", metadata={"help": "The scheduler type to use."})
|
76 |
+
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
77 |
+
min_lr_ratio: float = field(default=0.01, metadata={"help": "Minimal learning rate ratio."})
|
78 |
+
dataloader_num_workers: int = field(default=8, metadata={"help": "The number of workers to use for data loading."})
|
79 |
+
project_name: str = field(default="DiscreteLearning", metadata={"help": "The name of experiment"})
|
80 |
+
expr_name: str = field(default="", metadata={"help": "The name of experiment"})
|
81 |
+
|
82 |
+
|
83 |
+
def build_dataloader(dataset_cfg, image_transform, tokenizer, dataloader_num_workers=4):
|
84 |
+
dataset = hydra.utils.instantiate(dataset_cfg, image_transform=image_transform, tokenizer=tokenizer)
|
85 |
+
mp_service = MultiProcessingReadingService(num_workers=dataloader_num_workers)
|
86 |
+
dist_service = DistributedReadingService()
|
87 |
+
reading_service = SequentialReadingService(dist_service, mp_service)
|
88 |
+
dataloader = DataLoader2(dataset, reading_service=reading_service)
|
89 |
+
return dataloader
|
90 |
+
|
91 |
+
|
92 |
+
def get_metric(output):
|
93 |
+
metric = {}
|
94 |
+
for key, value in output.items():
|
95 |
+
if 'loss' in key:
|
96 |
+
metric[key] = value.item()
|
97 |
+
return metric
|
98 |
+
|
99 |
+
|
100 |
+
def get_code_usage(indices):
|
101 |
+
indices_list = all_gather(indices)
|
102 |
+
indices = torch.cat(indices_list, dim=0)
|
103 |
+
code_usage = indices.unique().numel()
|
104 |
+
return code_usage
|
105 |
+
|
106 |
+
|
107 |
+
def merge_config(**kwargs):
|
108 |
+
config = {}
|
109 |
+
for key, value in kwargs.items():
|
110 |
+
if isinstance(value, argparse.Namespace):
|
111 |
+
config[key] = vars(value)
|
112 |
+
elif isinstance(value, DictConfig):
|
113 |
+
config[key] = OmegaConf.to_object(value)
|
114 |
+
elif is_dataclass(value):
|
115 |
+
config[key] = asdict(value)
|
116 |
+
elif isinstance(value, dict):
|
117 |
+
config[key] = value
|
118 |
+
else:
|
119 |
+
logger.error(f'key: {key}, value: {value} will not be merged.')
|
120 |
+
return config
|
121 |
+
|
122 |
+
|
123 |
+
def trainable_params(model):
|
124 |
+
count = 0
|
125 |
+
for name, param in model.named_parameters():
|
126 |
+
count += param.numel()
|
127 |
+
return count
|
128 |
+
|
129 |
+
|
130 |
+
def train():
|
131 |
+
parser = transformers.HfArgumentParser((ConfigPathArguments, TrainingArguments))
|
132 |
+
cfg_path, args = parser.parse_args_into_dataclasses()
|
133 |
+
|
134 |
+
project_config = ProjectConfiguration(project_dir=args.output_dir,
|
135 |
+
logging_dir=os.path.join(args.output_dir, 'logs'))
|
136 |
+
|
137 |
+
accelerator = Accelerator(
|
138 |
+
mixed_precision=args.mixed_precision,
|
139 |
+
log_with=['tensorboard', 'wandb'],
|
140 |
+
project_config=project_config,
|
141 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
142 |
+
step_scheduler_with_optimizer=False,
|
143 |
+
)
|
144 |
+
logger.info('Init accelerator done.')
|
145 |
+
|
146 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
147 |
+
|
148 |
+
visual_encoder_cfg = OmegaConf.load(cfg_path.visual_encoder)
|
149 |
+
visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
|
150 |
+
logger.info('Load visual encoder done.')
|
151 |
+
|
152 |
+
discrete_model_cfg = OmegaConf.load(cfg_path.discrete_model)
|
153 |
+
discrete_model = hydra.utils.instantiate(discrete_model_cfg)
|
154 |
+
logger.info('Load discrete model done.')
|
155 |
+
|
156 |
+
train_dataset_cfg = OmegaConf.load(cfg_path.train_dataset)
|
157 |
+
|
158 |
+
if cfg_path.text_encoder is not None:
|
159 |
+
text_encoder_cfg = OmegaConf.load(cfg_path.text_encoder)
|
160 |
+
text_encoder = hydra.utils.instantiate(text_encoder_cfg)
|
161 |
+
else:
|
162 |
+
text_encoder_cfg = None
|
163 |
+
text_encoder = None
|
164 |
+
|
165 |
+
if cfg_path.image_transform is not None:
|
166 |
+
image_transform_cfg = OmegaConf.load(cfg_path.image_transform)
|
167 |
+
image_transform = hydra.utils.instantiate(image_transform_cfg)
|
168 |
+
else:
|
169 |
+
image_transform_cfg = None
|
170 |
+
image_transform = None
|
171 |
+
|
172 |
+
if cfg_path.tokenizer is not None:
|
173 |
+
tokenizer_cfg = OmegaConf.load(cfg_path.tokenizer)
|
174 |
+
tokenizer = hydra.utils.instantiate(tokenizer_cfg)
|
175 |
+
else:
|
176 |
+
tokenizer_cfg = None
|
177 |
+
tokenizer = None
|
178 |
+
|
179 |
+
weight_dtype = torch.float32
|
180 |
+
if accelerator.mixed_precision == "fp16":
|
181 |
+
weight_dtype = torch.float16
|
182 |
+
elif accelerator.mixed_precision == "bf16":
|
183 |
+
weight_dtype = torch.bfloat16
|
184 |
+
|
185 |
+
visual_encoder.to(accelerator.device, dtype=weight_dtype)
|
186 |
+
logger.info('Freeze visual encoder...')
|
187 |
+
visual_encoder.requires_grad_(False)
|
188 |
+
if text_encoder is not None:
|
189 |
+
logger.info('Freeze text encoder...')
|
190 |
+
text_encoder.requires_grad_(False)
|
191 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
192 |
+
discrete_model.to(accelerator.device, dtype=weight_dtype)
|
193 |
+
|
194 |
+
discrete_model = accelerator.prepare(discrete_model)
|
195 |
+
optimizer = torch.optim.AdamW(discrete_model.parameters(),
|
196 |
+
lr=args.learning_rate,
|
197 |
+
betas=[args.adam_beta1, args.adam_beta2],
|
198 |
+
eps=args.adam_epsilon,
|
199 |
+
weight_decay=args.weight_decay)
|
200 |
+
logger.info('Init optimizer done.')
|
201 |
+
scheduler = get_scheduler(name=args.lr_scheduler_type,
|
202 |
+
optimizer=optimizer,
|
203 |
+
num_warmup_steps=args.warmup_steps,
|
204 |
+
num_training_steps=args.max_steps,
|
205 |
+
min_lr_ratio=args.min_lr_ratio)
|
206 |
+
# accelerator.register_for_checkpointing(scheduler)
|
207 |
+
|
208 |
+
optimizer, scheduler = accelerator.prepare(optimizer, scheduler)
|
209 |
+
logger.info('Prepare accelerator done.')
|
210 |
+
|
211 |
+
config_record = merge_config(discrete_model=discrete_model_cfg,
|
212 |
+
visual_encoder=visual_encoder_cfg,
|
213 |
+
text_encoder=text_encoder_cfg,
|
214 |
+
image_transform=image_transform_cfg,
|
215 |
+
tokenizer=tokenizer_cfg,
|
216 |
+
train_dataset=train_dataset_cfg,
|
217 |
+
train_args=args)
|
218 |
+
accelerator.init_trackers(project_name=args.project_name,
|
219 |
+
init_kwargs={"wandb": {
|
220 |
+
"config": config_record,
|
221 |
+
"name": args.expr_name,
|
222 |
+
"dir": args.output_dir
|
223 |
+
}})
|
224 |
+
if args.resume_from_checkpoint is not None:
|
225 |
+
logger.info(f'Load checkpoint from {args.resume_from_checkpoint}')
|
226 |
+
accelerator.load_state(args.resume_from_checkpoint)
|
227 |
+
|
228 |
+
num_params = trainable_params(discrete_model)
|
229 |
+
logger.info("***** Running training *****")
|
230 |
+
logger.info(f" Total optimization steps = {args.max_steps}")
|
231 |
+
logger.info(f" Total trainable params = {num_params}")
|
232 |
+
# Only show the progress bar once on each machine.
|
233 |
+
progress_bar = tqdm(range(args.max_steps), disable=not accelerator.is_main_process)
|
234 |
+
progress_bar.set_description("Steps")
|
235 |
+
global_step = 0
|
236 |
+
if args.resume_steps is not None:
|
237 |
+
global_step = args.resume_steps
|
238 |
+
progress_bar.update(args.resume_steps)
|
239 |
+
|
240 |
+
train_dataloader = build_dataloader(dataset_cfg=train_dataset_cfg,
|
241 |
+
image_transform=image_transform,
|
242 |
+
tokenizer=tokenizer,
|
243 |
+
dataloader_num_workers=args.dataloader_num_workers)
|
244 |
+
for epoch in range(args.num_train_epochs):
|
245 |
+
discrete_model.train()
|
246 |
+
logger.info('Start new epoch')
|
247 |
+
|
248 |
+
for step, batch in enumerate(train_dataloader):
|
249 |
+
with accelerator.accumulate(discrete_model):
|
250 |
+
with torch.no_grad():
|
251 |
+
image_embeds = visual_encoder(batch['images'].to(accelerator.device, dtype=weight_dtype))
|
252 |
+
if text_encoder is not None:
|
253 |
+
text_embeds = text_encoder(batch['text_input_ids'].to(accelerator.device))
|
254 |
+
else:
|
255 |
+
text_embeds = None
|
256 |
+
|
257 |
+
output = discrete_model(image_embeds=image_embeds, text_embeds=text_embeds)
|
258 |
+
|
259 |
+
loss = output['total_loss']
|
260 |
+
accelerator.backward(loss)
|
261 |
+
if accelerator.sync_gradients:
|
262 |
+
accelerator.clip_grad_norm_(discrete_model.parameters(), max_norm=args.max_grad_norm)
|
263 |
+
optimizer.step()
|
264 |
+
scheduler.step()
|
265 |
+
optimizer.zero_grad()
|
266 |
+
|
267 |
+
if accelerator.sync_gradients:
|
268 |
+
progress_bar.update(1)
|
269 |
+
global_step += 1
|
270 |
+
|
271 |
+
if global_step % args.save_steps == 0:
|
272 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
273 |
+
accelerator.save_state(save_path)
|
274 |
+
|
275 |
+
metric = get_metric(output)
|
276 |
+
metric['lr'] = optimizer.param_groups[0]['lr']
|
277 |
+
metric['code_usage'] = get_code_usage(output['indices'])
|
278 |
+
metric = {key: (format(value, ".6f") if isinstance(value, float) else value) for key, value in
|
279 |
+
metric.items()}
|
280 |
+
accelerator.log(metric, step=global_step)
|
281 |
+
if accelerator.is_main_process:
|
282 |
+
tqdm.write(str(metric))
|
283 |
+
# print(metric)
|
284 |
+
if global_step >= args.max_steps:
|
285 |
+
break
|
286 |
+
|
287 |
+
accelerator.end_training()
|
288 |
+
|
289 |
+
|
290 |
+
if __name__ == '__main__':
|
291 |
+
train()
|
src/train/train_clm_sft.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
import hydra
|
3 |
+
|
4 |
+
import pyrootutils
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
from accelerate import Accelerator
|
8 |
+
from accelerate.logging import get_logger
|
9 |
+
from accelerate.utils import ProjectConfiguration
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
|
12 |
+
from deepspeed.runtime.engine import DummyOptim
|
13 |
+
from tqdm.auto import tqdm
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
from omegaconf.dictconfig import DictConfig
|
16 |
+
import argparse
|
17 |
+
from flask import Flask, request
|
18 |
+
from typing import List, Union
|
19 |
+
import json
|
20 |
+
from typing import Optional
|
21 |
+
import transformers
|
22 |
+
from dataclasses import dataclass, field, asdict, is_dataclass
|
23 |
+
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, DistributedReadingService, \
|
24 |
+
SequentialReadingService
|
25 |
+
import gc
|
26 |
+
import logging
|
27 |
+
from accelerate import FullyShardedDataParallelPlugin, DistributedDataParallelKwargs
|
28 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
|
29 |
+
|
30 |
+
pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
|
31 |
+
from src.train.schedular import get_scheduler
|
32 |
+
from src.train.dist_utils import all_gather
|
33 |
+
|
34 |
+
# logger = get_logger(__name__, log_level='info')
|
35 |
+
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
36 |
+
logging.basicConfig(level=logging.INFO, format=log_format)
|
37 |
+
|
38 |
+
logger = logging.getLogger(__name__)
|
39 |
+
os.environ["WANDB_MODE"] = "offline"
|
40 |
+
|
41 |
+
|
42 |
+
@dataclass
|
43 |
+
class ConfigPathArguments:
|
44 |
+
image_transform: Optional[str] = field(default=None, metadata={"help": "config path of image transform"})
|
45 |
+
tokenizer: Optional[str] = field(default=None,
|
46 |
+
metadata={"help": "config path of tokenizer used to initialize tokenizer"})
|
47 |
+
# model: Optional[str] = field(default=None, metadata={"help": "config path of llm"})
|
48 |
+
visual_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"})
|
49 |
+
llm_model: Optional[str] = field(default=None, metadata={"help": "config path of llm"})
|
50 |
+
agent_model: Optional[str] = field(default=None, metadata={"help": "config path of agent"})
|
51 |
+
train_dataset: Optional[str] = field(default=None, metadata={"help": "config path of training dataset"})
|
52 |
+
fsdp_plugin: Optional[str] = field(default=None, metadata={"help": "config path of fsdp plugin"})
|
53 |
+
deepspeed_plugin: Optional[str] = field(default=None, metadata={"help": "config path of deepspeed plugin"})
|
54 |
+
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class TrainingArguments:
|
58 |
+
output_dir: str = field(
|
59 |
+
metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, )
|
60 |
+
resume_from_checkpoint: Optional[str] = field(
|
61 |
+
default=None, metadata={"help": "The path to a folder with a valid checkpoint for your model."})
|
62 |
+
resume_steps: Optional[int] = field(default=None, metadata={"help": "The training sterps of saved checkpoint"})
|
63 |
+
batch_size: Optional[int] = field(default=60, metadata={"help": "The training batch size"})
|
64 |
+
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
65 |
+
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
66 |
+
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
67 |
+
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
|
68 |
+
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
|
69 |
+
max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
|
70 |
+
gradient_accumulation_steps: int = field(
|
71 |
+
default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."})
|
72 |
+
mixed_precision: Optional[str] = field(
|
73 |
+
default='no',
|
74 |
+
metadata={
|
75 |
+
"help":
|
76 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=1.10.and an Nvidia Ampere GPU."
|
77 |
+
})
|
78 |
+
num_train_epochs: int = field(default=3, metadata={"help": "Total number of training epochs to perform."})
|
79 |
+
max_steps: int = field(default=-1, metadata={"help": "Total number of training steps to perform. "})
|
80 |
+
save_steps: int = field(default=10000, metadata={"help": "Number of updates steps before two checkpoint saves."})
|
81 |
+
lr_scheduler_type: str = field(default="cosine", metadata={"help": "The scheduler type to use."})
|
82 |
+
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
83 |
+
min_lr_ratio: float = field(default=0.01, metadata={"help": "Minimal learning rate ratio."})
|
84 |
+
dataloader_num_workers: int = field(default=8, metadata={"help": "The number of workers to use for data loading."})
|
85 |
+
project_name: str = field(default="ContinuousVLM", metadata={"help": "The name of experiment"})
|
86 |
+
expr_name: str = field(default="", metadata={"help": "The name of experiment"})
|
87 |
+
|
88 |
+
|
89 |
+
def build_dataloader(dataset_cfg, image_transform, tokenizer, batch_size, dataloader_num_workers=4):
|
90 |
+
dataset = hydra.utils.instantiate(dataset_cfg, image_transform=image_transform, tokenizer=tokenizer)
|
91 |
+
mp_service = MultiProcessingReadingService(num_workers=dataloader_num_workers)
|
92 |
+
dist_service = DistributedReadingService()
|
93 |
+
reading_service = SequentialReadingService(dist_service, mp_service)
|
94 |
+
dataloader = DataLoader2(dataset, reading_service=reading_service)
|
95 |
+
# dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=dataloader_num_workers)
|
96 |
+
return dataloader
|
97 |
+
|
98 |
+
|
99 |
+
def get_metric(output):
|
100 |
+
metric = {}
|
101 |
+
for key, value in output.items():
|
102 |
+
if 'loss' in key:
|
103 |
+
gathered_metric = torch.stack(all_gather(value)).mean()
|
104 |
+
# metric[key] = value.item()
|
105 |
+
metric[key] = gathered_metric.item()
|
106 |
+
if 'acc' in key:
|
107 |
+
metric[key] = value.item()
|
108 |
+
return metric
|
109 |
+
|
110 |
+
|
111 |
+
def merge_config(**kwargs):
|
112 |
+
config = {}
|
113 |
+
for key, value in kwargs.items():
|
114 |
+
if isinstance(value, argparse.Namespace):
|
115 |
+
config[key] = vars(value)
|
116 |
+
elif isinstance(value, DictConfig):
|
117 |
+
config[key] = OmegaConf.to_object(value)
|
118 |
+
elif is_dataclass(value):
|
119 |
+
config[key] = asdict(value)
|
120 |
+
elif isinstance(value, (int, str, float, dict)) or value is None:
|
121 |
+
config[key] = value
|
122 |
+
else:
|
123 |
+
logger.error(f'key: {key}, value: {value} will not be merged.')
|
124 |
+
return config
|
125 |
+
|
126 |
+
|
127 |
+
def trainable_params(model):
|
128 |
+
count = 0
|
129 |
+
for name, param in model.named_parameters():
|
130 |
+
if param.requires_grad:
|
131 |
+
count += param.numel()
|
132 |
+
return count
|
133 |
+
|
134 |
+
|
135 |
+
def train():
|
136 |
+
parser = transformers.HfArgumentParser((ConfigPathArguments, TrainingArguments))
|
137 |
+
cfg_path, args = parser.parse_args_into_dataclasses()
|
138 |
+
|
139 |
+
project_config = ProjectConfiguration(project_dir=args.output_dir,
|
140 |
+
logging_dir=os.path.join(args.output_dir, 'logs'))
|
141 |
+
|
142 |
+
assert int(cfg_path.fsdp_plugin is not None) + int(cfg_path.deepspeed_plugin is not None) <= 1
|
143 |
+
if cfg_path.fsdp_plugin is not None:
|
144 |
+
fsdp_plugin_cfg = OmegaConf.load(cfg_path.fsdp_plugin)
|
145 |
+
fsdp_plugin = hydra.utils.instantiate(fsdp_plugin_cfg)
|
146 |
+
logger.info('Use FSDP plugin')
|
147 |
+
else:
|
148 |
+
fsdp_plugin = None
|
149 |
+
|
150 |
+
if cfg_path.deepspeed_plugin is not None:
|
151 |
+
deepspeed_plugin_cfg = OmegaConf.load(cfg_path.deepspeed_plugin)
|
152 |
+
deepspeed_plugin = hydra.utils.instantiate(deepspeed_plugin_cfg)
|
153 |
+
logger.info('Use deepspeed plugin')
|
154 |
+
else:
|
155 |
+
deepspeed_plugin = None
|
156 |
+
|
157 |
+
# ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
158 |
+
accelerator = Accelerator(
|
159 |
+
mixed_precision=args.mixed_precision,
|
160 |
+
log_with=['tensorboard', 'wandb'],
|
161 |
+
project_config=project_config,
|
162 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
163 |
+
step_scheduler_with_optimizer=False,
|
164 |
+
fsdp_plugin=fsdp_plugin,
|
165 |
+
deepspeed_plugin=deepspeed_plugin,
|
166 |
+
# kwargs_handlers=[ddp_kwargs],
|
167 |
+
)
|
168 |
+
accelerator.wait_for_everyone()
|
169 |
+
logger.info('Init accelerator done.')
|
170 |
+
|
171 |
+
if cfg_path.deepspeed_plugin is not None:
|
172 |
+
accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = 8
|
173 |
+
|
174 |
+
# print('deepspeed config: ', accelerator.state.deepspeed_plugin.deepspeed_config)
|
175 |
+
|
176 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
177 |
+
|
178 |
+
# if cfg_path.image_transform is not None:
|
179 |
+
image_transform_cfg = OmegaConf.load(cfg_path.image_transform)
|
180 |
+
image_transform = hydra.utils.instantiate(image_transform_cfg)
|
181 |
+
# else:
|
182 |
+
# image_transform_cfg = None
|
183 |
+
# image_transform = None
|
184 |
+
|
185 |
+
# if cfg_path.tokenizer is not None:
|
186 |
+
tokenizer_cfg = OmegaConf.load(cfg_path.tokenizer)
|
187 |
+
tokenizer = hydra.utils.instantiate(tokenizer_cfg)
|
188 |
+
# else:
|
189 |
+
# tokenizer_cfg = None
|
190 |
+
# tokenizer = None
|
191 |
+
train_dataset_cfg = OmegaConf.load(cfg_path.train_dataset)
|
192 |
+
|
193 |
+
visual_encoder_cfg = OmegaConf.load(cfg_path.visual_encoder)
|
194 |
+
visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
|
195 |
+
logger.info('Load visual encoder done.')
|
196 |
+
|
197 |
+
llm_model_cfg = OmegaConf.load(cfg_path.llm_model)
|
198 |
+
llm_model = hydra.utils.instantiate(llm_model_cfg)
|
199 |
+
llm_model.gradient_checkpointing_enable()
|
200 |
+
llm_model.config.use_cache = False
|
201 |
+
logger.info('Load llm model done.')
|
202 |
+
|
203 |
+
agent_model_cfg = OmegaConf.load(cfg_path.agent_model)
|
204 |
+
agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm_model)
|
205 |
+
logger.info('Load agent model done.')
|
206 |
+
|
207 |
+
weight_dtype = torch.float32
|
208 |
+
if accelerator.mixed_precision == "fp16":
|
209 |
+
weight_dtype = torch.float16
|
210 |
+
elif accelerator.mixed_precision == "bf16":
|
211 |
+
weight_dtype = torch.bfloat16
|
212 |
+
|
213 |
+
visual_encoder.to(accelerator.device, dtype=weight_dtype)
|
214 |
+
logger.info('Freeze visual encoder...')
|
215 |
+
visual_encoder.requires_grad_(False)
|
216 |
+
|
217 |
+
if cfg_path.fsdp_plugin is not None:
|
218 |
+
agent_model = accelerator.prepare(agent_model)
|
219 |
+
|
220 |
+
optimizer = torch.optim.AdamW(agent_model.parameters(),
|
221 |
+
lr=args.learning_rate,
|
222 |
+
betas=[args.adam_beta1, args.adam_beta2],
|
223 |
+
eps=args.adam_epsilon,
|
224 |
+
weight_decay=args.weight_decay)
|
225 |
+
logger.info('Init optimizer done.')
|
226 |
+
scheduler = get_scheduler(name=args.lr_scheduler_type,
|
227 |
+
optimizer=optimizer,
|
228 |
+
num_warmup_steps=args.warmup_steps,
|
229 |
+
num_training_steps=args.max_steps,
|
230 |
+
min_lr_ratio=args.min_lr_ratio)
|
231 |
+
# accelerator.register_for_checkpointing(scheduler)
|
232 |
+
train_dataloader = build_dataloader(dataset_cfg=train_dataset_cfg,
|
233 |
+
image_transform=image_transform,
|
234 |
+
tokenizer=tokenizer,
|
235 |
+
batch_size=args.batch_size,
|
236 |
+
dataloader_num_workers=args.dataloader_num_workers)
|
237 |
+
if cfg_path.fsdp_plugin is not None:
|
238 |
+
optimizer, scheduler = accelerator.prepare(optimizer, scheduler)
|
239 |
+
else:
|
240 |
+
agent_model, optimizer, scheduler = accelerator.prepare(agent_model, optimizer, scheduler)
|
241 |
+
logger.info('Prepare accelerator done.')
|
242 |
+
|
243 |
+
config_record = merge_config(agent_model=agent_model_cfg,
|
244 |
+
llm_model=llm_model,
|
245 |
+
visual_encoder=visual_encoder_cfg,
|
246 |
+
image_transform=image_transform_cfg,
|
247 |
+
tokenizer=tokenizer_cfg,
|
248 |
+
train_dataset=train_dataset_cfg,
|
249 |
+
train_args=args)
|
250 |
+
accelerator.init_trackers(project_name=args.project_name,
|
251 |
+
init_kwargs={"wandb": {
|
252 |
+
"config": config_record,
|
253 |
+
"name": args.expr_name,
|
254 |
+
"dir": args.output_dir
|
255 |
+
}})
|
256 |
+
if args.resume_from_checkpoint is not None:
|
257 |
+
logger.info(f'Load checkpoint from {args.resume_from_checkpoint}')
|
258 |
+
accelerator.load_state(args.resume_from_checkpoint)
|
259 |
+
torch.cuda.empty_cache()
|
260 |
+
gc.collect()
|
261 |
+
|
262 |
+
num_params = trainable_params(agent_model)
|
263 |
+
logger.info("***** Running training *****")
|
264 |
+
logger.info(f" Total optimization steps = {args.max_steps}")
|
265 |
+
logger.info(f" Total trainable params = {num_params}")
|
266 |
+
# Only show the progress bar once on each machine.
|
267 |
+
progress_bar = tqdm(range(args.max_steps), disable=not accelerator.is_main_process)
|
268 |
+
progress_bar.set_description("Steps")
|
269 |
+
global_step = 0
|
270 |
+
if args.resume_steps is not None:
|
271 |
+
global_step = args.resume_steps
|
272 |
+
progress_bar.update(args.resume_steps)
|
273 |
+
|
274 |
+
for epoch in range(args.num_train_epochs):
|
275 |
+
agent_model.train()
|
276 |
+
logger.info('Start new epoch')
|
277 |
+
|
278 |
+
for step, batch in enumerate(train_dataloader):
|
279 |
+
with accelerator.accumulate(agent_model):
|
280 |
+
# accelerator.wait_for_everyone()
|
281 |
+
# print('1')
|
282 |
+
with torch.no_grad():
|
283 |
+
if batch['images'] is not None:
|
284 |
+
image_embeds = visual_encoder(batch['images'].to(accelerator.device, dtype=weight_dtype))
|
285 |
+
# image_embeds = visual_encoder(batch['images'])
|
286 |
+
else:
|
287 |
+
image_embeds = None
|
288 |
+
# accelerator.wait_for_everyone()
|
289 |
+
# print('2')
|
290 |
+
output = agent_model(input_ids=batch['input_ids'].to(accelerator.device),
|
291 |
+
attention_mask=batch['attention_mask'].to(accelerator.device),
|
292 |
+
labels=batch['labels'].to(accelerator.device),
|
293 |
+
image_embeds=image_embeds,
|
294 |
+
embeds_gen_mask=batch['embeds_gen_mask'].to(accelerator.device)
|
295 |
+
if batch['embeds_gen_mask'] is not None else None,
|
296 |
+
embeds_cmp_mask=batch['embeds_cmp_mask'].to(accelerator.device)
|
297 |
+
if batch['embeds_cmp_mask'] is not None else None,
|
298 |
+
ids_gen_mask=batch['ids_gen_mask'].to(accelerator.device),
|
299 |
+
ids_cmp_mask=batch['ids_cmp_mask'].to(accelerator.device))
|
300 |
+
# output = agent_model(
|
301 |
+
# input_ids=batch['input_ids'], #.squeeze(0),
|
302 |
+
# attention_mask=batch['attention_mask'], # .squeeze(0),
|
303 |
+
# labels=batch['labels'], # .squeeze(0),
|
304 |
+
# image_embeds=image_embeds,
|
305 |
+
# embeds_gen_mask=batch['embeds_gen_mask'], #.squeeze(0),
|
306 |
+
# embeds_cmp_mask=batch['embeds_cmp_mask'], #.squeeze(0),
|
307 |
+
# ids_gen_mask=batch['ids_gen_mask'], #.squeeze(0),
|
308 |
+
# ids_cmp_mask=batch['ids_cmp_mask']) #.squeeze(0))
|
309 |
+
loss = output['total_loss']
|
310 |
+
# accelerator.wait_for_everyone()
|
311 |
+
# print('3')
|
312 |
+
accelerator.backward(loss)
|
313 |
+
# accelerator.wait_for_everyone()
|
314 |
+
# print('4')
|
315 |
+
if accelerator.sync_gradients:
|
316 |
+
accelerator.clip_grad_norm_(agent_model.parameters(), max_norm=args.max_grad_norm)
|
317 |
+
|
318 |
+
optimizer.step()
|
319 |
+
scheduler.step()
|
320 |
+
optimizer.zero_grad()
|
321 |
+
# accelerator.wait_for_everyone()
|
322 |
+
# print('5')
|
323 |
+
|
324 |
+
if accelerator.sync_gradients:
|
325 |
+
progress_bar.update(1)
|
326 |
+
global_step += 1
|
327 |
+
|
328 |
+
if global_step % args.save_steps == 0:
|
329 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
330 |
+
accelerator.save_state(save_path)
|
331 |
+
|
332 |
+
metric = get_metric(output)
|
333 |
+
metric['lr'] = optimizer.param_groups[0]['lr']
|
334 |
+
accelerator.log(metric, step=global_step)
|
335 |
+
metric = {key: (format(value, ".6f") if isinstance(value, float) else value) for key, value in
|
336 |
+
metric.items()}
|
337 |
+
if accelerator.is_main_process:
|
338 |
+
tqdm.write(str(metric))
|
339 |
+
# print(metric)
|
340 |
+
if global_step >= args.max_steps:
|
341 |
+
break
|
342 |
+
|
343 |
+
accelerator.end_training()
|
344 |
+
|
345 |
+
|
346 |
+
if __name__ == '__main__':
|
347 |
+
train()
|
src/train/train_sdxl_img2img_llm.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
import hydra
|
3 |
+
|
4 |
+
import pyrootutils
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
from accelerate import Accelerator
|
8 |
+
from accelerate.logging import get_logger
|
9 |
+
from accelerate.utils import ProjectConfiguration
|
10 |
+
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
from omegaconf.dictconfig import DictConfig
|
14 |
+
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler, \
|
15 |
+
Transformer2DModel
|
16 |
+
|
17 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
18 |
+
import argparse
|
19 |
+
from flask import Flask, request
|
20 |
+
from typing import List, Union
|
21 |
+
import json
|
22 |
+
from typing import Optional
|
23 |
+
import transformers
|
24 |
+
from dataclasses import dataclass, field, asdict, is_dataclass
|
25 |
+
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, DistributedReadingService, \
|
26 |
+
SequentialReadingService
|
27 |
+
import logging
|
28 |
+
|
29 |
+
pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)
|
30 |
+
from src.train.schedular import get_scheduler
|
31 |
+
from src.train.dist_utils import all_gather
|
32 |
+
|
33 |
+
# logger = get_logger(__name__, log_level='info')
|
34 |
+
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
35 |
+
logging.basicConfig(level=logging.INFO, format=log_format)
|
36 |
+
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
|
40 |
+
# os.environ["WANDB_MODE"] = "offline"
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class ConfigPathArguments:
|
45 |
+
image_transform: Optional[str] = field(default=None, metadata={"help": "config path of image transform"})
|
46 |
+
sd_image_transform: Optional[str] = field(default=None,
|
47 |
+
metadata={"help": "config path of stable diffusion image transform"})
|
48 |
+
# tokenizer: Optional[str] = field(default=None, metadata={"help": "config path of tokenizer used to initialize tokenizer"})
|
49 |
+
visual_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"})
|
50 |
+
# text_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"})
|
51 |
+
discrete_model: Optional[str] = field(default=None, metadata={"help": "config path of discrete model"})
|
52 |
+
# noise_scheduler: Optional[str] = field(default=None, metadata={"help": "config path of noise scheduler"})
|
53 |
+
# vae: Optional[str] = field(default=None, metadata={"help": "config path of vae"})
|
54 |
+
adapter: Optional[str] = field(default=None, metadata={"help": "config path of adapter"})
|
55 |
+
train_dataset: Optional[str] = field(default=None, metadata={"help": "config path of training dataset"})
|
56 |
+
fsdp_plugin: Optional[str] = field(default=None, metadata={"help": "config path of fsdp plugin"})
|
57 |
+
deepspeed_plugin: Optional[str] = field(default=None, metadata={"help": "config path of deepspeed plugin"})
|
58 |
+
tokenizer: Optional[str] = field(default=None,
|
59 |
+
metadata={"help": "config path of tokenizer used to initialize tokenizer"})
|
60 |
+
llm_model: Optional[str] = field(default=None, metadata={"help": "config path of llm"})
|
61 |
+
agent_model: Optional[str] = field(default=None, metadata={"help": "config path of agent"})
|
62 |
+
|
63 |
+
|
64 |
+
@dataclass
|
65 |
+
class TrainingArguments:
|
66 |
+
output_dir: str = field(
|
67 |
+
metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, )
|
68 |
+
diffusion_model_path: Optional[str] = field(default=None, metadata={"help": "config path of training dataset"})
|
69 |
+
resume_from_checkpoint: Optional[str] = field(
|
70 |
+
default=None, metadata={"help": "The path to a folder with a valid checkpoint for your model."})
|
71 |
+
resume_steps: Optional[int] = field(default=None, metadata={"help": "The training sterps of saved checkpoint"})
|
72 |
+
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
73 |
+
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
74 |
+
# adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
75 |
+
# adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
|
76 |
+
# adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
|
77 |
+
max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
|
78 |
+
gradient_accumulation_steps: int = field(
|
79 |
+
default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."})
|
80 |
+
mixed_precision: Optional[str] = field(
|
81 |
+
default='no',
|
82 |
+
metadata={
|
83 |
+
"help":
|
84 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=1.10.and an Nvidia Ampere GPU."
|
85 |
+
})
|
86 |
+
num_train_epochs: int = field(default=3, metadata={"help": "Total number of training epochs to perform."})
|
87 |
+
max_steps: int = field(default=-1, metadata={"help": "Total number of training steps to perform. "})
|
88 |
+
save_steps: int = field(default=10000, metadata={"help": "Number of updates steps before two checkpoint saves."})
|
89 |
+
lr_scheduler_type: str = field(default="cosine", metadata={"help": "The scheduler type to use."})
|
90 |
+
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
91 |
+
min_lr_ratio: float = field(default=0.01, metadata={"help": "Minimal learning rate ratio."})
|
92 |
+
dataloader_num_workers: int = field(default=8, metadata={"help": "The number of workers to use for data loading."})
|
93 |
+
project_name: str = field(default="IPAdapter", metadata={"help": "The name of experiment"})
|
94 |
+
expr_name: str = field(default="", metadata={"help": "The name of experiment"})
|
95 |
+
|
96 |
+
|
97 |
+
def build_dataloader(dataset_cfg, image_transform, sd_image_transform, tokenizer, dataloader_num_workers=4):
|
98 |
+
dataset = hydra.utils.instantiate(dataset_cfg,
|
99 |
+
image_transform=image_transform,
|
100 |
+
sd_image_transform=sd_image_transform,
|
101 |
+
tokenizer=tokenizer)
|
102 |
+
mp_service = MultiProcessingReadingService(num_workers=dataloader_num_workers)
|
103 |
+
dist_service = DistributedReadingService()
|
104 |
+
reading_service = SequentialReadingService(dist_service, mp_service)
|
105 |
+
dataloader = DataLoader2(dataset, reading_service=reading_service)
|
106 |
+
return dataloader
|
107 |
+
|
108 |
+
|
109 |
+
def get_metric(output):
|
110 |
+
metric = {}
|
111 |
+
for key, value in output.items():
|
112 |
+
if 'loss' in key:
|
113 |
+
metric[key] = value.item()
|
114 |
+
return metric
|
115 |
+
|
116 |
+
|
117 |
+
def merge_config(**kwargs):
|
118 |
+
config = {}
|
119 |
+
for key, value in kwargs.items():
|
120 |
+
if isinstance(value, argparse.Namespace):
|
121 |
+
config[key] = vars(value)
|
122 |
+
elif isinstance(value, DictConfig):
|
123 |
+
config[key] = OmegaConf.to_object(value)
|
124 |
+
elif is_dataclass(value):
|
125 |
+
config[key] = asdict(value)
|
126 |
+
elif isinstance(value, dict):
|
127 |
+
config[key] = value
|
128 |
+
else:
|
129 |
+
logger.error(f'key: {key}, value: {value} will not be merged.')
|
130 |
+
return config
|
131 |
+
|
132 |
+
|
133 |
+
def trainable_params(model):
|
134 |
+
count = 0
|
135 |
+
for name, param in model.named_parameters():
|
136 |
+
if param.requires_grad:
|
137 |
+
count += param.numel()
|
138 |
+
return count
|
139 |
+
|
140 |
+
|
141 |
+
def train():
|
142 |
+
parser = transformers.HfArgumentParser((ConfigPathArguments, TrainingArguments))
|
143 |
+
cfg_path, args = parser.parse_args_into_dataclasses()
|
144 |
+
|
145 |
+
project_config = ProjectConfiguration(project_dir=args.output_dir,
|
146 |
+
logging_dir=os.path.join(args.output_dir, 'logs'))
|
147 |
+
|
148 |
+
assert int(cfg_path.fsdp_plugin is not None) + int(cfg_path.deepspeed_plugin is not None) <= 1
|
149 |
+
if cfg_path.fsdp_plugin is not None:
|
150 |
+
fsdp_plugin_cfg = OmegaConf.load(cfg_path.fsdp_plugin)
|
151 |
+
fsdp_plugin = hydra.utils.instantiate(fsdp_plugin_cfg)
|
152 |
+
logger.info('Use FSDP plugin')
|
153 |
+
else:
|
154 |
+
fsdp_plugin = None
|
155 |
+
|
156 |
+
if cfg_path.deepspeed_plugin is not None:
|
157 |
+
deepspeed_plugin_cfg = OmegaConf.load(cfg_path.deepspeed_plugin)
|
158 |
+
deepspeed_plugin = hydra.utils.instantiate(deepspeed_plugin_cfg)
|
159 |
+
logger.info('Use deepspeed plugin')
|
160 |
+
else:
|
161 |
+
deepspeed_plugin = None
|
162 |
+
|
163 |
+
accelerator = Accelerator(
|
164 |
+
mixed_precision=args.mixed_precision,
|
165 |
+
log_with=['tensorboard', 'wandb'],
|
166 |
+
project_config=project_config,
|
167 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
168 |
+
step_scheduler_with_optimizer=False,
|
169 |
+
fsdp_plugin=fsdp_plugin,
|
170 |
+
deepspeed_plugin=deepspeed_plugin,
|
171 |
+
)
|
172 |
+
logger.info('Init accelerator done.')
|
173 |
+
|
174 |
+
if cfg_path.deepspeed_plugin is not None:
|
175 |
+
accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = 100
|
176 |
+
|
177 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
178 |
+
|
179 |
+
image_transform_cfg = OmegaConf.load(cfg_path.image_transform)
|
180 |
+
image_transform = hydra.utils.instantiate(image_transform_cfg)
|
181 |
+
sd_image_transform_cfg = OmegaConf.load(cfg_path.sd_image_transform)
|
182 |
+
sd_image_transform = hydra.utils.instantiate(sd_image_transform_cfg)
|
183 |
+
|
184 |
+
tokenizer_cfg = OmegaConf.load(cfg_path.tokenizer)
|
185 |
+
tokenizer = hydra.utils.instantiate(tokenizer_cfg)
|
186 |
+
|
187 |
+
visual_encoder_cfg = OmegaConf.load(cfg_path.visual_encoder)
|
188 |
+
visual_encoder = hydra.utils.instantiate(visual_encoder_cfg)
|
189 |
+
logger.info('Load visual encoder done.')
|
190 |
+
|
191 |
+
discrete_model_cfg = OmegaConf.load(cfg_path.discrete_model)
|
192 |
+
discrete_model = hydra.utils.instantiate(discrete_model_cfg)
|
193 |
+
logger.info('Load discrete model done.')
|
194 |
+
|
195 |
+
# noise_scheduler_cfg = OmegaConf.load(cfg_path.noise_scheduler)
|
196 |
+
# noise_scheduler = hydra.utils.instantiate(noise_scheduler_cfg)
|
197 |
+
|
198 |
+
# if cfg_path.tokenizer is not None:
|
199 |
+
# tokenizer_cfg = OmegaConf.load(cfg_path.tokenizer)
|
200 |
+
# tokenizer = hydra.utils.instantiate(tokenizer_cfg)
|
201 |
+
# else:
|
202 |
+
# tokenizer_cfg = None
|
203 |
+
# tokenizer = None
|
204 |
+
|
205 |
+
# if cfg_path.text_encoder is not None:
|
206 |
+
# text_encoder_cfg = OmegaConf.load(cfg_path.text_encoder)
|
207 |
+
# text_encoder = hydra.utils.instantiate(text_encoder_cfg)
|
208 |
+
# logger.info('Load text encoder done.')
|
209 |
+
# else:
|
210 |
+
# text_encoder_cfg = None
|
211 |
+
# text_encoder = None
|
212 |
+
|
213 |
+
# vae_cfg = OmegaConf.load(cfg_path.vae)
|
214 |
+
# vae = hydra.utils.instantiate(vae_cfg)
|
215 |
+
# logger.info('Load vae done.')
|
216 |
+
|
217 |
+
# noise_scheduler = DDPMScheduler.from_pretrained(args.diffusion_model_path, subfolder="scheduler")
|
218 |
+
# tokenizer = CLIPTokenizer.from_pretrained(args.diffusion_model_path, subfolder="tokenizer")
|
219 |
+
# text_encoder = CLIPTextModel.from_pretrained(args.diffusion_model_path, subfolder="text_encoder")
|
220 |
+
# vae = AutoencoderKL.from_pretrained(args.diffusion_model_path, subfolder="vae")
|
221 |
+
# unet = UNet2DConditionModel.from_pretrained(args.diffusion_model_path, subfolder="unet")
|
222 |
+
# print('load diffusion model done')
|
223 |
+
|
224 |
+
# noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(args.diffusion_model_path, subfolder="scheduler")
|
225 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.diffusion_model_path, subfolder="scheduler")
|
226 |
+
text_encoder = None
|
227 |
+
vae = AutoencoderKL.from_pretrained(args.diffusion_model_path, subfolder="vae")
|
228 |
+
unet = UNet2DConditionModel.from_pretrained(args.diffusion_model_path, subfolder="unet")
|
229 |
+
|
230 |
+
unet.enable_xformers_memory_efficient_attention()
|
231 |
+
unet.enable_gradient_checkpointing()
|
232 |
+
|
233 |
+
vae.requires_grad_(False)
|
234 |
+
visual_encoder.requires_grad_(False)
|
235 |
+
discrete_model.requires_grad_(False)
|
236 |
+
|
237 |
+
adapter_cfg = OmegaConf.load(cfg_path.adapter)
|
238 |
+
adapter = hydra.utils.instantiate(adapter_cfg, unet=unet)
|
239 |
+
logger.info('Load adapter done.')
|
240 |
+
|
241 |
+
weight_dtype = torch.float32
|
242 |
+
if accelerator.mixed_precision == "fp16":
|
243 |
+
weight_dtype = torch.float16
|
244 |
+
elif accelerator.mixed_precision == "bf16":
|
245 |
+
weight_dtype = torch.bfloat16
|
246 |
+
|
247 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
248 |
+
visual_encoder.to(accelerator.device, dtype=weight_dtype)
|
249 |
+
discrete_model.to(accelerator.device, dtype=weight_dtype)
|
250 |
+
if text_encoder is not None:
|
251 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
252 |
+
|
253 |
+
train_dataset_cfg = OmegaConf.load(cfg_path.train_dataset)
|
254 |
+
train_dataloader = build_dataloader(dataset_cfg=train_dataset_cfg,
|
255 |
+
image_transform=image_transform,
|
256 |
+
sd_image_transform=sd_image_transform,
|
257 |
+
tokenizer=tokenizer,
|
258 |
+
dataloader_num_workers=args.dataloader_num_workers)
|
259 |
+
|
260 |
+
llm_model_cfg = OmegaConf.load(cfg_path.llm_model)
|
261 |
+
llm_model = hydra.utils.instantiate(llm_model_cfg)
|
262 |
+
llm_model.gradient_checkpointing_enable()
|
263 |
+
llm_model.config.use_cache = False
|
264 |
+
logger.info('Load llm model done.')
|
265 |
+
|
266 |
+
agent_model_cfg = OmegaConf.load(cfg_path.agent_model)
|
267 |
+
agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm_model).to(accelerator.device, dtype=weight_dtype)
|
268 |
+
agent_model.requires_grad_(False)
|
269 |
+
agent_model.llm.base_model.model.use_kv_cache_head = False
|
270 |
+
logger.info('Load agent model done.')
|
271 |
+
|
272 |
+
if cfg_path.fsdp_plugin is not None:
|
273 |
+
adapter = accelerator.prepare(adapter)
|
274 |
+
|
275 |
+
optimizer = torch.optim.AdamW(adapter.params_to_opt(), lr=args.learning_rate, weight_decay=args.weight_decay)
|
276 |
+
logger.info('Init optimizer done.')
|
277 |
+
scheduler = get_scheduler(name=args.lr_scheduler_type,
|
278 |
+
optimizer=optimizer,
|
279 |
+
num_warmup_steps=args.warmup_steps,
|
280 |
+
num_training_steps=args.max_steps,
|
281 |
+
min_lr_ratio=args.min_lr_ratio)
|
282 |
+
# accelerator.register_for_checkpointing(scheduler)
|
283 |
+
|
284 |
+
# adapter.adapter, adapter.resampler, optimizer, scheduler = accelerator.prepare(
|
285 |
+
# adapter.adapter,
|
286 |
+
# adapter.resampler,
|
287 |
+
# optimizer,
|
288 |
+
# scheduler,
|
289 |
+
# )
|
290 |
+
|
291 |
+
# adapter, optimizer, scheduler = accelerator.prepare(
|
292 |
+
# adapter,
|
293 |
+
# optimizer,
|
294 |
+
# scheduler,
|
295 |
+
# )
|
296 |
+
if cfg_path.fsdp_plugin is not None:
|
297 |
+
optimizer, scheduler = accelerator.prepare(optimizer, scheduler)
|
298 |
+
else:
|
299 |
+
adapter, optimizer, scheduler = accelerator.prepare(adapter, optimizer, scheduler)
|
300 |
+
logger.info('Prepare accelerator done.')
|
301 |
+
|
302 |
+
# config_record = merge_config(discrete_model=discrete_model_cfg,
|
303 |
+
# visual_encoder=visual_encoder_cfg,
|
304 |
+
# text_encoder=text_encoder_cfg,
|
305 |
+
# image_transform=image_transform_cfg,
|
306 |
+
# sd_image_transform=sd_image_transform_cfg,
|
307 |
+
# tokenizer=tokenizer_cfg,
|
308 |
+
# train_dataset=train_dataset_cfg,
|
309 |
+
# vae=vae_cfg,
|
310 |
+
# adapter=adapter_cfg,
|
311 |
+
# train_args=args)
|
312 |
+
config_record = merge_config(discrete_model=discrete_model_cfg,
|
313 |
+
visual_encoder=visual_encoder_cfg,
|
314 |
+
image_transform=image_transform_cfg,
|
315 |
+
sd_image_transform=sd_image_transform_cfg,
|
316 |
+
train_dataset=train_dataset_cfg,
|
317 |
+
adapter=adapter_cfg,
|
318 |
+
train_args=args,
|
319 |
+
agent_model=agent_model_cfg,
|
320 |
+
llm_model=llm_model,
|
321 |
+
tokenizer=tokenizer_cfg)
|
322 |
+
accelerator.init_trackers(project_name=args.project_name,
|
323 |
+
init_kwargs={"wandb": {
|
324 |
+
"config": config_record,
|
325 |
+
"name": args.expr_name,
|
326 |
+
"dir": args.output_dir
|
327 |
+
}})
|
328 |
+
if args.resume_from_checkpoint is not None:
|
329 |
+
logger.info(f'Load checkpoint from {args.resume_from_checkpoint}')
|
330 |
+
accelerator.load_state(args.resume_from_checkpoint)
|
331 |
+
|
332 |
+
num_params = trainable_params(adapter)
|
333 |
+
logger.info("***** Running training *****")
|
334 |
+
logger.info(f" Total optimization steps = {args.max_steps}")
|
335 |
+
logger.info(f" Total trainable params = {num_params}")
|
336 |
+
for name, param in adapter.named_parameters():
|
337 |
+
if param.requires_grad:
|
338 |
+
print(name)
|
339 |
+
# print(f'adapter: {trainable_params(adapter.adapter)}')
|
340 |
+
# print(f'resampler: {trainable_params(adapter.resampler)}')
|
341 |
+
# Only show the progress bar once on each machine.
|
342 |
+
progress_bar = tqdm(range(args.max_steps), disable=not accelerator.is_main_process)
|
343 |
+
progress_bar.set_description("Steps")
|
344 |
+
global_step = 0
|
345 |
+
if args.resume_steps is not None:
|
346 |
+
global_step = args.resume_steps
|
347 |
+
progress_bar.update(args.resume_steps)
|
348 |
+
|
349 |
+
for epoch in range(args.num_train_epochs):
|
350 |
+
logger.info('Start new epoch')
|
351 |
+
for step, batch in enumerate(train_dataloader):
|
352 |
+
with accelerator.accumulate(adapter):
|
353 |
+
with torch.no_grad():
|
354 |
+
image_embeds = visual_encoder(batch['images'].to(accelerator.device, dtype=weight_dtype))
|
355 |
+
image_embeds = discrete_model.encode_image_embeds(image_embeds)
|
356 |
+
if text_encoder is not None:
|
357 |
+
text_embeds = text_encoder(batch['text_input_ids'].to(accelerator.device))[0]
|
358 |
+
else:
|
359 |
+
text_embeds = None
|
360 |
+
latents = vae.encode(
|
361 |
+
batch["sd_images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()
|
362 |
+
latents = latents * vae.config.scaling_factor
|
363 |
+
llm_output = agent_model(input_ids=batch['input_ids'].to(accelerator.device),
|
364 |
+
attention_mask=batch['attention_mask'].to(accelerator.device),
|
365 |
+
labels=batch['labels'].to(accelerator.device),
|
366 |
+
image_embeds=image_embeds,
|
367 |
+
embeds_gen_mask=batch['embeds_gen_mask'].to(accelerator.device)
|
368 |
+
if batch['embeds_gen_mask'] is not None else None,
|
369 |
+
embeds_cmp_mask=batch['embeds_cmp_mask'].to(accelerator.device)
|
370 |
+
if batch['embeds_cmp_mask'] is not None else None,
|
371 |
+
ids_gen_mask=batch['ids_gen_mask'].to(accelerator.device),
|
372 |
+
ids_cmp_mask=batch['ids_cmp_mask'].to(accelerator.device),
|
373 |
+
return_recon_image_embeds=True)
|
374 |
+
|
375 |
+
time_ids = batch['time_ids'].to(accelerator.device)
|
376 |
+
|
377 |
+
# Sample noise that we'll add to the latents
|
378 |
+
noise = torch.randn_like(latents)
|
379 |
+
bsz = latents.shape[0]
|
380 |
+
# Sample a random timestep for each image
|
381 |
+
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
382 |
+
timesteps = timesteps.long()
|
383 |
+
|
384 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
385 |
+
# (this is the forward diffusion process)
|
386 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
387 |
+
|
388 |
+
output = adapter(noisy_latents=noisy_latents,
|
389 |
+
timesteps=timesteps,
|
390 |
+
image_embeds=llm_output['recon_image_embeds'],
|
391 |
+
text_embeds=None,
|
392 |
+
noise=noise,
|
393 |
+
time_ids=time_ids)
|
394 |
+
|
395 |
+
loss = output['total_loss']
|
396 |
+
accelerator.backward(loss)
|
397 |
+
if accelerator.sync_gradients:
|
398 |
+
accelerator.clip_grad_norm_(adapter.parameters(), max_norm=args.max_grad_norm)
|
399 |
+
optimizer.step()
|
400 |
+
scheduler.step()
|
401 |
+
optimizer.zero_grad()
|
402 |
+
|
403 |
+
if accelerator.sync_gradients:
|
404 |
+
progress_bar.update(1)
|
405 |
+
global_step += 1
|
406 |
+
|
407 |
+
if global_step % args.save_steps == 0:
|
408 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
409 |
+
accelerator.save_state(save_path)
|
410 |
+
|
411 |
+
metric = get_metric(output)
|
412 |
+
metric['lr'] = optimizer.param_groups[0]['lr']
|
413 |
+
accelerator.log(metric, step=global_step)
|
414 |
+
metric = {key: (format(value, ".6f") if isinstance(value, float) else value) for key, value in
|
415 |
+
metric.items()}
|
416 |
+
|
417 |
+
# if accelerator.is_local_main_process:
|
418 |
+
if accelerator.is_main_process:
|
419 |
+
tqdm.write(str(metric))
|
420 |
+
# print(metric)
|
421 |
+
if global_step >= args.max_steps:
|
422 |
+
break
|
423 |
+
|
424 |
+
accelerator.end_training()
|
425 |
+
|
426 |
+
|
427 |
+
if __name__ == '__main__':
|
428 |
+
train()
|
utils.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import logging.handlers
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
handler = None
|
8 |
+
|
9 |
+
|
10 |
+
def build_logger(logger_name, logger_dir):
|
11 |
+
global handler
|
12 |
+
|
13 |
+
formatter = logging.Formatter(
|
14 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
15 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
16 |
+
)
|
17 |
+
|
18 |
+
# Set the format of root handlers
|
19 |
+
if not logging.getLogger().handlers:
|
20 |
+
logging.basicConfig(level=logging.INFO)
|
21 |
+
logging.getLogger().handlers[0].setFormatter(formatter)
|
22 |
+
|
23 |
+
# Redirect stdout and stderr to loggers
|
24 |
+
stdout_logger = logging.getLogger("stdout")
|
25 |
+
stdout_logger.setLevel(logging.INFO)
|
26 |
+
sl = StreamToLogger(stdout_logger, logging.INFO)
|
27 |
+
sys.stdout = sl
|
28 |
+
|
29 |
+
stderr_logger = logging.getLogger("stderr")
|
30 |
+
stderr_logger.setLevel(logging.ERROR)
|
31 |
+
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
32 |
+
sys.stderr = sl
|
33 |
+
|
34 |
+
# Get logger
|
35 |
+
logger = logging.getLogger(logger_name)
|
36 |
+
logger.setLevel(logging.INFO)
|
37 |
+
|
38 |
+
# Add a file handler for all loggers
|
39 |
+
if handler is None:
|
40 |
+
os.makedirs(logger_dir, exist_ok=True)
|
41 |
+
filename = os.path.join(logger_dir, logger_name + '.log')
|
42 |
+
handler = logging.handlers.TimedRotatingFileHandler(filename, when='D', utc=True)
|
43 |
+
handler.setFormatter(formatter)
|
44 |
+
|
45 |
+
for name, item in logging.root.manager.loggerDict.items():
|
46 |
+
if isinstance(item, logging.Logger):
|
47 |
+
item.addHandler(handler)
|
48 |
+
|
49 |
+
return logger
|
50 |
+
|
51 |
+
|
52 |
+
class StreamToLogger(object):
|
53 |
+
"""
|
54 |
+
Fake file-like stream object that redirects writes to a logger instance.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self, logger, log_level=logging.INFO):
|
58 |
+
self.terminal = sys.stdout
|
59 |
+
self.logger = logger
|
60 |
+
self.log_level = log_level
|
61 |
+
self.linebuf = ''
|
62 |
+
|
63 |
+
def __getattr__(self, attr):
|
64 |
+
return getattr(self.terminal, attr)
|
65 |
+
|
66 |
+
def write(self, buf):
|
67 |
+
temp_linebuf = self.linebuf + buf
|
68 |
+
self.linebuf = ''
|
69 |
+
for line in temp_linebuf.splitlines(True):
|
70 |
+
# From the io.TextIOWrapper docs:
|
71 |
+
# On output, if newline is None, any '\n' characters written
|
72 |
+
# are translated to the system default line separator.
|
73 |
+
# By default sys.stdout.write() expects '\n' newlines and then
|
74 |
+
# translates them so this is still cross platform.
|
75 |
+
if line[-1] == '\n':
|
76 |
+
self.logger.log(self.log_level, line.rstrip())
|
77 |
+
else:
|
78 |
+
self.linebuf += line
|
79 |
+
|
80 |
+
def flush(self):
|
81 |
+
if self.linebuf != '':
|
82 |
+
self.logger.log(self.log_level, self.linebuf.rstrip())
|
83 |
+
self.linebuf = ''
|