zR
commited on
Commit
•
ca14f13
1
Parent(s):
d2415e6
update
Browse files- modeling_cogvlm.py +34 -37
modeling_cogvlm.py
CHANGED
@@ -8,26 +8,17 @@ from torch import nn
|
|
8 |
from torch.nn import CrossEntropyLoss
|
9 |
from torchvision import transforms
|
10 |
from einops import rearrange
|
11 |
-
|
12 |
-
from decord import VideoReader, cpu
|
13 |
-
import decord
|
14 |
-
import io
|
15 |
-
import numpy as np
|
16 |
-
|
17 |
from transformers import PreTrainedModel, PreTrainedTokenizer
|
18 |
from transformers.utils.logging import get_logger
|
19 |
from transformers.activations import ACT2FN
|
20 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
21 |
-
from torchvision.transforms.functional import InterpolationMode
|
22 |
from torchvision.transforms import Lambda
|
23 |
-
from torchvision.transforms._transforms_video import NormalizeVideo,
|
24 |
-
from pytorchvideo.transforms import
|
25 |
from .configuration_cogvlm import CogVLMConfig
|
26 |
from .util import FastRotaryEmbedding
|
27 |
from .visual import EVA2CLIPModel
|
28 |
|
29 |
-
|
30 |
-
|
31 |
if TYPE_CHECKING:
|
32 |
from transformers.utils import ModelOutput
|
33 |
|
@@ -101,7 +92,8 @@ class MLP(nn.Module):
|
|
101 |
|
102 |
def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
|
103 |
vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
|
104 |
-
vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (
|
|
|
105 |
language_token_mask = ~vision_token_mask
|
106 |
return vision_token_mask, language_token_mask
|
107 |
|
@@ -117,7 +109,7 @@ class VisionExpertMLP(nn.Module):
|
|
117 |
# vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
|
118 |
# output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
|
119 |
# output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
|
120 |
-
|
121 |
output = self.language_mlp(hidden_states)
|
122 |
return output
|
123 |
|
@@ -177,7 +169,7 @@ class VisionExpertAttention(nn.Module):
|
|
177 |
def _transpose_for_scores(self, tensor):
|
178 |
"""Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
|
179 |
new_tensor_shape = tensor.size()[:-1] + \
|
180 |
-
(-1,
|
181 |
self.hidden_size_per_attention_head)
|
182 |
tensor = tensor.view(*new_tensor_shape)
|
183 |
return tensor.permute(0, 2, 1, 3)
|
@@ -214,7 +206,8 @@ class VisionExpertAttention(nn.Module):
|
|
214 |
if past_key_value is not None:
|
215 |
kv_seq_len += past_key_value[0].shape[-2]
|
216 |
|
217 |
-
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids,
|
|
|
218 |
|
219 |
if past_key_value is not None:
|
220 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
@@ -222,10 +215,13 @@ class VisionExpertAttention(nn.Module):
|
|
222 |
|
223 |
past_key_value = (key_states, value_states) if use_cache else None
|
224 |
|
225 |
-
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads, -1,
|
|
|
226 |
bsz, self.num_attention_heads, *key_states.shape[2:])
|
227 |
-
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads,
|
228 |
-
-1
|
|
|
|
|
229 |
|
230 |
context_layer = attention_fn(
|
231 |
query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
|
@@ -240,7 +236,7 @@ class VisionExpertAttention(nn.Module):
|
|
240 |
# attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
241 |
# attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
|
242 |
# attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
|
243 |
-
|
244 |
attn_output = self.language_expert_dense(context_layer)
|
245 |
|
246 |
if output_attentions:
|
@@ -329,7 +325,8 @@ def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
|
|
329 |
return True
|
330 |
|
331 |
|
332 |
-
def build_position_ids(x: "torch.BoolTensor(B, L)",
|
|
|
333 |
if attention_mask is not None:
|
334 |
tmp = x.clone()
|
335 |
tmp[~(attention_mask.bool())] = -1
|
@@ -344,7 +341,8 @@ def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["to
|
|
344 |
tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
|
345 |
# final position ids
|
346 |
y = torch.zeros_like(x, dtype=torch.long)
|
347 |
-
y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
|
|
|
348 |
y = y.cumsum(dim=-1)
|
349 |
return y
|
350 |
|
@@ -407,7 +405,8 @@ class CogVLMVideoModel(CogVLMPreTrainedModel):
|
|
407 |
inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
|
408 |
else: # single-modality
|
409 |
if token_type_ids is None:
|
410 |
-
token_type_ids = torch.ones_like(input_ids, dtype=torch.long,
|
|
|
411 |
assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
|
412 |
inputs_embeds = self.embed_tokens(input_ids)
|
413 |
|
@@ -588,7 +587,7 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
588 |
self.model = CogVLMVideoModel(config)
|
589 |
self.vocab_size = config.vocab_size
|
590 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
591 |
-
self.video_downsample = 1
|
592 |
|
593 |
# Initialize weights and apply final processing
|
594 |
self.post_init()
|
@@ -685,7 +684,8 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
685 |
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
|
686 |
|
687 |
def prepare_inputs_for_generation(
|
688 |
-
self, input_ids, token_type_ids, images=None, past_key_values=None, attention_mask=None, inputs_embeds=None,
|
|
|
689 |
):
|
690 |
# build position_ids if needed
|
691 |
position_ids = kwargs.get("position_ids", None)
|
@@ -732,7 +732,8 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
732 |
# update token_type_ids with last value
|
733 |
if "token_type_ids" in model_kwargs:
|
734 |
token_type_ids = model_kwargs["token_type_ids"]
|
735 |
-
new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype,
|
|
|
736 |
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
|
737 |
|
738 |
if not is_encoder_decoder:
|
@@ -761,8 +762,6 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
761 |
)
|
762 |
return reordered_past
|
763 |
|
764 |
-
|
765 |
-
|
766 |
def build_conversation_input_ids(
|
767 |
self,
|
768 |
tokenizer: "PreTrainedTokenizer",
|
@@ -780,7 +779,7 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
780 |
text = _history_to_prompt(template_version, history, query)
|
781 |
input_ids = [tokenizer.bos_token_id]
|
782 |
token_type_ids = [LANGUAGE_TOKEN_TYPE]
|
783 |
-
add_time_indices = False
|
784 |
if images is not None and len(images) == 1:
|
785 |
# vision
|
786 |
transform = transforms.Compose(
|
@@ -793,18 +792,19 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
793 |
# RandomHorizontalFlipVideo(p=0.5),
|
794 |
]
|
795 |
)
|
796 |
-
images = [transform(images[0]).transpose(0, 1)]
|
797 |
num_eois = len(images[0])
|
798 |
tokenizer.pad_token_id = 128002
|
799 |
-
vision_token_num = (64 + 2) * num_eois
|
800 |
if not add_time_indices:
|
801 |
-
|
|
|
802 |
token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
|
803 |
else:
|
804 |
video_ids, video_type_ids = [], []
|
|
|
805 |
for _time_idx in range(num_eois):
|
806 |
-
video_ids += [tokenizer.pad_token_id] *
|
807 |
-
video_type_ids += [VISION_TOKEN_TYPE] *
|
808 |
# add time indices
|
809 |
time_indices = tokenizer.encode(str(_time_idx), add_special_tokens=False)
|
810 |
video_ids += time_indices
|
@@ -812,7 +812,7 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
812 |
# llama3 adapt for cogvlm
|
813 |
input_ids += video_ids
|
814 |
token_type_ids += video_type_ids
|
815 |
-
|
816 |
text_ids = tokenizer.encode(text, add_special_tokens=False)
|
817 |
|
818 |
if answer is not None:
|
@@ -820,7 +820,6 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
820 |
answer_ids += [tokenizer.eos_token_id]
|
821 |
text_ids += answer_ids
|
822 |
|
823 |
-
|
824 |
input_ids += text_ids
|
825 |
token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
|
826 |
attention_mask = [1] * len(input_ids)
|
@@ -837,5 +836,3 @@ class CogVLMVideoForCausalLM(CogVLMPreTrainedModel):
|
|
837 |
'images': images,
|
838 |
'labels': labels,
|
839 |
}
|
840 |
-
|
841 |
-
|
|
|
8 |
from torch.nn import CrossEntropyLoss
|
9 |
from torchvision import transforms
|
10 |
from einops import rearrange
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
from transformers import PreTrainedModel, PreTrainedTokenizer
|
12 |
from transformers.utils.logging import get_logger
|
13 |
from transformers.activations import ACT2FN
|
14 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
|
15 |
from torchvision.transforms import Lambda
|
16 |
+
from torchvision.transforms._transforms_video import NormalizeVideo, CenterCropVideo
|
17 |
+
from pytorchvideo.transforms import ShortSideScale
|
18 |
from .configuration_cogvlm import CogVLMConfig
|
19 |
from .util import FastRotaryEmbedding
|
20 |
from .visual import EVA2CLIPModel
|
21 |
|
|
|
|
|
22 |
if TYPE_CHECKING:
|
23 |
from transformers.utils import ModelOutput
|
24 |
|
|
|
92 |
|
93 |
def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
|
94 |
vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
|
95 |
+
vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (
|
96 |
+
token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
|
97 |
language_token_mask = ~vision_token_mask
|
98 |
return vision_token_mask, language_token_mask
|
99 |
|
|
|
109 |
# vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
|
110 |
# output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
|
111 |
# output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
|
112 |
+
|
113 |
output = self.language_mlp(hidden_states)
|
114 |
return output
|
115 |
|
|
|
169 |
def _transpose_for_scores(self, tensor):
|
170 |
"""Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
|
171 |
new_tensor_shape = tensor.size()[:-1] + \
|
172 |
+
(-1, # flexible for multi-query
|
173 |
self.hidden_size_per_attention_head)
|
174 |
tensor = tensor.view(*new_tensor_shape)
|
175 |
return tensor.permute(0, 2, 1, 3)
|
|
|
206 |
if past_key_value is not None:
|
207 |
kv_seq_len += past_key_value[0].shape[-2]
|
208 |
|
209 |
+
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids,
|
210 |
+
max_seqlen=position_ids.max() + 1)
|
211 |
|
212 |
if past_key_value is not None:
|
213 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
|
215 |
|
216 |
past_key_value = (key_states, value_states) if use_cache else None
|
217 |
|
218 |
+
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads, -1,
|
219 |
+
-1).contiguous().view(
|
220 |
bsz, self.num_attention_heads, *key_states.shape[2:])
|
221 |
+
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads,
|
222 |
+
-1,
|
223 |
+
-1).contiguous().view(bsz, self.num_attention_heads,
|
224 |
+
*value_states.shape[2:])
|
225 |
|
226 |
context_layer = attention_fn(
|
227 |
query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
|
|
|
236 |
# attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
237 |
# attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
|
238 |
# attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
|
239 |
+
|
240 |
attn_output = self.language_expert_dense(context_layer)
|
241 |
|
242 |
if output_attentions:
|
|
|
325 |
return True
|
326 |
|
327 |
|
328 |
+
def build_position_ids(x: "torch.BoolTensor(B, L)",
|
329 |
+
attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
|
330 |
if attention_mask is not None:
|
331 |
tmp = x.clone()
|
332 |
tmp[~(attention_mask.bool())] = -1
|
|
|
341 |
tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
|
342 |
# final position ids
|
343 |
y = torch.zeros_like(x, dtype=torch.long)
|
344 |
+
y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | (
|
345 |
+
(tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
|
346 |
y = y.cumsum(dim=-1)
|
347 |
return y
|
348 |
|
|
|
405 |
inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
|
406 |
else: # single-modality
|
407 |
if token_type_ids is None:
|
408 |
+
token_type_ids = torch.ones_like(input_ids, dtype=torch.long,
|
409 |
+
device=input_ids.device) * LANGUAGE_TOKEN_TYPE
|
410 |
assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
|
411 |
inputs_embeds = self.embed_tokens(input_ids)
|
412 |
|
|
|
587 |
self.model = CogVLMVideoModel(config)
|
588 |
self.vocab_size = config.vocab_size
|
589 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
590 |
+
self.video_downsample = 1 # TODO: change this to config
|
591 |
|
592 |
# Initialize weights and apply final processing
|
593 |
self.post_init()
|
|
|
684 |
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
|
685 |
|
686 |
def prepare_inputs_for_generation(
|
687 |
+
self, input_ids, token_type_ids, images=None, past_key_values=None, attention_mask=None, inputs_embeds=None,
|
688 |
+
**kwargs
|
689 |
):
|
690 |
# build position_ids if needed
|
691 |
position_ids = kwargs.get("position_ids", None)
|
|
|
732 |
# update token_type_ids with last value
|
733 |
if "token_type_ids" in model_kwargs:
|
734 |
token_type_ids = model_kwargs["token_type_ids"]
|
735 |
+
new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype,
|
736 |
+
device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
|
737 |
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
|
738 |
|
739 |
if not is_encoder_decoder:
|
|
|
762 |
)
|
763 |
return reordered_past
|
764 |
|
|
|
|
|
765 |
def build_conversation_input_ids(
|
766 |
self,
|
767 |
tokenizer: "PreTrainedTokenizer",
|
|
|
779 |
text = _history_to_prompt(template_version, history, query)
|
780 |
input_ids = [tokenizer.bos_token_id]
|
781 |
token_type_ids = [LANGUAGE_TOKEN_TYPE]
|
782 |
+
add_time_indices = True if template_version == 'chat' else False
|
783 |
if images is not None and len(images) == 1:
|
784 |
# vision
|
785 |
transform = transforms.Compose(
|
|
|
792 |
# RandomHorizontalFlipVideo(p=0.5),
|
793 |
]
|
794 |
)
|
795 |
+
images = [transform(images[0]).transpose(0, 1)] # (T, C, H, W)
|
796 |
num_eois = len(images[0])
|
797 |
tokenizer.pad_token_id = 128002
|
|
|
798 |
if not add_time_indices:
|
799 |
+
vision_token_num = (64 + 2) * num_eois
|
800 |
+
input_ids += [tokenizer.pad_token_id] * vision_token_num # add spetial token
|
801 |
token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
|
802 |
else:
|
803 |
video_ids, video_type_ids = [], []
|
804 |
+
sing_vision_token_num = (64 + 2)
|
805 |
for _time_idx in range(num_eois):
|
806 |
+
video_ids += [tokenizer.pad_token_id] * sing_vision_token_num
|
807 |
+
video_type_ids += [VISION_TOKEN_TYPE] * sing_vision_token_num
|
808 |
# add time indices
|
809 |
time_indices = tokenizer.encode(str(_time_idx), add_special_tokens=False)
|
810 |
video_ids += time_indices
|
|
|
812 |
# llama3 adapt for cogvlm
|
813 |
input_ids += video_ids
|
814 |
token_type_ids += video_type_ids
|
815 |
+
|
816 |
text_ids = tokenizer.encode(text, add_special_tokens=False)
|
817 |
|
818 |
if answer is not None:
|
|
|
820 |
answer_ids += [tokenizer.eos_token_id]
|
821 |
text_ids += answer_ids
|
822 |
|
|
|
823 |
input_ids += text_ids
|
824 |
token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
|
825 |
attention_mask = [1] * len(input_ids)
|
|
|
836 |
'images': images,
|
837 |
'labels': labels,
|
838 |
}
|
|
|
|