qingsonglv commited on
Commit
f1ed53e
1 Parent(s): 71b4373

remove triton dependency

Browse files
Files changed (1) hide show
  1. modeling_cogagent.py +960 -910
modeling_cogagent.py CHANGED
@@ -1,910 +1,960 @@
1
- """largely copy from llama and adapt for CogAgent"""
2
- import warnings
3
- from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
4
-
5
- import math
6
- import torch
7
- from torch import nn
8
- from torch.nn import CrossEntropyLoss
9
- from torchvision import transforms
10
- from einops import rearrange
11
-
12
- from transformers import PreTrainedModel, PreTrainedTokenizer
13
- from transformers.utils.logging import get_logger
14
- from transformers.activations import ACT2FN
15
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
-
17
- from .configuration_cogagent import CogAgentConfig
18
- from .util import FastRotaryEmbedding
19
- from .visual import EVA2CLIPModel
20
- from .cross_visual import CrossVisionModel
21
-
22
- if TYPE_CHECKING:
23
- from transformers.utils import ModelOutput
24
-
25
- logger = get_logger(__name__)
26
-
27
- LANGUAGE_TOKEN_TYPE = 0
28
- VISION_TOKEN_TYPE = 1
29
-
30
-
31
- # Copied from transformers.models.bart.modeling_bart._make_causal_mask
32
- def _make_causal_mask(
33
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
34
- ):
35
- """
36
- Make causal mask used for bi-directional self-attention.
37
- """
38
- bsz, tgt_len = input_ids_shape
39
- mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
40
- mask_cond = torch.arange(mask.size(-1), device=device)
41
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
42
- mask = mask.to(dtype)
43
-
44
- if past_key_values_length > 0:
45
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
46
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
47
-
48
-
49
- # Copied from transformers.models.bart.modeling_bart._expand_mask
50
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
51
- """
52
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
53
- """
54
- bsz, src_len = mask.size()
55
- tgt_len = tgt_len if tgt_len is not None else src_len
56
-
57
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
58
-
59
- inverted_mask = 1.0 - expanded_mask
60
-
61
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
62
-
63
-
64
- class RMSNorm(nn.Module):
65
- def __init__(self, hidden_size, eps=1e-6):
66
- super().__init__()
67
- self.weight = nn.Parameter(torch.ones(hidden_size))
68
- self.variance_epsilon = eps
69
-
70
- def forward(self, hidden_states):
71
- input_dtype = hidden_states.dtype
72
- hidden_states = hidden_states.to(torch.float32)
73
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
74
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
75
- return (self.weight * hidden_states).to(input_dtype)
76
-
77
-
78
- class MLP(nn.Module):
79
- def __init__(self, config):
80
- super().__init__()
81
- self.hidden_size = config.hidden_size
82
- self.intermediate_size = config.intermediate_size
83
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
84
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
85
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
86
- self.act_fn = ACT2FN[config.hidden_act]
87
-
88
- def forward(self, x):
89
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
90
- return down_proj
91
-
92
-
93
- def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
94
- vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
95
- vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
96
- language_token_mask = ~vision_token_mask
97
- return vision_token_mask, language_token_mask
98
-
99
-
100
- class VisionExpertMLP(nn.Module):
101
- def __init__(self, config):
102
- super().__init__()
103
- self.language_mlp = MLP(config)
104
- self.vision_mlp = MLP(config)
105
-
106
- def forward(self, hidden_states: "torch.Tensor(B, L, D)", token_type_ids: "torch.LongTensor(B, L)"):
107
- output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
108
- vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
109
- output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
110
- output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
111
- return output
112
-
113
-
114
- def attention_fn(
115
- query_layer: "torch.tensor(B, H, L, HD)",
116
- key_layer: "torch.tensor(B, H, L, HD)",
117
- value_layer: "torch.tensor(B, H, L, HD)",
118
- attention_mask: "torch.tensor(B, H, L, HD)",
119
- *,
120
- scaling_attention_score: bool = True,
121
- attention_dropout: nn.Module = None
122
- ):
123
- attention_mask_bool = (attention_mask == 0)
124
- is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
125
- is_full = (attention_mask_bool > 0).all()
126
- if not (int(torch.__version__.split('.')[0]) >= 2):
127
- warnings.warn("It's recommended to use torch2.0 or higher.")
128
- if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
129
- dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
130
- return torch.nn.functional.scaled_dot_product_attention(
131
- query_layer, key_layer, value_layer,
132
- attn_mask=None,
133
- dropout_p=dropout_p,
134
- is_causal=not is_full
135
- )
136
- else:
137
- if scaling_attention_score:
138
- query_layer = query_layer / math.sqrt(query_layer.shape[-1])
139
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
140
- attention_scores = attention_scores + attention_mask
141
- attention_scores = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
142
- if attention_dropout is not None:
143
- attention_scores = attention_dropout(attention_scores)
144
- context_layer = torch.matmul(attention_scores, value_layer)
145
- return context_layer
146
-
147
-
148
- class VisionExpertAttention(nn.Module):
149
- def __init__(self, config):
150
- super().__init__()
151
- self.config = config
152
- self.hidden_size = config.hidden_size
153
- self.num_heads = config.num_attention_heads
154
- self.head_dim = self.hidden_size // self.num_heads
155
- self.max_position_embeddings = config.max_position_embeddings
156
-
157
- # self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads)
158
- self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
159
- self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
160
- self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
161
- self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
162
- self.language_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
163
-
164
- def _transpose_for_scores(self, tensor):
165
- """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
166
- new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.head_dim)
167
- tensor = tensor.view(*new_tensor_shape)
168
- return tensor.permute(0, 2, 1, 3)
169
-
170
- def forward(
171
- self,
172
- hidden_states: torch.Tensor,
173
- token_type_ids: torch.LongTensor,
174
- position_ids: torch.LongTensor,
175
- attention_mask: Optional[torch.Tensor] = None,
176
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
177
- output_attentions: bool = False,
178
- use_cache: bool = False,
179
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
180
- bsz, q_len, _ = hidden_states.size()
181
- vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
182
-
183
- shape = list(hidden_states.shape)
184
- shape[-1] = shape[-1] * 3
185
- mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device)
186
- mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(hidden_states[vision_token_mask])
187
- mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(hidden_states[language_token_mask])
188
-
189
- query_states, key_states, value_states = torch.split(mixed_raw_layer, self.hidden_size, dim=-1)
190
- query_states = self._transpose_for_scores(query_states) # B, H, L, HD
191
- key_states = self._transpose_for_scores(key_states) # B, H, L, HD
192
- value_states = self._transpose_for_scores(value_states) # B, H, L, HD
193
-
194
- kv_seq_len = key_states.shape[-2]
195
- if past_key_value is not None:
196
- kv_seq_len += past_key_value[0].shape[-2]
197
-
198
- query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids, max_seqlen=position_ids.max() + 1)
199
-
200
- if past_key_value is not None:
201
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
202
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
203
-
204
- past_key_value = (key_states, value_states) if use_cache else None
205
-
206
- context_layer = attention_fn(
207
- query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
208
- scaling_attention_score=True, attention_dropout=None)
209
- if context_layer.size() != (bsz, self.num_heads, q_len, self.head_dim):
210
- raise ValueError(
211
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
212
- f" {context_layer.size()}"
213
- )
214
- context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
215
-
216
- attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
217
- attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
218
- attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
219
-
220
- if output_attentions:
221
- warnings.warn("output_attentions is not implemented.")
222
-
223
- return attn_output, None, past_key_value
224
-
225
- class CrossAttention(nn.Module):
226
- def __init__(self, config):
227
- super().__init__()
228
- self.config = config
229
- self.hidden_size = config.hidden_size
230
- self.cross_hidden_size = config.cross_hidden_size
231
- self.cross_compute_hidden_size = config.cross_compute_hidden_size
232
- self.num_heads = config.num_attention_heads
233
- self.head_dim = self.hidden_size // self.num_heads
234
- self.cross_head_dim = self.cross_compute_hidden_size // self.num_heads
235
- self.max_position_embeddings = config.max_position_embeddings
236
-
237
- # self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads)
238
- self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
239
- self.query = nn.Linear(self.hidden_size, self.cross_compute_hidden_size, bias=False)
240
- self.key_value = nn.Linear(self.cross_hidden_size, self.cross_compute_hidden_size * 2, bias=False)
241
- self.dense = nn.Linear(self.cross_compute_hidden_size, self.hidden_size, bias=False)
242
-
243
- def _transpose_for_scores(self, tensor):
244
- """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
245
- new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.cross_head_dim)
246
- tensor = tensor.view(*new_tensor_shape)
247
- return tensor.permute(0, 2, 1, 3)
248
-
249
- def forward(
250
- self,
251
- hidden_states: torch.Tensor,
252
- encoder_outputs: torch.LongTensor,
253
- attention_mask: Optional[torch.Tensor] = None,
254
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
255
- output_attentions: bool = False,
256
- use_cache: bool = False,
257
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
258
- bsz, q_len, _ = hidden_states.size()
259
-
260
- shape = list(hidden_states.shape)
261
- shape[-1] = shape[-1] * 3
262
-
263
- mixed_query_layer = self.query(hidden_states)
264
- if past_key_value is None:
265
- mixed_x_layer = self.key_value(encoder_outputs)
266
- mixed_key_layer, mixed_value_layer = torch.split(mixed_x_layer, self.cross_compute_hidden_size, dim=-1)
267
- key_states = self._transpose_for_scores(mixed_key_layer) # B, H, L, HD
268
- value_states = self._transpose_for_scores(mixed_value_layer) # B, H, L, HD
269
- else:
270
- key_states, value_states = past_key_value
271
-
272
- query_states = self._transpose_for_scores(mixed_query_layer) # B, H, L, HD
273
-
274
- past_key_value = (key_states, value_states) if use_cache else None
275
-
276
- context_layer = attention_fn(
277
- query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
278
- scaling_attention_score=True, attention_dropout=None)
279
- if context_layer.size() != (bsz, self.num_heads, q_len, self.cross_head_dim):
280
- raise ValueError(
281
- f"`cross_attn_output` should be of size {(bsz, self.num_heads, q_len, self.cross_head_dim)}, but is"
282
- f" {context_layer.size()}"
283
- )
284
- context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.cross_hidden_size)
285
-
286
- attn_output = self.dense(context_layer)
287
-
288
- if output_attentions:
289
- warnings.warn("output_attentions is not implemented.")
290
-
291
- return attn_output, None, past_key_value
292
-
293
- class CogAgentDecoderLayer(nn.Module):
294
- def __init__(self, config):
295
- super().__init__()
296
- self.hidden_size = config.hidden_size
297
- self.self_attn = VisionExpertAttention(config=config)
298
- self.cross_attn = CrossAttention(config=config)
299
- self.mlp = VisionExpertMLP(config)
300
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
301
- self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
302
- self.post_cross_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
303
-
304
- def forward(
305
- self,
306
- hidden_states: torch.Tensor,
307
- encoder_outputs: torch.Tensor,
308
- token_type_ids: torch.LongTensor,
309
- position_ids: torch.LongTensor,
310
- attention_mask: Optional[torch.Tensor] = None,
311
- cross_attention_mask: Optional[torch.Tensor] = None,
312
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
313
- output_attentions: Optional[bool] = False,
314
- use_cache: Optional[bool] = False,
315
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
316
- residual = hidden_states
317
-
318
- hidden_states = self.input_layernorm(hidden_states)
319
-
320
- # Self Attention
321
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
322
- hidden_states=hidden_states,
323
- token_type_ids=token_type_ids,
324
- position_ids=position_ids,
325
- attention_mask=attention_mask,
326
- past_key_value=past_key_value[:2] if past_key_value is not None else None,
327
- output_attentions=output_attentions,
328
- use_cache=use_cache,
329
- )
330
- hidden_states = residual + hidden_states
331
-
332
- cross_input = self.post_cross_attention_layernorm(hidden_states)
333
- # Fully Connected
334
- attention_output, self_cross_attn_weights, present_cross_key_value = self.cross_attn(
335
- hidden_states=cross_input,
336
- encoder_outputs=encoder_outputs,
337
- attention_mask=cross_attention_mask,
338
- past_key_value=past_key_value[-2:] if past_key_value is not None else None,
339
- output_attentions=output_attentions,
340
- use_cache=use_cache,
341
- )
342
- hidden_states = hidden_states + attention_output
343
- mlp_input = self.post_attention_layernorm(hidden_states)
344
- mlp_output = self.mlp(mlp_input, token_type_ids=token_type_ids)
345
- hidden_states = mlp_output + hidden_states
346
-
347
- outputs = (hidden_states,)
348
-
349
- if output_attentions:
350
- outputs += (self_attn_weights,)
351
-
352
- if use_cache:
353
- outputs += (present_key_value+present_cross_key_value,)
354
-
355
- return outputs # type: ignore
356
-
357
-
358
- class CogAgentPreTrainedModel(PreTrainedModel):
359
- config_class = CogAgentConfig
360
- base_model_prefix = "model"
361
- supports_gradient_checkpointing = False
362
- _no_split_modules = ["CogAgentDecoderLayer"]
363
- _skip_keys_device_placement = "past_key_values"
364
-
365
- def _init_weights(self, module):
366
- std = self.config.initializer_range
367
- if isinstance(module, nn.Linear):
368
- module.weight.data.normal_(mean=0.0, std=std)
369
- if module.bias is not None:
370
- module.bias.data.zero_()
371
- elif isinstance(module, nn.Embedding):
372
- module.weight.data.normal_(mean=0.0, std=std)
373
- if module.padding_idx is not None:
374
- module.weight.data[module.padding_idx].zero_()
375
-
376
-
377
- def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
378
- if images_list is None or len(images_list) == 0:
379
- return True
380
- for image_list in images_list:
381
- if len(image_list):
382
- return False
383
- return True
384
-
385
-
386
- def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
387
- if attention_mask is not None:
388
- tmp = x.clone()
389
- tmp[~(attention_mask.bool())] = -1
390
- else:
391
- tmp = x.clone()
392
- # image boi eoi token as LANGUAGE_TOKEN_TYPE
393
- is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
394
- is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
395
- is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
396
- is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
397
- is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
398
- tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
399
- # final position ids
400
- y = torch.zeros_like(x, dtype=torch.long)
401
- y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ((tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
402
- y = y.cumsum(dim=-1)
403
- return y
404
-
405
-
406
- class CogAgentModel(CogAgentPreTrainedModel):
407
- def __init__(self, config):
408
- super().__init__(config)
409
- self.padding_idx = config.pad_token_id
410
- self.vocab_size = config.vocab_size
411
-
412
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
413
- self.layers = nn.ModuleList([CogAgentDecoderLayer(config) for _ in range(config.num_hidden_layers)])
414
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
415
-
416
- self.vision = EVA2CLIPModel(config)
417
- self.cross_vision = CrossVisionModel(config)
418
-
419
- self.gradient_checkpointing = False
420
- # Initialize weights and apply final processing
421
- self.post_init()
422
-
423
- def encode_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
424
- images_list, images = images, []
425
-
426
- images = []
427
- for image_list in images_list:
428
- for image in image_list:
429
- images.append(image)
430
-
431
- images = torch.stack(images)
432
- images_features = self.vision(images)
433
- return images_features
434
-
435
- def encode_cross_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
436
- images_list, images = images, []
437
-
438
- images = []
439
- for image_list in images_list:
440
- for image in image_list:
441
- images.append(image)
442
-
443
- images = torch.stack(images)
444
- encoder_outputs = self.cross_vision(images)
445
- return encoder_outputs
446
-
447
- def forward(
448
- self,
449
- input_ids: torch.LongTensor = None,
450
- images: List[List[torch.Tensor]] = None,
451
- cross_images: List[List[torch.Tensor]] = None,
452
- token_type_ids: Optional[torch.LongTensor] = None,
453
- attention_mask: Optional[torch.Tensor] = None,
454
- cross_attention_mask: Optional[torch.Tensor] = None,
455
- position_ids: Optional[torch.LongTensor] = None,
456
- past_key_values: Optional[List[torch.FloatTensor]] = None,
457
- inputs_embeds: Optional[torch.FloatTensor] = None,
458
- use_cache: Optional[bool] = None,
459
- output_attentions: Optional[bool] = None,
460
- output_hidden_states: Optional[bool] = None,
461
- return_dict: Optional[bool] = None,
462
- ) -> Union[Tuple, BaseModelOutputWithPast]:
463
- """take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
464
-
465
- if past_key_values is not None:
466
- encoder_outputs = None
467
- # generate mode with past_key_values. the image features are already mapped
468
- else:
469
- # not allow for inputs_embeds, because we want to process image feature
470
- assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
471
- if not is_empty(images): # multi-modality
472
- assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
473
- assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
474
- inputs_embeds = self.embed_tokens(input_ids)
475
- images_features = self.encode_images(images)
476
- encoder_outputs = self.encode_cross_images(cross_images)
477
- images_features = rearrange(images_features, 'b n d -> (b n) d')
478
- images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
479
- inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
480
- else: # single-modality
481
- if token_type_ids is None:
482
- token_type_ids = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) * LANGUAGE_TOKEN_TYPE
483
- assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
484
- inputs_embeds = self.embed_tokens(input_ids)
485
- encoder_outputs = None
486
-
487
- if position_ids is None:
488
- position_ids = build_position_ids(token_type_ids, attention_mask)
489
- input_ids = None
490
-
491
- return self.llm_forward(
492
- input_ids=input_ids,
493
- encoder_outputs=encoder_outputs,
494
- token_type_ids=token_type_ids,
495
- attention_mask=attention_mask,
496
- cross_attention_mask=cross_attention_mask,
497
- position_ids=position_ids,
498
- past_key_values=past_key_values,
499
- inputs_embeds=inputs_embeds,
500
- use_cache=use_cache,
501
- output_attentions=output_attentions,
502
- output_hidden_states=output_hidden_states,
503
- return_dict=return_dict,
504
- )
505
-
506
- def llm_forward(
507
- self,
508
- input_ids: torch.LongTensor = None,
509
- encoder_outputs: torch.LongTensor = None,
510
- token_type_ids: torch.LongTensor = None,
511
- attention_mask: Optional[torch.Tensor] = None,
512
- cross_attention_mask: Optional[torch.Tensor] = None,
513
- position_ids: Optional[torch.LongTensor] = None,
514
- past_key_values: Optional[List[torch.FloatTensor]] = None,
515
- inputs_embeds: Optional[torch.FloatTensor] = None,
516
- use_cache: Optional[bool] = None,
517
- output_attentions: Optional[bool] = None,
518
- output_hidden_states: Optional[bool] = None,
519
- return_dict: Optional[bool] = None,
520
- ) -> Union[Tuple, BaseModelOutputWithPast]:
521
- """largely copy from llama forward and adapt for CogAgent with `token_type_ids`"""
522
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
523
- output_hidden_states = (
524
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
525
- )
526
- use_cache = use_cache if use_cache is not None else self.config.use_cache
527
-
528
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
529
-
530
- # retrieve input_ids and inputs_embeds
531
- if input_ids is not None and inputs_embeds is not None:
532
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
533
- elif input_ids is not None:
534
- batch_size, seq_length = input_ids.shape
535
- elif inputs_embeds is not None:
536
- batch_size, seq_length, _ = inputs_embeds.shape
537
- else:
538
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
539
-
540
- seq_length_with_past = seq_length
541
- past_key_values_length = 0
542
-
543
- if past_key_values is not None:
544
- past_key_values_length = past_key_values[0][0].shape[2]
545
- seq_length_with_past = seq_length_with_past + past_key_values_length
546
-
547
- if position_ids is None:
548
- device = input_ids.device if input_ids is not None else inputs_embeds.device
549
- position_ids = torch.arange(
550
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
551
- )
552
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
553
- else:
554
- position_ids = position_ids.view(-1, seq_length).long()
555
-
556
- if inputs_embeds is None:
557
- inputs_embeds = self.embed_tokens(input_ids)
558
- # embed positions
559
- if attention_mask is None:
560
- attention_mask = torch.ones(
561
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
562
- )
563
- if cross_attention_mask is None:
564
- cross_attention_mask = torch.ones(
565
- (batch_size, 1), dtype=torch.bool, device=inputs_embeds.device
566
- )
567
- attention_mask = self._prepare_decoder_attention_mask(
568
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
569
- )
570
-
571
- hidden_states = inputs_embeds
572
-
573
- # decoder layers
574
- all_hidden_states = () if output_hidden_states else None
575
- all_self_attns = () if output_attentions else None
576
- next_decoder_cache = () if use_cache else None
577
-
578
- for idx, decoder_layer in enumerate(self.layers):
579
- if output_hidden_states:
580
- all_hidden_states += (hidden_states,)
581
-
582
- past_key_value = past_key_values[idx] if past_key_values is not None else None
583
- layer_outputs = decoder_layer(
584
- hidden_states,
585
- encoder_outputs=encoder_outputs,
586
- token_type_ids=token_type_ids,
587
- attention_mask=attention_mask,
588
- cross_attention_mask=cross_attention_mask,
589
- position_ids=position_ids,
590
- past_key_value=past_key_value,
591
- output_attentions=output_attentions,
592
- use_cache=use_cache,
593
- )
594
- hidden_states = layer_outputs[0]
595
-
596
- if use_cache:
597
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
598
-
599
- if output_attentions:
600
- all_self_attns += (layer_outputs[1],)
601
-
602
- hidden_states = self.norm(hidden_states)
603
-
604
- # add hidden states from the last decoder layer
605
- if output_hidden_states:
606
- all_hidden_states += (hidden_states,)
607
-
608
- next_cache = next_decoder_cache if use_cache else None
609
- if not return_dict:
610
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
611
- return BaseModelOutputWithPast(
612
- last_hidden_state=hidden_states,
613
- past_key_values=next_cache,
614
- hidden_states=all_hidden_states,
615
- attentions=all_self_attns,
616
- )
617
-
618
- def get_input_embeddings(self):
619
- return self.embed_tokens
620
-
621
- def set_input_embeddings(self, value):
622
- self.embed_tokens = value
623
-
624
- # noinspection PyMethodMayBeStatic
625
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
626
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
627
- # create causal mask
628
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
629
- combined_attention_mask = None
630
- if input_shape[-1] > 1:
631
- combined_attention_mask = _make_causal_mask(
632
- input_shape,
633
- inputs_embeds.dtype,
634
- device=inputs_embeds.device,
635
- past_key_values_length=past_key_values_length,
636
- )
637
-
638
- if attention_mask is not None:
639
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
640
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
641
- inputs_embeds.device
642
- )
643
- combined_attention_mask = (
644
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
645
- )
646
-
647
- return combined_attention_mask
648
-
649
-
650
- def chat_history_to_prompt(history, query):
651
- prompt = " [INST] "
652
- for i, (old_query, response) in enumerate(history):
653
- prompt += old_query + " [/INST] " + response + " [INST] "
654
- prompt += query + " [/INST] "
655
- return prompt
656
-
657
-
658
- def base_history_to_prompt(history, query):
659
- prompt = query
660
- return prompt
661
-
662
-
663
- _history_to_prompt = {
664
- "base": base_history_to_prompt,
665
- "chat": chat_history_to_prompt
666
- }
667
-
668
-
669
- class CogAgentForCausalLM(CogAgentPreTrainedModel):
670
- _auto_class = "AutoModelForCausalLM"
671
-
672
- def __init__(self, config):
673
- super().__init__(config)
674
- self.model = CogAgentModel(config)
675
- self.vocab_size = config.vocab_size
676
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
677
-
678
- # Initialize weights and apply final processing
679
- self.post_init()
680
-
681
- def get_input_embeddings(self):
682
- return self.model.embed_tokens
683
-
684
- def set_input_embeddings(self, value):
685
- self.model.embed_tokens = value
686
-
687
- def get_output_embeddings(self):
688
- return self.lm_head
689
-
690
- def set_output_embeddings(self, new_embeddings):
691
- self.lm_head = new_embeddings
692
-
693
- def set_decoder(self, decoder):
694
- self.model = decoder
695
-
696
- def get_decoder(self):
697
- return self.model
698
-
699
- def forward(
700
- self,
701
- input_ids: torch.LongTensor = None,
702
- images: List[List[torch.Tensor]] = None,
703
- cross_images: List[List[torch.Tensor]] = None,
704
- token_type_ids: Optional[torch.LongTensor] = None,
705
- attention_mask: Optional[torch.Tensor] = None,
706
- position_ids: Optional[torch.LongTensor] = None,
707
- past_key_values: Optional[List[torch.FloatTensor]] = None,
708
- inputs_embeds: Optional[torch.FloatTensor] = None,
709
- use_cache: Optional[bool] = None,
710
- output_attentions: Optional[bool] = None,
711
- output_hidden_states: Optional[bool] = None,
712
- return_dict: Optional[bool] = None,
713
- labels: Optional[torch.LongTensor] = None,
714
- ) -> Union[Tuple, CausalLMOutputWithPast]:
715
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
716
- output_hidden_states = (
717
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
718
- )
719
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
720
-
721
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
722
- outputs = self.model(
723
- input_ids=input_ids,
724
- images=images,
725
- cross_images=cross_images,
726
- token_type_ids=token_type_ids,
727
- attention_mask=attention_mask,
728
- position_ids=position_ids,
729
- past_key_values=past_key_values,
730
- inputs_embeds=inputs_embeds,
731
- use_cache=use_cache,
732
- output_attentions=output_attentions,
733
- output_hidden_states=output_hidden_states,
734
- return_dict=return_dict,
735
- )
736
-
737
- hidden_states = outputs[0]
738
- logits = self.lm_head(hidden_states)
739
- logits = logits.float()
740
-
741
- loss = None
742
- if labels is not None:
743
- # Shift so that tokens < n predict n
744
- shift_logits = logits[..., :-1, :].contiguous()
745
- shift_labels = labels[..., 1:].contiguous()
746
- # Flatten the tokens
747
- loss_fct = CrossEntropyLoss()
748
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
749
- shift_labels = shift_labels.view(-1)
750
- # Enable model parallelism
751
- shift_labels = shift_labels.to(shift_logits.device)
752
- loss = loss_fct(shift_logits, shift_labels)
753
-
754
- if not return_dict:
755
- output = (logits,) + outputs[1:]
756
- return (loss,) + output if loss is not None else output
757
-
758
- return CausalLMOutputWithPast(
759
- loss=loss,
760
- logits=logits,
761
- past_key_values=outputs.past_key_values,
762
- hidden_states=outputs.hidden_states,
763
- attentions=outputs.attentions,
764
- )
765
-
766
- def _prepare_attention_mask_for_generation(
767
- self,
768
- inputs: torch.Tensor,
769
- pad_token_id: Optional[int],
770
- eos_token_id: Optional[Union[int, List[int]]],
771
- ) -> torch.LongTensor:
772
- return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
773
-
774
- def prepare_inputs_for_generation(
775
- self, input_ids, token_type_ids, images=None, cross_images=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
776
- ):
777
- # build position_ids if needed
778
- position_ids = kwargs.get("position_ids", None)
779
- if position_ids is None:
780
- position_ids = build_position_ids(token_type_ids, attention_mask)
781
-
782
- if past_key_values:
783
- input_ids = input_ids[:, -1:]
784
- token_type_ids = token_type_ids[:, -1:]
785
- position_ids = position_ids[:, -1:]
786
-
787
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
788
- if inputs_embeds is not None and past_key_values is None:
789
- model_inputs = {"inputs_embeds": inputs_embeds}
790
- else:
791
- model_inputs = {"input_ids": input_ids}
792
-
793
- model_inputs.update(
794
- {
795
- "token_type_ids": token_type_ids,
796
- "images": images,
797
- "cross_images": cross_images,
798
- "position_ids": position_ids,
799
- "past_key_values": past_key_values,
800
- "use_cache": kwargs.get("use_cache"),
801
- "attention_mask": attention_mask,
802
- }
803
- )
804
- return model_inputs
805
-
806
- def _update_model_kwargs_for_generation(
807
- self,
808
- outputs: "ModelOutput",
809
- model_kwargs: Dict[str, Any],
810
- is_encoder_decoder: bool = False,
811
- standardize_cache_format: bool = False,
812
- ) -> Dict[str, Any]:
813
- # update past_key_values
814
- model_kwargs["past_key_values"] = self._extract_past_from_model_output(
815
- outputs, standardize_cache_format=standardize_cache_format
816
- )
817
- if getattr(outputs, "state", None) is not None:
818
- model_kwargs["state"] = outputs.state
819
-
820
- # update token_type_ids with last value
821
- if "token_type_ids" in model_kwargs:
822
- token_type_ids = model_kwargs["token_type_ids"]
823
- new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype, device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
824
- model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
825
-
826
- if not is_encoder_decoder:
827
- # update attention mask
828
- if "attention_mask" in model_kwargs:
829
- attention_mask = model_kwargs["attention_mask"]
830
- model_kwargs["attention_mask"] = torch.cat(
831
- [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
832
- )
833
- else:
834
- # update decoder attention mask
835
- if "decoder_attention_mask" in model_kwargs:
836
- decoder_attention_mask = model_kwargs["decoder_attention_mask"]
837
- model_kwargs["decoder_attention_mask"] = torch.cat(
838
- [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
839
- dim=-1,
840
- )
841
-
842
- return model_kwargs
843
-
844
- def _reorder_cache(self, past_key_values, beam_idx):
845
- reordered_past = ()
846
- for layer_past in past_key_values:
847
- reordered_past += (
848
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
849
- )
850
- return reordered_past
851
-
852
- def build_conversation_input_ids(
853
- self,
854
- tokenizer: "PreTrainedTokenizer",
855
- *,
856
- query: str,
857
- history: Optional[List[Tuple[str, str]]] = None,
858
- images: Optional[List["PIL.Image"]] = None,
859
- template_version: Optional[Literal["base", "chat", "vqa"]] = None,
860
- ):
861
- image_size: int = self.config.vision_config['image_size']
862
- cross_image_size: int = self.config.cross_image_size
863
- patch_size: int = self.config.vision_config['patch_size']
864
- template_version = template_version or self.config.template_version
865
- assert images is None or len(images) <= 1, f"not support multi images by now."
866
- history = history or []
867
- text = _history_to_prompt[template_version](history, query)
868
-
869
- input_ids = [tokenizer.bos_token_id]
870
- token_type_ids = [LANGUAGE_TOKEN_TYPE]
871
- if images is not None and len(images) == 1:
872
- ori = images
873
- # vision
874
- transform = transforms.Compose(
875
- [
876
- transforms.Resize(
877
- (image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
878
- ),
879
- transforms.ToTensor(),
880
- transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
881
- ]
882
- )
883
- images = [transform(ori[0])]
884
- cross_transform = transforms.Compose(
885
- [
886
- transforms.Resize(
887
- (cross_image_size, cross_image_size), interpolation=transforms.InterpolationMode.BICUBIC
888
- ),
889
- transforms.ToTensor(),
890
- transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
891
- ]
892
- )
893
- cross_images = [cross_transform(ori[0])]
894
- # language
895
- vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
896
- input_ids += [tokenizer.pad_token_id] * vision_token_num
897
- token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
898
- text_ids = tokenizer.encode(text, add_special_tokens=False)
899
-
900
- input_ids += text_ids
901
- token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
902
- attention_mask = [1] * len(input_ids)
903
-
904
- return {
905
- 'input_ids': torch.tensor(input_ids, dtype=torch.long),
906
- 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
907
- 'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
908
- 'images': images,
909
- 'cross_images': cross_images
910
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """largely copy from llama and adapt for CogAgent"""
2
+ import warnings
3
+ from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
4
+
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import CrossEntropyLoss
9
+ from torchvision import transforms
10
+ from einops import rearrange
11
+
12
+ from transformers import PreTrainedModel, PreTrainedTokenizer
13
+ from transformers.utils.logging import get_logger
14
+ from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
+
17
+ from .configuration_cogagent import CogAgentConfig
18
+ # from .util import FastRotaryEmbedding
19
+ from torch.nn import functional as F
20
+ from .visual import EVA2CLIPModel
21
+ from .cross_visual import CrossVisionModel
22
+
23
+ if TYPE_CHECKING:
24
+ from transformers.utils import ModelOutput
25
+
26
+ logger = get_logger(__name__)
27
+
28
+ LANGUAGE_TOKEN_TYPE = 0
29
+ VISION_TOKEN_TYPE = 1
30
+
31
+
32
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
33
+ def _make_causal_mask(
34
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
35
+ ):
36
+ """
37
+ Make causal mask used for bi-directional self-attention.
38
+ """
39
+ bsz, tgt_len = input_ids_shape
40
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
41
+ mask_cond = torch.arange(mask.size(-1), device=device)
42
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
43
+ mask = mask.to(dtype)
44
+
45
+ if past_key_values_length > 0:
46
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
47
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
48
+
49
+
50
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
51
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
52
+ """
53
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
54
+ """
55
+ bsz, src_len = mask.size()
56
+ tgt_len = tgt_len if tgt_len is not None else src_len
57
+
58
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
59
+
60
+ inverted_mask = 1.0 - expanded_mask
61
+
62
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
63
+
64
+
65
+ class RMSNorm(nn.Module):
66
+ def __init__(self, hidden_size, eps=1e-6):
67
+ super().__init__()
68
+ self.weight = nn.Parameter(torch.ones(hidden_size))
69
+ self.variance_epsilon = eps
70
+
71
+ def forward(self, hidden_states):
72
+ input_dtype = hidden_states.dtype
73
+ hidden_states = hidden_states.to(torch.float32)
74
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
75
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
76
+ return (self.weight * hidden_states).to(input_dtype)
77
+
78
+
79
+ class MLP(nn.Module):
80
+ def __init__(self, config):
81
+ super().__init__()
82
+ self.hidden_size = config.hidden_size
83
+ self.intermediate_size = config.intermediate_size
84
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
85
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
86
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
87
+ self.act_fn = ACT2FN[config.hidden_act]
88
+
89
+ def forward(self, x):
90
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
91
+ return down_proj
92
+
93
+
94
+ def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
95
+ vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
96
+ vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
97
+ language_token_mask = ~vision_token_mask
98
+ return vision_token_mask, language_token_mask
99
+
100
+
101
+ class VisionExpertMLP(nn.Module):
102
+ def __init__(self, config):
103
+ super().__init__()
104
+ self.language_mlp = MLP(config)
105
+ self.vision_mlp = MLP(config)
106
+
107
+ def forward(self, hidden_states: "torch.Tensor(B, L, D)", token_type_ids: "torch.LongTensor(B, L)"):
108
+ output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
109
+ vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
110
+ output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
111
+ output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
112
+ return output
113
+
114
+
115
+ def attention_fn(
116
+ query_layer: "torch.tensor(B, H, L, HD)",
117
+ key_layer: "torch.tensor(B, H, L, HD)",
118
+ value_layer: "torch.tensor(B, H, L, HD)",
119
+ attention_mask: "torch.tensor(B, H, L, HD)",
120
+ *,
121
+ scaling_attention_score: bool = True,
122
+ attention_dropout: nn.Module = None
123
+ ):
124
+ attention_mask_bool = (attention_mask == 0)
125
+ is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
126
+ is_full = (attention_mask_bool > 0).all()
127
+ if not (int(torch.__version__.split('.')[0]) >= 2):
128
+ warnings.warn("It's recommended to use torch2.0 or higher.")
129
+ if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
130
+ dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
131
+ return torch.nn.functional.scaled_dot_product_attention(
132
+ query_layer, key_layer, value_layer,
133
+ attn_mask=None,
134
+ dropout_p=dropout_p,
135
+ is_causal=not is_full
136
+ )
137
+ else:
138
+ if scaling_attention_score:
139
+ query_layer = query_layer / math.sqrt(query_layer.shape[-1])
140
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
141
+ attention_scores = attention_scores + attention_mask
142
+ attention_scores = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
143
+ if attention_dropout is not None:
144
+ attention_scores = attention_dropout(attention_scores)
145
+ context_layer = torch.matmul(attention_scores, value_layer)
146
+ return context_layer
147
+
148
+ class RotaryEmbedding(torch.nn.Module):
149
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
150
+ super().__init__()
151
+
152
+ self.dim = dim
153
+ self.max_position_embeddings = max_position_embeddings
154
+ self.base = base
155
+ inv_freq = self._compute_inv_freq(device)
156
+ self.register_buffer("inv_freq", inv_freq)
157
+ self.max_seq_len_cached = 0
158
+
159
+ def _compute_inv_freq(self, device=None):
160
+ return 1.0 / (
161
+ self.base
162
+ ** (torch.arange(0, self.dim, 2, device=device) / self.dim)
163
+ )
164
+
165
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
166
+ self.max_seq_len_cached = seq_len
167
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
168
+
169
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
170
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
171
+ emb = torch.cat((freqs, freqs), dim=-1)
172
+ self.register_buffer("cos_cached", emb.cos()[:, None, :].to(dtype), persistent=False)
173
+ self.register_buffer("sin_cached", emb.sin()[:, None, :].to(dtype), persistent=False)
174
+
175
+ def forward(self, x, seq_len):
176
+ # x: [bs, num_attention_heads, seq_len, head_size]
177
+ if seq_len > self.max_seq_len_cached:
178
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
179
+
180
+ return (
181
+ self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
182
+ self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
183
+ )
184
+
185
+
186
+ def rotate_half(x):
187
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
188
+ return torch.cat((-x2, x1), dim=x1.ndim - 1)
189
+
190
+
191
+ def apply_rotary_pos_emb_index_bhs(q, k, cos, sin, position_id):
192
+ # batch_size, num_head, seq_len, hidden_size
193
+ cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(1), \
194
+ F.embedding(position_id, sin.squeeze(1)).unsqueeze(1)
195
+ q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
196
+ return q, k
197
+
198
+ class VisionExpertAttention(nn.Module):
199
+ def __init__(self, config):
200
+ super().__init__()
201
+ self.config = config
202
+ self.hidden_size = config.hidden_size
203
+ self.num_heads = config.num_attention_heads
204
+ self.head_dim = self.hidden_size // self.num_heads
205
+ self.max_position_embeddings = config.max_position_embeddings
206
+
207
+ self.rotary_emb = RotaryEmbedding(self.head_dim)
208
+ self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
209
+ self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
210
+ self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
211
+ self.language_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
212
+
213
+ def _transpose_for_scores(self, tensor):
214
+ """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
215
+ new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.head_dim)
216
+ tensor = tensor.view(*new_tensor_shape)
217
+ return tensor.permute(0, 2, 1, 3)
218
+
219
+ def forward(
220
+ self,
221
+ hidden_states: torch.Tensor,
222
+ token_type_ids: torch.LongTensor,
223
+ position_ids: torch.LongTensor,
224
+ attention_mask: Optional[torch.Tensor] = None,
225
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
226
+ output_attentions: bool = False,
227
+ use_cache: bool = False,
228
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
229
+ bsz, q_len, _ = hidden_states.size()
230
+ vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
231
+
232
+ shape = list(hidden_states.shape)
233
+ shape[-1] = shape[-1] * 3
234
+ mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device)
235
+ mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(hidden_states[vision_token_mask])
236
+ mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(hidden_states[language_token_mask])
237
+
238
+ query_states, key_states, value_states = torch.split(mixed_raw_layer, self.hidden_size, dim=-1)
239
+ query_states = self._transpose_for_scores(query_states) # B, H, L, HD
240
+ key_states = self._transpose_for_scores(key_states) # B, H, L, HD
241
+ value_states = self._transpose_for_scores(value_states) # B, H, L, HD
242
+
243
+ kv_seq_len = key_states.shape[-2]
244
+ if past_key_value is not None:
245
+ kv_seq_len += past_key_value[0].shape[-2]
246
+
247
+ cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max() + 1)
248
+ query_states, key_states = apply_rotary_pos_emb_index_bhs(query_states, key_states, cos, sin, position_ids)
249
+
250
+ if past_key_value is not None:
251
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
252
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
253
+
254
+ past_key_value = (key_states, value_states) if use_cache else None
255
+
256
+ context_layer = attention_fn(
257
+ query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
258
+ scaling_attention_score=True, attention_dropout=None)
259
+ if context_layer.size() != (bsz, self.num_heads, q_len, self.head_dim):
260
+ raise ValueError(
261
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
262
+ f" {context_layer.size()}"
263
+ )
264
+ context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
265
+
266
+ attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
267
+ attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
268
+ attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
269
+
270
+ if output_attentions:
271
+ warnings.warn("output_attentions is not implemented.")
272
+
273
+ return attn_output, None, past_key_value
274
+
275
+ class CrossAttention(nn.Module):
276
+ def __init__(self, config):
277
+ super().__init__()
278
+ self.config = config
279
+ self.hidden_size = config.hidden_size
280
+ self.cross_hidden_size = config.cross_hidden_size
281
+ self.cross_compute_hidden_size = config.cross_compute_hidden_size
282
+ self.num_heads = config.num_attention_heads
283
+ self.head_dim = self.hidden_size // self.num_heads
284
+ self.cross_head_dim = self.cross_compute_hidden_size // self.num_heads
285
+ self.max_position_embeddings = config.max_position_embeddings
286
+
287
+ # self.rotary_emb = RotaryEmbedding(self.hidden_size // self.num_heads)
288
+ self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
289
+ self.query = nn.Linear(self.hidden_size, self.cross_compute_hidden_size, bias=False)
290
+ self.key_value = nn.Linear(self.cross_hidden_size, self.cross_compute_hidden_size * 2, bias=False)
291
+ self.dense = nn.Linear(self.cross_compute_hidden_size, self.hidden_size, bias=False)
292
+
293
+ def _transpose_for_scores(self, tensor):
294
+ """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
295
+ new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.cross_head_dim)
296
+ tensor = tensor.view(*new_tensor_shape)
297
+ return tensor.permute(0, 2, 1, 3)
298
+
299
+ def forward(
300
+ self,
301
+ hidden_states: torch.Tensor,
302
+ encoder_outputs: torch.LongTensor,
303
+ attention_mask: Optional[torch.Tensor] = None,
304
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
305
+ output_attentions: bool = False,
306
+ use_cache: bool = False,
307
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
308
+ bsz, q_len, _ = hidden_states.size()
309
+
310
+ shape = list(hidden_states.shape)
311
+ shape[-1] = shape[-1] * 3
312
+
313
+ mixed_query_layer = self.query(hidden_states)
314
+ if past_key_value is None:
315
+ mixed_x_layer = self.key_value(encoder_outputs)
316
+ mixed_key_layer, mixed_value_layer = torch.split(mixed_x_layer, self.cross_compute_hidden_size, dim=-1)
317
+ key_states = self._transpose_for_scores(mixed_key_layer) # B, H, L, HD
318
+ value_states = self._transpose_for_scores(mixed_value_layer) # B, H, L, HD
319
+ else:
320
+ key_states, value_states = past_key_value
321
+
322
+ query_states = self._transpose_for_scores(mixed_query_layer) # B, H, L, HD
323
+
324
+ past_key_value = (key_states, value_states) if use_cache else None
325
+
326
+ context_layer = attention_fn(
327
+ query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
328
+ scaling_attention_score=True, attention_dropout=None)
329
+ if context_layer.size() != (bsz, self.num_heads, q_len, self.cross_head_dim):
330
+ raise ValueError(
331
+ f"`cross_attn_output` should be of size {(bsz, self.num_heads, q_len, self.cross_head_dim)}, but is"
332
+ f" {context_layer.size()}"
333
+ )
334
+ context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.cross_hidden_size)
335
+
336
+ attn_output = self.dense(context_layer)
337
+
338
+ if output_attentions:
339
+ warnings.warn("output_attentions is not implemented.")
340
+
341
+ return attn_output, None, past_key_value
342
+
343
+ class CogAgentDecoderLayer(nn.Module):
344
+ def __init__(self, config):
345
+ super().__init__()
346
+ self.hidden_size = config.hidden_size
347
+ self.self_attn = VisionExpertAttention(config=config)
348
+ self.cross_attn = CrossAttention(config=config)
349
+ self.mlp = VisionExpertMLP(config)
350
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
351
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
352
+ self.post_cross_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
353
+
354
+ def forward(
355
+ self,
356
+ hidden_states: torch.Tensor,
357
+ encoder_outputs: torch.Tensor,
358
+ token_type_ids: torch.LongTensor,
359
+ position_ids: torch.LongTensor,
360
+ attention_mask: Optional[torch.Tensor] = None,
361
+ cross_attention_mask: Optional[torch.Tensor] = None,
362
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
363
+ output_attentions: Optional[bool] = False,
364
+ use_cache: Optional[bool] = False,
365
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
366
+ residual = hidden_states
367
+
368
+ hidden_states = self.input_layernorm(hidden_states)
369
+
370
+ # Self Attention
371
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
372
+ hidden_states=hidden_states,
373
+ token_type_ids=token_type_ids,
374
+ position_ids=position_ids,
375
+ attention_mask=attention_mask,
376
+ past_key_value=past_key_value[:2] if past_key_value is not None else None,
377
+ output_attentions=output_attentions,
378
+ use_cache=use_cache,
379
+ )
380
+ hidden_states = residual + hidden_states
381
+
382
+ cross_input = self.post_cross_attention_layernorm(hidden_states)
383
+ # Fully Connected
384
+ attention_output, self_cross_attn_weights, present_cross_key_value = self.cross_attn(
385
+ hidden_states=cross_input,
386
+ encoder_outputs=encoder_outputs,
387
+ attention_mask=cross_attention_mask,
388
+ past_key_value=past_key_value[-2:] if past_key_value is not None else None,
389
+ output_attentions=output_attentions,
390
+ use_cache=use_cache,
391
+ )
392
+ hidden_states = hidden_states + attention_output
393
+ mlp_input = self.post_attention_layernorm(hidden_states)
394
+ mlp_output = self.mlp(mlp_input, token_type_ids=token_type_ids)
395
+ hidden_states = mlp_output + hidden_states
396
+
397
+ outputs = (hidden_states,)
398
+
399
+ if output_attentions:
400
+ outputs += (self_attn_weights,)
401
+
402
+ if use_cache:
403
+ outputs += (present_key_value+present_cross_key_value,)
404
+
405
+ return outputs # type: ignore
406
+
407
+
408
+ class CogAgentPreTrainedModel(PreTrainedModel):
409
+ config_class = CogAgentConfig
410
+ base_model_prefix = "model"
411
+ supports_gradient_checkpointing = False
412
+ _no_split_modules = ["CogAgentDecoderLayer"]
413
+ _skip_keys_device_placement = "past_key_values"
414
+
415
+ def _init_weights(self, module):
416
+ std = self.config.initializer_range
417
+ if isinstance(module, nn.Linear):
418
+ module.weight.data.normal_(mean=0.0, std=std)
419
+ if module.bias is not None:
420
+ module.bias.data.zero_()
421
+ elif isinstance(module, nn.Embedding):
422
+ module.weight.data.normal_(mean=0.0, std=std)
423
+ if module.padding_idx is not None:
424
+ module.weight.data[module.padding_idx].zero_()
425
+
426
+
427
+ def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
428
+ if images_list is None or len(images_list) == 0:
429
+ return True
430
+ for image_list in images_list:
431
+ if len(image_list):
432
+ return False
433
+ return True
434
+
435
+
436
+ def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
437
+ if attention_mask is not None:
438
+ tmp = x.clone()
439
+ tmp[~(attention_mask.bool())] = -1
440
+ else:
441
+ tmp = x.clone()
442
+ # image boi eoi token as LANGUAGE_TOKEN_TYPE
443
+ is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
444
+ is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
445
+ is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
446
+ is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
447
+ is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
448
+ tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
449
+ # final position ids
450
+ y = torch.zeros_like(x, dtype=torch.long)
451
+ y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ((tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
452
+ y = y.cumsum(dim=-1)
453
+ return y
454
+
455
+
456
+ class CogAgentModel(CogAgentPreTrainedModel):
457
+ def __init__(self, config):
458
+ super().__init__(config)
459
+ self.padding_idx = config.pad_token_id
460
+ self.vocab_size = config.vocab_size
461
+
462
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
463
+ self.layers = nn.ModuleList([CogAgentDecoderLayer(config) for _ in range(config.num_hidden_layers)])
464
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
465
+
466
+ self.vision = EVA2CLIPModel(config)
467
+ self.cross_vision = CrossVisionModel(config)
468
+
469
+ self.gradient_checkpointing = False
470
+ # Initialize weights and apply final processing
471
+ self.post_init()
472
+
473
+ def encode_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
474
+ images_list, images = images, []
475
+
476
+ images = []
477
+ for image_list in images_list:
478
+ for image in image_list:
479
+ images.append(image)
480
+
481
+ images = torch.stack(images)
482
+ images_features = self.vision(images)
483
+ return images_features
484
+
485
+ def encode_cross_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
486
+ images_list, images = images, []
487
+
488
+ images = []
489
+ for image_list in images_list:
490
+ for image in image_list:
491
+ images.append(image)
492
+
493
+ images = torch.stack(images)
494
+ encoder_outputs = self.cross_vision(images)
495
+ return encoder_outputs
496
+
497
+ def forward(
498
+ self,
499
+ input_ids: torch.LongTensor = None,
500
+ images: List[List[torch.Tensor]] = None,
501
+ cross_images: List[List[torch.Tensor]] = None,
502
+ token_type_ids: Optional[torch.LongTensor] = None,
503
+ attention_mask: Optional[torch.Tensor] = None,
504
+ cross_attention_mask: Optional[torch.Tensor] = None,
505
+ position_ids: Optional[torch.LongTensor] = None,
506
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
507
+ inputs_embeds: Optional[torch.FloatTensor] = None,
508
+ use_cache: Optional[bool] = None,
509
+ output_attentions: Optional[bool] = None,
510
+ output_hidden_states: Optional[bool] = None,
511
+ return_dict: Optional[bool] = None,
512
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
513
+ """take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
514
+
515
+ if past_key_values is not None:
516
+ encoder_outputs = None
517
+ # generate mode with past_key_values. the image features are already mapped
518
+ else:
519
+ # not allow for inputs_embeds, because we want to process image feature
520
+ assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
521
+ if not is_empty(images): # multi-modality
522
+ assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
523
+ assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
524
+ inputs_embeds = self.embed_tokens(input_ids)
525
+ images_features = self.encode_images(images)
526
+ encoder_outputs = self.encode_cross_images(cross_images)
527
+ images_features = rearrange(images_features, 'b n d -> (b n) d')
528
+ images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
529
+ inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
530
+ else: # single-modality
531
+ if token_type_ids is None:
532
+ token_type_ids = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) * LANGUAGE_TOKEN_TYPE
533
+ assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
534
+ inputs_embeds = self.embed_tokens(input_ids)
535
+ encoder_outputs = None
536
+
537
+ if position_ids is None:
538
+ position_ids = build_position_ids(token_type_ids, attention_mask)
539
+ input_ids = None
540
+
541
+ return self.llm_forward(
542
+ input_ids=input_ids,
543
+ encoder_outputs=encoder_outputs,
544
+ token_type_ids=token_type_ids,
545
+ attention_mask=attention_mask,
546
+ cross_attention_mask=cross_attention_mask,
547
+ position_ids=position_ids,
548
+ past_key_values=past_key_values,
549
+ inputs_embeds=inputs_embeds,
550
+ use_cache=use_cache,
551
+ output_attentions=output_attentions,
552
+ output_hidden_states=output_hidden_states,
553
+ return_dict=return_dict,
554
+ )
555
+
556
+ def llm_forward(
557
+ self,
558
+ input_ids: torch.LongTensor = None,
559
+ encoder_outputs: torch.LongTensor = None,
560
+ token_type_ids: torch.LongTensor = None,
561
+ attention_mask: Optional[torch.Tensor] = None,
562
+ cross_attention_mask: Optional[torch.Tensor] = None,
563
+ position_ids: Optional[torch.LongTensor] = None,
564
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
565
+ inputs_embeds: Optional[torch.FloatTensor] = None,
566
+ use_cache: Optional[bool] = None,
567
+ output_attentions: Optional[bool] = None,
568
+ output_hidden_states: Optional[bool] = None,
569
+ return_dict: Optional[bool] = None,
570
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
571
+ """largely copy from llama forward and adapt for CogAgent with `token_type_ids`"""
572
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
573
+ output_hidden_states = (
574
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
575
+ )
576
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
577
+
578
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
579
+
580
+ # retrieve input_ids and inputs_embeds
581
+ if input_ids is not None and inputs_embeds is not None:
582
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
583
+ elif input_ids is not None:
584
+ batch_size, seq_length = input_ids.shape
585
+ elif inputs_embeds is not None:
586
+ batch_size, seq_length, _ = inputs_embeds.shape
587
+ else:
588
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
589
+
590
+ seq_length_with_past = seq_length
591
+ past_key_values_length = 0
592
+
593
+ if past_key_values is not None:
594
+ past_key_values_length = past_key_values[0][0].shape[2]
595
+ seq_length_with_past = seq_length_with_past + past_key_values_length
596
+
597
+ if position_ids is None:
598
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
599
+ position_ids = torch.arange(
600
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
601
+ )
602
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
603
+ else:
604
+ position_ids = position_ids.view(-1, seq_length).long()
605
+
606
+ if inputs_embeds is None:
607
+ inputs_embeds = self.embed_tokens(input_ids)
608
+ # embed positions
609
+ if attention_mask is None:
610
+ attention_mask = torch.ones(
611
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
612
+ )
613
+ if cross_attention_mask is None:
614
+ cross_attention_mask = torch.ones(
615
+ (batch_size, 1), dtype=torch.bool, device=inputs_embeds.device
616
+ )
617
+ attention_mask = self._prepare_decoder_attention_mask(
618
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
619
+ )
620
+
621
+ hidden_states = inputs_embeds
622
+
623
+ # decoder layers
624
+ all_hidden_states = () if output_hidden_states else None
625
+ all_self_attns = () if output_attentions else None
626
+ next_decoder_cache = () if use_cache else None
627
+
628
+ for idx, decoder_layer in enumerate(self.layers):
629
+ if output_hidden_states:
630
+ all_hidden_states += (hidden_states,)
631
+
632
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
633
+ layer_outputs = decoder_layer(
634
+ hidden_states,
635
+ encoder_outputs=encoder_outputs,
636
+ token_type_ids=token_type_ids,
637
+ attention_mask=attention_mask,
638
+ cross_attention_mask=cross_attention_mask,
639
+ position_ids=position_ids,
640
+ past_key_value=past_key_value,
641
+ output_attentions=output_attentions,
642
+ use_cache=use_cache,
643
+ )
644
+ hidden_states = layer_outputs[0]
645
+
646
+ if use_cache:
647
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
648
+
649
+ if output_attentions:
650
+ all_self_attns += (layer_outputs[1],)
651
+
652
+ hidden_states = self.norm(hidden_states)
653
+
654
+ # add hidden states from the last decoder layer
655
+ if output_hidden_states:
656
+ all_hidden_states += (hidden_states,)
657
+
658
+ next_cache = next_decoder_cache if use_cache else None
659
+ if not return_dict:
660
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
661
+ return BaseModelOutputWithPast(
662
+ last_hidden_state=hidden_states,
663
+ past_key_values=next_cache,
664
+ hidden_states=all_hidden_states,
665
+ attentions=all_self_attns,
666
+ )
667
+
668
+ def get_input_embeddings(self):
669
+ return self.embed_tokens
670
+
671
+ def set_input_embeddings(self, value):
672
+ self.embed_tokens = value
673
+
674
+ # noinspection PyMethodMayBeStatic
675
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
676
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
677
+ # create causal mask
678
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
679
+ combined_attention_mask = None
680
+ if input_shape[-1] > 1:
681
+ combined_attention_mask = _make_causal_mask(
682
+ input_shape,
683
+ inputs_embeds.dtype,
684
+ device=inputs_embeds.device,
685
+ past_key_values_length=past_key_values_length,
686
+ )
687
+
688
+ if attention_mask is not None:
689
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
690
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
691
+ inputs_embeds.device
692
+ )
693
+ combined_attention_mask = (
694
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
695
+ )
696
+
697
+ return combined_attention_mask
698
+
699
+
700
+ def chat_history_to_prompt(history, query):
701
+ prompt = " [INST] "
702
+ for i, (old_query, response) in enumerate(history):
703
+ prompt += old_query + " [/INST] " + response + " [INST] "
704
+ prompt += query + " [/INST] "
705
+ return prompt
706
+
707
+
708
+ def base_history_to_prompt(history, query):
709
+ prompt = query
710
+ return prompt
711
+
712
+
713
+ _history_to_prompt = {
714
+ "base": base_history_to_prompt,
715
+ "chat": chat_history_to_prompt
716
+ }
717
+
718
+
719
+ class CogAgentForCausalLM(CogAgentPreTrainedModel):
720
+ _auto_class = "AutoModelForCausalLM"
721
+
722
+ def __init__(self, config):
723
+ super().__init__(config)
724
+ self.model = CogAgentModel(config)
725
+ self.vocab_size = config.vocab_size
726
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
727
+
728
+ # Initialize weights and apply final processing
729
+ self.post_init()
730
+
731
+ def get_input_embeddings(self):
732
+ return self.model.embed_tokens
733
+
734
+ def set_input_embeddings(self, value):
735
+ self.model.embed_tokens = value
736
+
737
+ def get_output_embeddings(self):
738
+ return self.lm_head
739
+
740
+ def set_output_embeddings(self, new_embeddings):
741
+ self.lm_head = new_embeddings
742
+
743
+ def set_decoder(self, decoder):
744
+ self.model = decoder
745
+
746
+ def get_decoder(self):
747
+ return self.model
748
+
749
+ def forward(
750
+ self,
751
+ input_ids: torch.LongTensor = None,
752
+ images: List[List[torch.Tensor]] = None,
753
+ cross_images: List[List[torch.Tensor]] = None,
754
+ token_type_ids: Optional[torch.LongTensor] = None,
755
+ attention_mask: Optional[torch.Tensor] = None,
756
+ position_ids: Optional[torch.LongTensor] = None,
757
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
758
+ inputs_embeds: Optional[torch.FloatTensor] = None,
759
+ use_cache: Optional[bool] = None,
760
+ output_attentions: Optional[bool] = None,
761
+ output_hidden_states: Optional[bool] = None,
762
+ return_dict: Optional[bool] = None,
763
+ labels: Optional[torch.LongTensor] = None,
764
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
765
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
766
+ output_hidden_states = (
767
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
768
+ )
769
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
770
+
771
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
772
+ outputs = self.model(
773
+ input_ids=input_ids,
774
+ images=images,
775
+ cross_images=cross_images,
776
+ token_type_ids=token_type_ids,
777
+ attention_mask=attention_mask,
778
+ position_ids=position_ids,
779
+ past_key_values=past_key_values,
780
+ inputs_embeds=inputs_embeds,
781
+ use_cache=use_cache,
782
+ output_attentions=output_attentions,
783
+ output_hidden_states=output_hidden_states,
784
+ return_dict=return_dict,
785
+ )
786
+
787
+ hidden_states = outputs[0]
788
+ logits = self.lm_head(hidden_states)
789
+ logits = logits.float()
790
+
791
+ loss = None
792
+ if labels is not None:
793
+ # Shift so that tokens < n predict n
794
+ shift_logits = logits[..., :-1, :].contiguous()
795
+ shift_labels = labels[..., 1:].contiguous()
796
+ # Flatten the tokens
797
+ loss_fct = CrossEntropyLoss()
798
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
799
+ shift_labels = shift_labels.view(-1)
800
+ # Enable model parallelism
801
+ shift_labels = shift_labels.to(shift_logits.device)
802
+ loss = loss_fct(shift_logits, shift_labels)
803
+
804
+ if not return_dict:
805
+ output = (logits,) + outputs[1:]
806
+ return (loss,) + output if loss is not None else output
807
+
808
+ return CausalLMOutputWithPast(
809
+ loss=loss,
810
+ logits=logits,
811
+ past_key_values=outputs.past_key_values,
812
+ hidden_states=outputs.hidden_states,
813
+ attentions=outputs.attentions,
814
+ )
815
+
816
+ def _prepare_attention_mask_for_generation(
817
+ self,
818
+ inputs: torch.Tensor,
819
+ pad_token_id: Optional[int],
820
+ eos_token_id: Optional[Union[int, List[int]]],
821
+ ) -> torch.LongTensor:
822
+ return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
823
+
824
+ def prepare_inputs_for_generation(
825
+ self, input_ids, token_type_ids, images=None, cross_images=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
826
+ ):
827
+ # build position_ids if needed
828
+ position_ids = kwargs.get("position_ids", None)
829
+ if position_ids is None:
830
+ position_ids = build_position_ids(token_type_ids, attention_mask)
831
+
832
+ if past_key_values:
833
+ input_ids = input_ids[:, -1:]
834
+ token_type_ids = token_type_ids[:, -1:]
835
+ position_ids = position_ids[:, -1:]
836
+
837
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
838
+ if inputs_embeds is not None and past_key_values is None:
839
+ model_inputs = {"inputs_embeds": inputs_embeds}
840
+ else:
841
+ model_inputs = {"input_ids": input_ids}
842
+
843
+ model_inputs.update(
844
+ {
845
+ "token_type_ids": token_type_ids,
846
+ "images": images,
847
+ "cross_images": cross_images,
848
+ "position_ids": position_ids,
849
+ "past_key_values": past_key_values,
850
+ "use_cache": kwargs.get("use_cache"),
851
+ "attention_mask": attention_mask,
852
+ }
853
+ )
854
+ return model_inputs
855
+
856
+ def _update_model_kwargs_for_generation(
857
+ self,
858
+ outputs: "ModelOutput",
859
+ model_kwargs: Dict[str, Any],
860
+ is_encoder_decoder: bool = False,
861
+ standardize_cache_format: bool = False,
862
+ ) -> Dict[str, Any]:
863
+ # update past_key_values
864
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
865
+ outputs, standardize_cache_format=standardize_cache_format
866
+ )
867
+ if getattr(outputs, "state", None) is not None:
868
+ model_kwargs["state"] = outputs.state
869
+
870
+ # update token_type_ids with last value
871
+ if "token_type_ids" in model_kwargs:
872
+ token_type_ids = model_kwargs["token_type_ids"]
873
+ new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype, device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
874
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
875
+
876
+ if not is_encoder_decoder:
877
+ # update attention mask
878
+ if "attention_mask" in model_kwargs:
879
+ attention_mask = model_kwargs["attention_mask"]
880
+ model_kwargs["attention_mask"] = torch.cat(
881
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
882
+ )
883
+ else:
884
+ # update decoder attention mask
885
+ if "decoder_attention_mask" in model_kwargs:
886
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
887
+ model_kwargs["decoder_attention_mask"] = torch.cat(
888
+ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
889
+ dim=-1,
890
+ )
891
+
892
+ return model_kwargs
893
+
894
+ def _reorder_cache(self, past_key_values, beam_idx):
895
+ reordered_past = ()
896
+ for layer_past in past_key_values:
897
+ reordered_past += (
898
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
899
+ )
900
+ return reordered_past
901
+
902
+ def build_conversation_input_ids(
903
+ self,
904
+ tokenizer: "PreTrainedTokenizer",
905
+ *,
906
+ query: str,
907
+ history: Optional[List[Tuple[str, str]]] = None,
908
+ images: Optional[List["PIL.Image"]] = None,
909
+ template_version: Optional[Literal["base", "chat", "vqa"]] = None,
910
+ ):
911
+ image_size: int = self.config.vision_config['image_size']
912
+ cross_image_size: int = self.config.cross_image_size
913
+ patch_size: int = self.config.vision_config['patch_size']
914
+ template_version = template_version or self.config.template_version
915
+ assert images is None or len(images) <= 1, f"not support multi images by now."
916
+ history = history or []
917
+ text = _history_to_prompt[template_version](history, query)
918
+
919
+ input_ids = [tokenizer.bos_token_id]
920
+ token_type_ids = [LANGUAGE_TOKEN_TYPE]
921
+ if images is not None and len(images) == 1:
922
+ ori = images
923
+ # vision
924
+ transform = transforms.Compose(
925
+ [
926
+ transforms.Resize(
927
+ (image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
928
+ ),
929
+ transforms.ToTensor(),
930
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
931
+ ]
932
+ )
933
+ images = [transform(ori[0])]
934
+ cross_transform = transforms.Compose(
935
+ [
936
+ transforms.Resize(
937
+ (cross_image_size, cross_image_size), interpolation=transforms.InterpolationMode.BICUBIC
938
+ ),
939
+ transforms.ToTensor(),
940
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
941
+ ]
942
+ )
943
+ cross_images = [cross_transform(ori[0])]
944
+ # language
945
+ vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
946
+ input_ids += [tokenizer.pad_token_id] * vision_token_num
947
+ token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
948
+ text_ids = tokenizer.encode(text, add_special_tokens=False)
949
+
950
+ input_ids += text_ids
951
+ token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
952
+ attention_mask = [1] * len(input_ids)
953
+
954
+ return {
955
+ 'input_ids': torch.tensor(input_ids, dtype=torch.long),
956
+ 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
957
+ 'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
958
+ 'images': images,
959
+ 'cross_images': cross_images
960
+ }