zR commited on
Commit
ca14f13
1 Parent(s): d2415e6
Files changed (1) hide show
  1. modeling_cogvlm.py +34 -37
modeling_cogvlm.py CHANGED
@@ -8,26 +8,17 @@ from torch import nn
8
  from torch.nn import CrossEntropyLoss
9
  from torchvision import transforms
10
  from einops import rearrange
11
-
12
- from decord import VideoReader, cpu
13
- import decord
14
- import io
15
- import numpy as np
16
-
17
  from transformers import PreTrainedModel, PreTrainedTokenizer
18
  from transformers.utils.logging import get_logger
19
  from transformers.activations import ACT2FN
20
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
21
- from torchvision.transforms.functional import InterpolationMode
22
  from torchvision.transforms import Lambda
23
- from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo, CenterCropVideo
24
- from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale
25
  from .configuration_cogvlm import CogVLMConfig
26
  from .util import FastRotaryEmbedding
27
  from .visual import EVA2CLIPModel
28
 
29
-
30
-
31
  if TYPE_CHECKING:
32
  from transformers.utils import ModelOutput
33
 
@@ -101,7 +92,8 @@ class MLP(nn.Module):
101
 
102
  def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
103
  vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
104
- vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
 
105
  language_token_mask = ~vision_token_mask
106
  return vision_token_mask, language_token_mask
107
 
@@ -117,7 +109,7 @@ class VisionExpertMLP(nn.Module):
117
  # vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
118
  # output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
119
  # output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
120
-
121
  output = self.language_mlp(hidden_states)
122
  return output
123
 
@@ -177,7 +169,7 @@ class VisionExpertAttention(nn.Module):
177
  def _transpose_for_scores(self, tensor):
178
  """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
179
  new_tensor_shape = tensor.size()[:-1] + \
180
- (-1, # flexible for multi-query
181
  self.hidden_size_per_attention_head)
182
  tensor = tensor.view(*new_tensor_shape)
183
  return tensor.permute(0, 2, 1, 3)
@@ -214,7 +206,8 @@ class VisionExpertAttention(nn.Module):
214
  if past_key_value is not None:
215
  kv_seq_len += past_key_value[0].shape[-2]
216
 
217
- query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids, max_seqlen=position_ids.max() + 1)
 
218
 
219
  if past_key_value is not None:
220
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
@@ -222,10 +215,13 @@ class VisionExpertAttention(nn.Module):
222
 
223
  past_key_value = (key_states, value_states) if use_cache else None
224
 
225
- key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads, -1, -1).contiguous().view(
 
226
  bsz, self.num_attention_heads, *key_states.shape[2:])
227
- value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads, -1,
228
- -1).contiguous().view(bsz, self.num_attention_heads, *value_states.shape[2:])
 
 
229
 
230
  context_layer = attention_fn(
231
  query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
@@ -240,7 +236,7 @@ class VisionExpertAttention(nn.Module):
240
  # attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
241
  # attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
242
  # attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
243
-
244
  attn_output = self.language_expert_dense(context_layer)
245
 
246
  if output_attentions:
@@ -329,7 +325,8 @@ def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
329
  return True
330
 
331
 
332
- def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
 
333
  if attention_mask is not None:
334
  tmp = x.clone()
335
  tmp[~(attention_mask.bool())] = -1
@@ -344,7 +341,8 @@ def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["to
344
  tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
345
  # final position ids
346
  y = torch.zeros_like(x, dtype=torch.long)
347
- y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ((tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
 
348
  y = y.cumsum(dim=-1)
349
  return y
350
 
@@ -407,7 +405,8 @@ class CogVLMVideoModel(CogVLMPreTrainedModel):
407
  inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
408
  else: # single-modality
409
  if token_type_ids is None:
410
- token_type_ids = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) * LANGUAGE_TOKEN_TYPE
 
411
  assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
412
  inputs_embeds = self.embed_tokens(input_ids)
413
 
@@ -588,7 +587,7 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
588
  self.model = CogVLMVideoModel(config)
589
  self.vocab_size = config.vocab_size
590
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
591
- self.video_downsample = 1 # TODO: change this to config
592
 
593
  # Initialize weights and apply final processing
594
  self.post_init()
@@ -685,7 +684,8 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
685
  return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
686
 
687
  def prepare_inputs_for_generation(
688
- self, input_ids, token_type_ids, images=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
 
689
  ):
690
  # build position_ids if needed
691
  position_ids = kwargs.get("position_ids", None)
@@ -732,7 +732,8 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
732
  # update token_type_ids with last value
733
  if "token_type_ids" in model_kwargs:
734
  token_type_ids = model_kwargs["token_type_ids"]
735
- new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype, device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
 
736
  model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
737
 
738
  if not is_encoder_decoder:
@@ -761,8 +762,6 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
761
  )
762
  return reordered_past
763
 
764
-
765
-
766
  def build_conversation_input_ids(
767
  self,
768
  tokenizer: "PreTrainedTokenizer",
@@ -780,7 +779,7 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
780
  text = _history_to_prompt(template_version, history, query)
781
  input_ids = [tokenizer.bos_token_id]
782
  token_type_ids = [LANGUAGE_TOKEN_TYPE]
783
- add_time_indices = False
784
  if images is not None and len(images) == 1:
785
  # vision
786
  transform = transforms.Compose(
@@ -793,18 +792,19 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
793
  # RandomHorizontalFlipVideo(p=0.5),
794
  ]
795
  )
796
- images = [transform(images[0]).transpose(0, 1)] # (T, C, H, W)
797
  num_eois = len(images[0])
798
  tokenizer.pad_token_id = 128002
799
- vision_token_num = (64 + 2) * num_eois
800
  if not add_time_indices:
801
- input_ids += [tokenizer.pad_token_id] * vision_token_num # add spetial token
 
802
  token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
803
  else:
804
  video_ids, video_type_ids = [], []
 
805
  for _time_idx in range(num_eois):
806
- video_ids += [tokenizer.pad_token_id] * vision_token_num
807
- video_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
808
  # add time indices
809
  time_indices = tokenizer.encode(str(_time_idx), add_special_tokens=False)
810
  video_ids += time_indices
@@ -812,7 +812,7 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
812
  # llama3 adapt for cogvlm
813
  input_ids += video_ids
814
  token_type_ids += video_type_ids
815
-
816
  text_ids = tokenizer.encode(text, add_special_tokens=False)
817
 
818
  if answer is not None:
@@ -820,7 +820,6 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
820
  answer_ids += [tokenizer.eos_token_id]
821
  text_ids += answer_ids
822
 
823
-
824
  input_ids += text_ids
825
  token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
826
  attention_mask = [1] * len(input_ids)
@@ -837,5 +836,3 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
837
  'images': images,
838
  'labels': labels,
839
  }
840
-
841
-
 
8
  from torch.nn import CrossEntropyLoss
9
  from torchvision import transforms
10
  from einops import rearrange
 
 
 
 
 
 
11
  from transformers import PreTrainedModel, PreTrainedTokenizer
12
  from transformers.utils.logging import get_logger
13
  from transformers.activations import ACT2FN
14
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 
15
  from torchvision.transforms import Lambda
16
+ from torchvision.transforms._transforms_video import NormalizeVideo, CenterCropVideo
17
+ from pytorchvideo.transforms import ShortSideScale
18
  from .configuration_cogvlm import CogVLMConfig
19
  from .util import FastRotaryEmbedding
20
  from .visual import EVA2CLIPModel
21
 
 
 
22
  if TYPE_CHECKING:
23
  from transformers.utils import ModelOutput
24
 
 
92
 
93
  def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
94
  vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
95
+ vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (
96
+ token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
97
  language_token_mask = ~vision_token_mask
98
  return vision_token_mask, language_token_mask
99
 
 
109
  # vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
110
  # output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
111
  # output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
112
+
113
  output = self.language_mlp(hidden_states)
114
  return output
115
 
 
169
  def _transpose_for_scores(self, tensor):
170
  """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
171
  new_tensor_shape = tensor.size()[:-1] + \
172
+ (-1, # flexible for multi-query
173
  self.hidden_size_per_attention_head)
174
  tensor = tensor.view(*new_tensor_shape)
175
  return tensor.permute(0, 2, 1, 3)
 
206
  if past_key_value is not None:
207
  kv_seq_len += past_key_value[0].shape[-2]
208
 
209
+ query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids,
210
+ max_seqlen=position_ids.max() + 1)
211
 
212
  if past_key_value is not None:
213
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
 
215
 
216
  past_key_value = (key_states, value_states) if use_cache else None
217
 
218
+ key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads, -1,
219
+ -1).contiguous().view(
220
  bsz, self.num_attention_heads, *key_states.shape[2:])
221
+ value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads,
222
+ -1,
223
+ -1).contiguous().view(bsz, self.num_attention_heads,
224
+ *value_states.shape[2:])
225
 
226
  context_layer = attention_fn(
227
  query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
 
236
  # attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
237
  # attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
238
  # attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
239
+
240
  attn_output = self.language_expert_dense(context_layer)
241
 
242
  if output_attentions:
 
325
  return True
326
 
327
 
328
+ def build_position_ids(x: "torch.BoolTensor(B, L)",
329
+ attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
330
  if attention_mask is not None:
331
  tmp = x.clone()
332
  tmp[~(attention_mask.bool())] = -1
 
341
  tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
342
  # final position ids
343
  y = torch.zeros_like(x, dtype=torch.long)
344
+ y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
345
+ (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
346
  y = y.cumsum(dim=-1)
347
  return y
348
 
 
405
  inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
406
  else: # single-modality
407
  if token_type_ids is None:
408
+ token_type_ids = torch.ones_like(input_ids, dtype=torch.long,
409
+ device=input_ids.device) * LANGUAGE_TOKEN_TYPE
410
  assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
411
  inputs_embeds = self.embed_tokens(input_ids)
412
 
 
587
  self.model = CogVLMVideoModel(config)
588
  self.vocab_size = config.vocab_size
589
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
590
+ self.video_downsample = 1 # TODO: change this to config
591
 
592
  # Initialize weights and apply final processing
593
  self.post_init()
 
684
  return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
685
 
686
  def prepare_inputs_for_generation(
687
+ self, input_ids, token_type_ids, images=None, past_key_values=None, attention_mask=None, inputs_embeds=None,
688
+ **kwargs
689
  ):
690
  # build position_ids if needed
691
  position_ids = kwargs.get("position_ids", None)
 
732
  # update token_type_ids with last value
733
  if "token_type_ids" in model_kwargs:
734
  token_type_ids = model_kwargs["token_type_ids"]
735
+ new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype,
736
+ device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
737
  model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
738
 
739
  if not is_encoder_decoder:
 
762
  )
763
  return reordered_past
764
 
 
 
765
  def build_conversation_input_ids(
766
  self,
767
  tokenizer: "PreTrainedTokenizer",
 
779
  text = _history_to_prompt(template_version, history, query)
780
  input_ids = [tokenizer.bos_token_id]
781
  token_type_ids = [LANGUAGE_TOKEN_TYPE]
782
+ add_time_indices = True if template_version == 'chat' else False
783
  if images is not None and len(images) == 1:
784
  # vision
785
  transform = transforms.Compose(
 
792
  # RandomHorizontalFlipVideo(p=0.5),
793
  ]
794
  )
795
+ images = [transform(images[0]).transpose(0, 1)] # (T, C, H, W)
796
  num_eois = len(images[0])
797
  tokenizer.pad_token_id = 128002
 
798
  if not add_time_indices:
799
+ vision_token_num = (64 + 2) * num_eois
800
+ input_ids += [tokenizer.pad_token_id] * vision_token_num # add spetial token
801
  token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
802
  else:
803
  video_ids, video_type_ids = [], []
804
+ sing_vision_token_num = (64 + 2)
805
  for _time_idx in range(num_eois):
806
+ video_ids += [tokenizer.pad_token_id] * sing_vision_token_num
807
+ video_type_ids += [VISION_TOKEN_TYPE] * sing_vision_token_num
808
  # add time indices
809
  time_indices = tokenizer.encode(str(_time_idx), add_special_tokens=False)
810
  video_ids += time_indices
 
812
  # llama3 adapt for cogvlm
813
  input_ids += video_ids
814
  token_type_ids += video_type_ids
815
+
816
  text_ids = tokenizer.encode(text, add_special_tokens=False)
817
 
818
  if answer is not None:
 
820
  answer_ids += [tokenizer.eos_token_id]
821
  text_ids += answer_ids
822
 
 
823
  input_ids += text_ids
824
  token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
825
  attention_mask = [1] * len(input_ids)
 
836
  'images': images,
837
  'labels': labels,
838
  }