HachiML commited on
Commit
5436666
1 Parent(s): 2b5af3d

Delete modeling_mists.py

Browse files
Files changed (1) hide show
  1. modeling_mists.py +0 -405
modeling_mists.py DELETED
@@ -1,405 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import List, Optional, Tuple, Union
3
-
4
- import torch
5
- import torch.utils.checkpoint
6
- from torch import nn
7
-
8
- from transformers import PreTrainedModel
9
- from transformers.activations import ACT2FN
10
- from transformers import Cache
11
- from transformers.modeling_outputs import ModelOutput
12
- from transformers.utils import (
13
- add_start_docstrings,
14
- add_start_docstrings_to_model_forward,
15
- logging,
16
- replace_return_docstrings,
17
- )
18
- from transformers import AutoModel, AutoModelForCausalLM
19
-
20
- from .modeling_moment import MomentEmbeddingModel
21
- from .configuration_mists import MistsConfig
22
-
23
-
24
- @dataclass
25
- # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Mists
26
- class MistsCausalLMOutputWithPast(ModelOutput):
27
- loss: Optional[torch.FloatTensor] = None
28
- logits: torch.FloatTensor = None
29
- past_key_values: Optional[List[torch.FloatTensor]] = None
30
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
31
- attentions: Optional[Tuple[torch.FloatTensor]] = None
32
- time_series_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
33
-
34
-
35
- class MistsMultiModalProjector(nn.Module):
36
- def __init__(self, config: MistsConfig):
37
- super().__init__()
38
-
39
- # time series towerからのoutputは定型でない。input_maskに合わせてpadding用の学習可能なベクトルを使用し、time series towerからの入力を定型にする。
40
- self.mask_embedding = nn.Parameter(torch.randn(1, 1, config.time_series_hidden_size))
41
-
42
- # mlp
43
- self.linear_1 = nn.Linear(config.time_series_hidden_size, config.text_config.hidden_size, bias=True)
44
- self.act = ACT2FN[config.projector_hidden_act]
45
- self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
46
-
47
- def forward(self, time_series_features, input_mask):
48
- masked_features = time_series_features * input_mask.unsqueeze(-1) + self.mask_embedding * (1 - input_mask.unsqueeze(-1))
49
- hidden_states = self.linear_1(masked_features)
50
- hidden_states = self.act(hidden_states)
51
- hidden_states = self.linear_2(hidden_states)
52
- return hidden_states
53
-
54
-
55
- class MistsPreTrainedModel(PreTrainedModel):
56
- config_class = MistsConfig
57
- base_model_prefix = "model"
58
- supports_gradient_checkpointing = True
59
- _no_split_modules = ["T5Block"]
60
- _skip_keys_device_placement = "past_key_values"
61
- _supports_flash_attn_2 = True
62
- _supports_sdpa = True
63
- _supports_cache_class = True
64
- _supports_static_cache = True
65
-
66
- def _init_weights(self, module):
67
- # important: 現状Mistralの初期化コードをそのまま移植している。
68
- # refers: https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/mistral/modeling_mistral.py#L762
69
- # 現状のまま事前学習を行うのは望ましくなく、FineTuningと推論のみが可能。
70
- std = self.config.text_config.initializer_range
71
- if isinstance(module, nn.Linear):
72
- module.weight.data.normal_(mean=0.0, std=std)
73
- if module.bias is not None:
74
- module.bias.data.zero_()
75
- elif isinstance(module, nn.Embedding):
76
- module.weight.data.normal_(mean=0.0, std=std)
77
- if module.padding_idx is not None:
78
- module.weight.data[module.padding_idx].zero_()
79
-
80
-
81
- class MistsForConditionalGeneration(MistsPreTrainedModel):
82
- def __init__(self, config: MistsConfig):
83
- super().__init__(config)
84
-
85
- self.time_series_tower = MomentEmbeddingModel(config.time_series_config)
86
- self.multi_modal_projector = MistsMultiModalProjector(config)
87
- self.vocab_size = config.text_config.vocab_size
88
- self.language_model = AutoModelForCausalLM.from_config(
89
- config.text_config, attn_implementation=config._attn_implementation
90
- )
91
- self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
92
- self.post_init()
93
-
94
- def get_time_series_tower(self):
95
- time_series_tower = getattr(self, 'time_series_tower', None)
96
- if type(time_series_tower) is list:
97
- time_series_tower = time_series_tower[0]
98
- return time_series_tower
99
-
100
- def get_input_embeddings(self):
101
- return self.language_model.get_input_embeddings()
102
-
103
- def set_input_embeddings(self, value):
104
- self.language_model.set_input_embeddings(value)
105
-
106
- def get_output_embeddings(self):
107
- return self.language_model.get_output_embeddings()
108
-
109
- def set_output_embeddings(self, new_embeddings):
110
- self.language_model.set_output_embeddings(new_embeddings)
111
-
112
- def set_decoder(self, decoder):
113
- self.language_model.set_decoder(decoder)
114
-
115
- def get_decoder(self):
116
- return self.language_model.get_decoder()
117
-
118
- def tie_weights(self):
119
- return self.language_model.tie_weights()
120
-
121
- def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
122
- model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
123
- # update vocab size
124
- self.config.text_config.vocab_size = model_embeds.num_embeddings
125
- self.vocab_size = model_embeds.num_embeddings
126
- return model_embeds
127
-
128
- # copy _merge_input_ids_with_image_features from LlabaForConditionalGeneration
129
- # refers: https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/llava/modeling_llava.py#L277C9-L277C45
130
- def _merge_input_ids_with_time_series_features(self, time_series_features, inputs_embeds, input_ids, attention_mask, labels):
131
- num_time_series, num_time_series_patches, embed_dim = time_series_features.shape # num_time_series_patches = n_channels x n_patches
132
- batch_size, sequence_length = input_ids.shape
133
- left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
134
- # 1. Create a mask to know where special time_series tokens are
135
- special_time_series_token_mask = input_ids == self.config.time_series_token_index
136
- num_special_time_series_tokens = torch.sum(special_time_series_token_mask, dim=-1)
137
- # Compute the maximum embed dimension
138
- max_embed_dim = (num_special_time_series_tokens.max() * (num_time_series_patches - 1)) + sequence_length
139
- max_embed_dim = int(max_embed_dim.item()) # テンソルから整数値を取得
140
- if max_embed_dim is None:
141
- print(f"num_special_time_series_tokens.max(): {num_special_time_series_tokens.max()}")
142
- print(f"num_time_series_patches: {num_time_series_patches}")
143
- print(f"sequence_length: {sequence_length}")
144
- else:
145
- print(f"max_embed_dim 0: {max_embed_dim}")
146
- batch_indices, non_time_series_indices = torch.where(input_ids != self.config.time_series_token_index)
147
-
148
- # 2. Compute the positions where text should be written
149
- # Calculate new positions for text tokens in merged time_series-text sequence.
150
- # `special_time_series_token_mask` identifies time_series tokens. Each time_series token will be replaced by `nb_text_tokens_per_time_series - 1` text tokens.
151
- # `torch.cumsum` computes how each time_series token shifts subsequent text token positions.
152
- # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
153
- new_token_positions = torch.cumsum((special_time_series_token_mask * (num_time_series_patches - 1) + 1), -1) - 1
154
- nb_time_series_pad = max_embed_dim - 1 - new_token_positions[:, -1]
155
- if left_padding:
156
- new_token_positions += nb_time_series_pad[:, None] # offset for left padding
157
- text_to_overwrite = new_token_positions[batch_indices, non_time_series_indices]
158
-
159
- # 3. Create the full embedding, already padded to the maximum position
160
- final_embedding = torch.zeros(
161
- batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
162
- )
163
- final_attention_mask = torch.zeros(
164
- batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
165
- )
166
- if labels is not None:
167
- final_labels = torch.full(
168
- (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
169
- )
170
- # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
171
- # set the corresponding tensors into their correct target device.
172
- target_device = inputs_embeds.device
173
- batch_indices, non_time_series_indices, text_to_overwrite = (
174
- batch_indices.to(target_device),
175
- non_time_series_indices.to(target_device),
176
- text_to_overwrite.to(target_device),
177
- )
178
- attention_mask = attention_mask.to(target_device)
179
-
180
- # 4. Fill the embeddings based on the mask. If we have ["hey" "<time_series>", "how", "are"]
181
- # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the time_series features
182
- final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_time_series_indices]
183
- final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_time_series_indices]
184
- print("max_embed_dim is None: ", (max_embed_dim is None))
185
- print("max_embed_dim: ", max_embed_dim)
186
- if labels is not None:
187
- final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_time_series_indices]
188
- print("max_embed_dim is None: ", (max_embed_dim is None))
189
- print("max_embed_dim: ", max_embed_dim)
190
-
191
- # 5. Fill the embeddings corresponding to the time_series. Anything that is not `text_positions` needs filling (#29835)
192
- print("inputs_embeds.device: ", inputs_embeds.device)
193
- print("max_embed_dim: ", max_embed_dim, " is None: ", (max_embed_dim is None))
194
- time_series_to_overwrite = torch.full(
195
- (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
196
- )
197
- time_series_to_overwrite[batch_indices, text_to_overwrite] = False
198
- time_series_to_overwrite &= time_series_to_overwrite.cumsum(-1) - 1 >= nb_time_series_pad[:, None].to(target_device)
199
-
200
- if time_series_to_overwrite.sum() != time_series_features.shape[:-1].numel():
201
- raise ValueError(
202
- f"The input provided to the model are wrong. The number of time series tokens is {torch.sum(special_time_series_token_mask)} while"
203
- f" the number of time series given to the model is {num_time_series}. This prevents correct indexing and breaks batch generation."
204
- )
205
-
206
- final_embedding[time_series_to_overwrite] = time_series_features.contiguous().reshape(-1, embed_dim).to(target_device)
207
- final_attention_mask |= time_series_to_overwrite
208
- position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
209
-
210
- # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
211
- batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
212
- indices_to_mask = new_token_positions[batch_indices, pad_indices]
213
-
214
- final_embedding[batch_indices, indices_to_mask] = 0
215
-
216
- if labels is None:
217
- final_labels = None
218
-
219
- return final_embedding, final_attention_mask, final_labels, position_ids
220
-
221
- def forward(
222
- self,
223
- input_ids: torch.LongTensor = None,
224
- time_series_values: torch.FloatTensor = None,
225
- time_series_input_mask: torch.FloatTensor = None,
226
- attention_mask: Optional[torch.Tensor] = None,
227
- position_ids: Optional[torch.LongTensor] = None,
228
- past_key_values: Optional[List[torch.FloatTensor]] = None,
229
- inputs_embeds: Optional[torch.FloatTensor] = None,
230
- # time_series_feature_layer: Optional[int] = None,
231
- # time_series_feature_select_strategy: Optional[str] = None,
232
- labels: Optional[torch.LongTensor] = None,
233
- use_cache: Optional[bool] = None,
234
- output_attentions: Optional[bool] = None,
235
- output_hidden_states: Optional[bool] = None,
236
- return_dict: Optional[bool] = None,
237
- ) -> Union[Tuple, MistsCausalLMOutputWithPast]:
238
-
239
- # language_modelの引数で変わる
240
- # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
241
- # output_hidden_states = (
242
- # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
243
- # )
244
- # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
245
- # vision_feature_layer = (
246
- # vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
247
- # )
248
- # vision_feature_select_strategy = (
249
- # vision_feature_select_strategy
250
- # if vision_feature_select_strategy is not None
251
- # else self.config.vision_feature_select_strategy
252
- # )
253
-
254
- if inputs_embeds is None:
255
- # 1. Extra the input embeddings
256
- inputs_embeds = self.get_input_embeddings()(input_ids)
257
-
258
- # 2. Merge text and time_series
259
- if time_series_values is not None and input_ids.shape[1] != 1:
260
- time_series_outputs = self.time_series_tower(time_series_values, time_series_input_mask)
261
- time_series_features = self.multi_modal_projector(
262
- time_series_features=time_series_outputs.hidden_states, # [batch_size, n_patches, d_model]
263
- input_mask=time_series_outputs.input_mask_patch_view, # [batch_size, n_paches]
264
- )
265
-
266
- inputs_embeds = inputs_embeds.to(time_series_features.dtype)
267
- inputs_embeds, attention_mask, labels, position_ids =self._merge_input_ids_with_time_series_features(
268
- time_series_features, inputs_embeds, input_ids, attention_mask, labels
269
- )
270
-
271
- # In case input_ids.shape[1] == 1 & time_series_values==None & past_key_values != None, we are in the case of
272
- # generation with cache
273
- elif past_key_values is not None and time_series_values is not None and input_ids.shape[1] == 1:
274
- # Retrieve the first layer to inspect the logits and mask out the hidden states
275
- # that are set to 0
276
- first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
277
-
278
- # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
279
- batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
280
-
281
- # Get the target length
282
- target_length = input_ids.shape[1]
283
- past_length = first_layer_past_key_value.shape[-1]
284
-
285
- extended_attention_mask = torch.ones(
286
- (attention_mask.shape[0], past_length),
287
- dtype=attention_mask.dtype,
288
- device=attention_mask.device,
289
- )
290
-
291
- # Filter out only the tokens that can be un-attended, this can happen
292
- # if one uses Llava + Fused modules where the cache on the
293
- # first iteration is already big enough, or if one passes custom cache
294
- valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
295
- new_batch_index = batch_index[valid_indices]
296
- new_non_attended_tokens = non_attended_tokens[valid_indices]
297
-
298
- # Zero-out the places where we don't need to attend
299
- extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
300
-
301
- attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
302
- position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
303
-
304
- print("inputs_embeds: ", inputs_embeds.shape)
305
-
306
- outputs = self.language_model(
307
- attention_mask=attention_mask,
308
- position_ids=position_ids,
309
- past_key_values=past_key_values,
310
- inputs_embeds=inputs_embeds.to(self.language_model.dtype),
311
- use_cache=use_cache,
312
- output_attentions=output_attentions,
313
- output_hidden_states=output_hidden_states,
314
- return_dict=return_dict,
315
- )
316
-
317
- logits = outputs[0]
318
-
319
- loss = None
320
- if labels is not None:
321
- # Shift so that tokens < n predict n
322
- if attention_mask is not None:
323
- shift_attention_mask = attention_mask[..., 1:]
324
- shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
325
- shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
326
- else:
327
- shift_logits = logits[..., :-1, :].contiguous()
328
- shift_labels = labels[..., 1:].contiguous()
329
- # Flatten the tokens
330
- loss_fct = nn.CrossEntropyLoss()
331
- loss = loss_fct(
332
- shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
333
- )
334
-
335
- if not return_dict:
336
- output = (logits,) + outputs[1:]
337
- return (loss,) + output if loss is not None else output
338
-
339
- return MistsCausalLMOutputWithPast(
340
- loss=loss,
341
- logits=logits,
342
- past_key_values=outputs.past_key_values,
343
- hidden_states=outputs.hidden_states,
344
- attentions=outputs.attentions,
345
- )
346
-
347
- def prepare_inputs_for_generation(
348
- self, input_ids, past_key_values=None, inputs_embeds=None, time_series_values=None, attention_mask=None, **kwargs
349
- ):
350
- if past_key_values is not None:
351
- if isinstance(past_key_values, Cache):
352
- cache_length = past_key_values.get_seq_length()
353
- past_length = past_key_values.seen_tokens
354
- else:
355
- cache_length = past_length = past_key_values[0][0].shape[2]
356
-
357
- # Keep only the unprocessed tokens:
358
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
359
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
360
- # input)
361
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
362
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
363
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
364
- # input_ids based on the past_length.
365
- elif past_length < input_ids.shape[1]:
366
- input_ids = input_ids[:, past_length:]
367
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
368
- elif self.config.time_series_token_index in input_ids:
369
- input_ids = input_ids[:, input_ids.shape[1] - 1 :]
370
- # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
371
- # older attention values, as their corresponding values are not part of the input.
372
- if cache_length < past_length and attention_mask is not None:
373
- attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
374
-
375
- position_ids = kwargs.get("position_ids", None)
376
- if attention_mask is not None and position_ids is None:
377
- # create position_ids on the fly for batch generation
378
- position_ids = attention_mask.long().cumsum(-1) - 1
379
- position_ids.masked_fill_(attention_mask == 0, 1)
380
- if past_key_values:
381
- position_ids = position_ids[:, -input_ids.shape[1] :]
382
-
383
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
384
- if inputs_embeds is not None and past_key_values is None:
385
- model_inputs = {"inputs_embeds": inputs_embeds}
386
- else:
387
- model_inputs = {"input_ids": input_ids}
388
-
389
- model_inputs.update(
390
- {
391
- "position_ids": position_ids,
392
- "past_key_values": past_key_values,
393
- "use_cache": kwargs.get("use_cache"),
394
- "attention_mask": attention_mask,
395
- "time_series_values": time_series_values,
396
- }
397
- )
398
- return model_inputs
399
-
400
- def _reorder_cache(self, *args, **kwargs):
401
- return self.language_model._reorder_cache(*args, **kwargs)
402
-
403
-
404
-
405
-