Spaces:
Runtime error
Runtime error
use mg-llava instead llava in AutoConfig.register
Browse files- ml_mgie/mgie_llava.py +91 -47
ml_mgie/mgie_llava.py
CHANGED
@@ -12,12 +12,12 @@ import torch.nn.functional as F
|
|
12 |
from torch.nn import CrossEntropyLoss
|
13 |
|
14 |
from transformers import AutoConfig, AutoModelForCausalLM, \
|
15 |
-
|
16 |
-
|
17 |
|
18 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
19 |
|
20 |
-
import os
|
21 |
|
22 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
23 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
@@ -26,7 +26,7 @@ DEFAULT_IM_END_TOKEN = "<im_end>"
|
|
26 |
|
27 |
|
28 |
class LlavaConfig(LlamaConfig):
|
29 |
-
model_type = "llava"
|
30 |
|
31 |
|
32 |
class LlavaLlamaModel(LlamaModel):
|
@@ -37,11 +37,13 @@ class LlavaLlamaModel(LlamaModel):
|
|
37 |
|
38 |
if hasattr(config, "mm_vision_tower"):
|
39 |
# HACK: for FSDP
|
40 |
-
self.vision_tower = [
|
|
|
41 |
# self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
|
42 |
|
43 |
if hasattr(config, "use_mm_proj"):
|
44 |
-
self.mm_projector = nn.Linear(
|
|
|
45 |
|
46 |
def get_vision_tower(self):
|
47 |
vision_tower = getattr(self, 'vision_tower', None)
|
@@ -67,18 +69,22 @@ class LlavaLlamaModel(LlamaModel):
|
|
67 |
self.vision_tower = vision_tower
|
68 |
|
69 |
vision_config = vision_tower.config
|
70 |
-
num_patches = (vision_config.image_size //
|
|
|
71 |
|
72 |
self.config.use_mm_proj = True
|
73 |
self.config.mm_hidden_size = vision_config.hidden_size
|
74 |
self.config.mm_vision_select_layer = mm_vision_select_layer
|
75 |
|
76 |
if not hasattr(self, 'mm_projector'):
|
77 |
-
self.mm_projector = nn.Linear(
|
|
|
78 |
|
79 |
if pretrain_mm_mlp_adapter is not None:
|
80 |
-
mm_projector_weights = torch.load(
|
81 |
-
|
|
|
|
|
82 |
|
83 |
return dict(
|
84 |
image_processor=image_processor,
|
@@ -117,21 +123,28 @@ class LlavaLlamaModel(LlamaModel):
|
|
117 |
# variable length images
|
118 |
image_features = []
|
119 |
for image in images:
|
120 |
-
image_forward_out = vision_tower(
|
121 |
-
|
|
|
|
|
122 |
select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
|
123 |
image_feature = select_hidden_state[:, 1:]
|
124 |
image_features.append(image_feature)
|
125 |
else:
|
126 |
-
image_forward_outs = vision_tower(
|
127 |
-
|
|
|
|
|
128 |
select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
|
129 |
-
image_features = select_hidden_state[:, 1:].to(
|
|
|
130 |
if type(images) is list:
|
131 |
-
image_features = [self.mm_projector(
|
|
|
132 |
else:
|
133 |
image_features = self.mm_projector(image_features)
|
134 |
-
dummy_image_features = torch.zeros(
|
|
|
135 |
dummy_image_features = self.mm_projector(dummy_image_features)
|
136 |
|
137 |
new_input_embeds = []
|
@@ -139,7 +152,8 @@ class LlavaLlamaModel(LlamaModel):
|
|
139 |
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
|
140 |
if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
|
141 |
# multimodal LLM, but the current sample is not multimodal
|
142 |
-
cur_input_embeds = cur_input_embeds +
|
|
|
143 |
new_input_embeds.append(cur_input_embeds)
|
144 |
cur_image_idx += 1
|
145 |
continue
|
@@ -147,32 +161,43 @@ class LlavaLlamaModel(LlamaModel):
|
|
147 |
cur_image_features = image_features[cur_image_idx]
|
148 |
num_patches = cur_image_features.shape[0]
|
149 |
if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum():
|
150 |
-
raise ValueError(
|
151 |
-
|
|
|
|
|
152 |
for image_start_token_pos in image_start_tokens:
|
153 |
-
cur_image_features = image_features[cur_image_idx].to(
|
|
|
154 |
num_patches = cur_image_features.shape[0]
|
155 |
if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token:
|
156 |
-
raise ValueError(
|
|
|
157 |
if orig_embeds_params is not None:
|
158 |
-
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
|
|
|
159 |
else:
|
160 |
-
cur_new_input_embeds = torch.cat(
|
|
|
161 |
cur_image_idx += 1
|
162 |
new_input_embeds.append(cur_new_input_embeds)
|
163 |
else:
|
164 |
cur_image_features = image_features[cur_image_idx]
|
165 |
num_patches = cur_image_features.shape[0]
|
166 |
if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches:
|
167 |
-
raise ValueError(
|
168 |
-
|
|
|
|
|
169 |
mask_index_start = masked_indices[0]
|
170 |
if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
|
171 |
-
raise ValueError(
|
|
|
172 |
if orig_embeds_params is not None:
|
173 |
-
cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(
|
|
|
174 |
else:
|
175 |
-
cur_new_input_embeds = torch.cat(
|
|
|
176 |
new_input_embeds.append(cur_new_input_embeds)
|
177 |
cur_image_idx += 1
|
178 |
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
@@ -184,6 +209,7 @@ class LlavaLlamaModel(LlamaModel):
|
|
184 |
return_dict=return_dict
|
185 |
)
|
186 |
|
|
|
187 |
class EditMapper(nn.Module):
|
188 |
def __init__(self):
|
189 |
super().__init__()
|
@@ -202,6 +228,7 @@ class EditMapper(nn.Module):
|
|
202 |
|
203 |
return feat
|
204 |
|
|
|
205 |
class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
206 |
config_class = LlavaConfig
|
207 |
|
@@ -209,7 +236,8 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
|
209 |
super(LlamaForCausalLM, self).__init__(config)
|
210 |
self.model = LlavaLlamaModel(config)
|
211 |
|
212 |
-
self.lm_head = nn.Linear(
|
|
|
213 |
|
214 |
self.edit_head = EditMapper()
|
215 |
|
@@ -292,12 +320,15 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
|
292 |
if labels is not None:
|
293 |
llm = []
|
294 |
for i in range(labels.shape[0]):
|
295 |
-
try:
|
296 |
-
|
|
|
|
|
297 |
p = min(len(hidden_states[i])-9, p)
|
298 |
llm.append(hidden_states[i][p:p+8].unsqueeze(0))
|
299 |
llm = torch.cat(llm, dim=0)
|
300 |
-
hid_edit = self.edit_head(
|
|
|
301 |
|
302 |
B, DROP = labels.shape[0], 0.05
|
303 |
|
@@ -305,24 +336,30 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
|
305 |
self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
|
306 |
|
307 |
with torch.no_grad():
|
308 |
-
lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample(
|
|
|
309 |
lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device),
|
310 |
torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
|
311 |
|
312 |
noise = torch.randn_like(lat_ans)
|
313 |
-
ts = torch.randint(
|
|
|
314 |
lat_noise = self.scheduler.add_noise(lat_ans, noise, ts)
|
315 |
|
316 |
prob = torch.rand(B, device=lat_ans.device)
|
317 |
-
mask = (prob<(DROP*2)).reshape(B, 1, 1)
|
318 |
hid_edit = torch.where(mask, hid_null, hid_edit)
|
319 |
-
mask = (1.0-((prob>=DROP).to(lat_inp.dtype)*
|
|
|
320 |
lat_inp *= mask
|
321 |
|
322 |
-
out = self.unet(
|
|
|
323 |
|
324 |
-
loss_ce, loss_edit = loss, nn.functional.mse_loss(
|
325 |
-
|
|
|
|
|
326 |
loss = loss_ce+loss_edit*0.5
|
327 |
|
328 |
if not return_dict:
|
@@ -367,9 +404,11 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
|
367 |
self.resize_token_embeddings(len(tokenizer))
|
368 |
|
369 |
if mm_use_im_start_end:
|
370 |
-
num_new_tokens = tokenizer.add_tokens(
|
|
|
371 |
self.resize_token_embeddings(len(tokenizer))
|
372 |
-
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
|
|
|
373 |
|
374 |
if num_new_tokens > 0:
|
375 |
input_embeddings = self.get_input_embeddings().weight.data
|
@@ -384,14 +423,16 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
|
384 |
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
385 |
|
386 |
if tune_mm_mlp_adapter:
|
387 |
-
self.get_model().orig_embeds_params = [
|
|
|
388 |
for p in self.get_input_embeddings().parameters():
|
389 |
p.requires_grad = True
|
390 |
for p in self.get_output_embeddings().parameters():
|
391 |
p.requires_grad = False
|
392 |
|
393 |
if pretrain_mm_mlp_adapter:
|
394 |
-
mm_projector_weights = torch.load(
|
|
|
395 |
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
396 |
assert num_new_tokens == 2
|
397 |
if input_embeddings.shape == embed_tokens_weight.shape:
|
@@ -399,9 +440,12 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
|
399 |
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
400 |
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
401 |
else:
|
402 |
-
raise ValueError(
|
|
|
|
|
|
|
|
|
403 |
|
404 |
-
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
|
405 |
|
406 |
-
AutoConfig.register("llava", LlavaConfig)
|
407 |
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
|
|
|
12 |
from torch.nn import CrossEntropyLoss
|
13 |
|
14 |
from transformers import AutoConfig, AutoModelForCausalLM, \
|
15 |
+
LlamaConfig, LlamaModel, LlamaForCausalLM, \
|
16 |
+
CLIPVisionModel, CLIPImageProcessor
|
17 |
|
18 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
19 |
|
20 |
+
import os
|
21 |
|
22 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
23 |
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
|
|
26 |
|
27 |
|
28 |
class LlavaConfig(LlamaConfig):
|
29 |
+
model_type = "mg-llava"
|
30 |
|
31 |
|
32 |
class LlavaLlamaModel(LlamaModel):
|
|
|
37 |
|
38 |
if hasattr(config, "mm_vision_tower"):
|
39 |
# HACK: for FSDP
|
40 |
+
self.vision_tower = [
|
41 |
+
CLIPVisionModel.from_pretrained(config.mm_vision_tower)]
|
42 |
# self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
|
43 |
|
44 |
if hasattr(config, "use_mm_proj"):
|
45 |
+
self.mm_projector = nn.Linear(
|
46 |
+
config.mm_hidden_size, config.hidden_size)
|
47 |
|
48 |
def get_vision_tower(self):
|
49 |
vision_tower = getattr(self, 'vision_tower', None)
|
|
|
69 |
self.vision_tower = vision_tower
|
70 |
|
71 |
vision_config = vision_tower.config
|
72 |
+
num_patches = (vision_config.image_size //
|
73 |
+
vision_config.patch_size) ** 2
|
74 |
|
75 |
self.config.use_mm_proj = True
|
76 |
self.config.mm_hidden_size = vision_config.hidden_size
|
77 |
self.config.mm_vision_select_layer = mm_vision_select_layer
|
78 |
|
79 |
if not hasattr(self, 'mm_projector'):
|
80 |
+
self.mm_projector = nn.Linear(
|
81 |
+
vision_config.hidden_size, self.config.hidden_size)
|
82 |
|
83 |
if pretrain_mm_mlp_adapter is not None:
|
84 |
+
mm_projector_weights = torch.load(
|
85 |
+
pretrain_mm_mlp_adapter, map_location='cpu')
|
86 |
+
self.mm_projector.load_state_dict(
|
87 |
+
{k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
|
88 |
|
89 |
return dict(
|
90 |
image_processor=image_processor,
|
|
|
123 |
# variable length images
|
124 |
image_features = []
|
125 |
for image in images:
|
126 |
+
image_forward_out = vision_tower(
|
127 |
+
image.unsqueeze(0), output_hidden_states=True)
|
128 |
+
select_hidden_state_layer = getattr(
|
129 |
+
self.config, "mm_vision_select_layer", -1)
|
130 |
select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
|
131 |
image_feature = select_hidden_state[:, 1:]
|
132 |
image_features.append(image_feature)
|
133 |
else:
|
134 |
+
image_forward_outs = vision_tower(
|
135 |
+
images.to(vision_tower.dtype), output_hidden_states=True)
|
136 |
+
select_hidden_state_layer = getattr(
|
137 |
+
self.config, "mm_vision_select_layer", -1)
|
138 |
select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
|
139 |
+
image_features = select_hidden_state[:, 1:].to(
|
140 |
+
images.dtype)
|
141 |
if type(images) is list:
|
142 |
+
image_features = [self.mm_projector(
|
143 |
+
image_feature)[0] for image_feature in image_features]
|
144 |
else:
|
145 |
image_features = self.mm_projector(image_features)
|
146 |
+
dummy_image_features = torch.zeros(
|
147 |
+
256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
|
148 |
dummy_image_features = self.mm_projector(dummy_image_features)
|
149 |
|
150 |
new_input_embeds = []
|
|
|
152 |
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
|
153 |
if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
|
154 |
# multimodal LLM, but the current sample is not multimodal
|
155 |
+
cur_input_embeds = cur_input_embeds + \
|
156 |
+
(0. * dummy_image_features).sum()
|
157 |
new_input_embeds.append(cur_input_embeds)
|
158 |
cur_image_idx += 1
|
159 |
continue
|
|
|
161 |
cur_image_features = image_features[cur_image_idx]
|
162 |
num_patches = cur_image_features.shape[0]
|
163 |
if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum():
|
164 |
+
raise ValueError(
|
165 |
+
"The number of image start tokens and image end tokens should be the same.")
|
166 |
+
image_start_tokens = torch.where(
|
167 |
+
cur_input_ids == vision_tower.config.im_start_token)[0]
|
168 |
for image_start_token_pos in image_start_tokens:
|
169 |
+
cur_image_features = image_features[cur_image_idx].to(
|
170 |
+
device=cur_input_embeds.device)
|
171 |
num_patches = cur_image_features.shape[0]
|
172 |
if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token:
|
173 |
+
raise ValueError(
|
174 |
+
"The image end token should follow the image start token.")
|
175 |
if orig_embeds_params is not None:
|
176 |
+
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
|
177 |
+
cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
|
178 |
else:
|
179 |
+
cur_new_input_embeds = torch.cat(
|
180 |
+
(cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
|
181 |
cur_image_idx += 1
|
182 |
new_input_embeds.append(cur_new_input_embeds)
|
183 |
else:
|
184 |
cur_image_features = image_features[cur_image_idx]
|
185 |
num_patches = cur_image_features.shape[0]
|
186 |
if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches:
|
187 |
+
raise ValueError(
|
188 |
+
"The number of image patch tokens should be the same as the number of image patches.")
|
189 |
+
masked_indices = torch.where(
|
190 |
+
cur_input_ids == vision_tower.config.im_patch_token)[0]
|
191 |
mask_index_start = masked_indices[0]
|
192 |
if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
|
193 |
+
raise ValueError(
|
194 |
+
"The image patch tokens should be consecutive.")
|
195 |
if orig_embeds_params is not None:
|
196 |
+
cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(
|
197 |
+
), cur_image_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0)
|
198 |
else:
|
199 |
+
cur_new_input_embeds = torch.cat(
|
200 |
+
(cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
|
201 |
new_input_embeds.append(cur_new_input_embeds)
|
202 |
cur_image_idx += 1
|
203 |
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
|
|
209 |
return_dict=return_dict
|
210 |
)
|
211 |
|
212 |
+
|
213 |
class EditMapper(nn.Module):
|
214 |
def __init__(self):
|
215 |
super().__init__()
|
|
|
228 |
|
229 |
return feat
|
230 |
|
231 |
+
|
232 |
class LlavaLlamaForCausalLM(LlamaForCausalLM):
|
233 |
config_class = LlavaConfig
|
234 |
|
|
|
236 |
super(LlamaForCausalLM, self).__init__(config)
|
237 |
self.model = LlavaLlamaModel(config)
|
238 |
|
239 |
+
self.lm_head = nn.Linear(
|
240 |
+
config.hidden_size, config.vocab_size, bias=False)
|
241 |
|
242 |
self.edit_head = EditMapper()
|
243 |
|
|
|
320 |
if labels is not None:
|
321 |
llm = []
|
322 |
for i in range(labels.shape[0]):
|
323 |
+
try:
|
324 |
+
p = labels[i].data.cpu().tolist().index(32003)-1
|
325 |
+
except:
|
326 |
+
p = len(labels[i])-9
|
327 |
p = min(len(hidden_states[i])-9, p)
|
328 |
llm.append(hidden_states[i][p:p+8].unsqueeze(0))
|
329 |
llm = torch.cat(llm, dim=0)
|
330 |
+
hid_edit = self.edit_head(
|
331 |
+
llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
|
332 |
|
333 |
B, DROP = labels.shape[0], 0.05
|
334 |
|
|
|
336 |
self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
|
337 |
|
338 |
with torch.no_grad():
|
339 |
+
lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample(
|
340 |
+
)*self.vae.config.scaling_factor, self.vae.encode(p2p_inp).latent_dist.mode()
|
341 |
lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device),
|
342 |
torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
|
343 |
|
344 |
noise = torch.randn_like(lat_ans)
|
345 |
+
ts = torch.randint(
|
346 |
+
0, self.scheduler.config.num_train_timesteps, (B, ), device=noise.device).long()
|
347 |
lat_noise = self.scheduler.add_noise(lat_ans, noise, ts)
|
348 |
|
349 |
prob = torch.rand(B, device=lat_ans.device)
|
350 |
+
mask = (prob < (DROP*2)).reshape(B, 1, 1)
|
351 |
hid_edit = torch.where(mask, hid_null, hid_edit)
|
352 |
+
mask = (1.0-((prob >= DROP).to(lat_inp.dtype) *
|
353 |
+
(prob < (DROP*3)).to(lat_inp.dtype))).reshape(B, 1, 1, 1)
|
354 |
lat_inp *= mask
|
355 |
|
356 |
+
out = self.unet(
|
357 |
+
torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample
|
358 |
|
359 |
+
loss_ce, loss_edit = loss, nn.functional.mse_loss(
|
360 |
+
out, noise, reduction='mean')
|
361 |
+
if int(os.environ['LOCAL_RANK']) == 0:
|
362 |
+
print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit)
|
363 |
loss = loss_ce+loss_edit*0.5
|
364 |
|
365 |
if not return_dict:
|
|
|
404 |
self.resize_token_embeddings(len(tokenizer))
|
405 |
|
406 |
if mm_use_im_start_end:
|
407 |
+
num_new_tokens = tokenizer.add_tokens(
|
408 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
409 |
self.resize_token_embeddings(len(tokenizer))
|
410 |
+
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
|
411 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
|
412 |
|
413 |
if num_new_tokens > 0:
|
414 |
input_embeddings = self.get_input_embeddings().weight.data
|
|
|
423 |
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
424 |
|
425 |
if tune_mm_mlp_adapter:
|
426 |
+
self.get_model().orig_embeds_params = [
|
427 |
+
self.get_input_embeddings().weight.data.clone().to(device=device)]
|
428 |
for p in self.get_input_embeddings().parameters():
|
429 |
p.requires_grad = True
|
430 |
for p in self.get_output_embeddings().parameters():
|
431 |
p.requires_grad = False
|
432 |
|
433 |
if pretrain_mm_mlp_adapter:
|
434 |
+
mm_projector_weights = torch.load(
|
435 |
+
pretrain_mm_mlp_adapter, map_location='cpu')
|
436 |
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
437 |
assert num_new_tokens == 2
|
438 |
if input_embeddings.shape == embed_tokens_weight.shape:
|
|
|
440 |
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
441 |
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
442 |
else:
|
443 |
+
raise ValueError(
|
444 |
+
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
445 |
+
|
446 |
+
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
447 |
+
[DEFAULT_IMAGE_PATCH_TOKEN])[0]
|
448 |
|
|
|
449 |
|
450 |
+
AutoConfig.register("mg-llava", LlavaConfig)
|
451 |
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
|