zackli4ai commited on
Commit
c516834
1 Parent(s): a808e2a

upload registration code

Browse files
Files changed (2) hide show
  1. configuration_dolphin.py +218 -0
  2. modeling_dolphin.py +735 -0
configuration_dolphin.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Qwen2 model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+ # We can also consider to pass the encoder config dict to the Qwen2Config config as well.
23
+ encoder_config_dict = {
24
+ "_name_or_path": "alexchen4ai/Qwen2-0.5B",
25
+ "add_cross_attention": False,
26
+ "architectures": ["Qwen2ForCausalLM"],
27
+ "attention_dropout": 0.0,
28
+ "bad_words_ids": None,
29
+ "begin_suppress_tokens": None,
30
+ "bos_token_id": 151643,
31
+ "chunk_size_feed_forward": 0,
32
+ "cross_attention_hidden_size": None,
33
+ "decoder_start_token_id": None,
34
+ "diversity_penalty": 0.0,
35
+ "do_sample": False,
36
+ "early_stopping": False,
37
+ "encoder_config": None,
38
+ "encoder_no_repeat_ngram_size": 0,
39
+ "eos_token_id": 151643,
40
+ "exponential_decay_length_penalty": None,
41
+ "finetuning_task": None,
42
+ "forced_bos_token_id": None,
43
+ "forced_eos_token_id": None,
44
+ "hidden_act": "silu",
45
+ "hidden_size": 896,
46
+ "id2label": {"0": "LABEL_0", "1": "LABEL_1"},
47
+ "initializer_range": 0.02,
48
+ "intermediate_size": 4864,
49
+ "is_decoder": False,
50
+ "is_encoder_decoder": False,
51
+ "label2id": {"LABEL_0": 0, "LABEL_1": 1},
52
+ "length_penalty": 1.0,
53
+ "max_length": 20,
54
+ "max_position_embeddings": 131072,
55
+ "max_window_layers": 24,
56
+ "min_length": 0,
57
+ "model_type": "qwen2",
58
+ "no_repeat_ngram_size": 0,
59
+ "num_attention_heads": 14,
60
+ "num_beam_groups": 1,
61
+ "num_beams": 1,
62
+ "num_hidden_layers": 24,
63
+ "num_key_value_heads": 2,
64
+ "num_return_sequences": 1,
65
+ "output_attentions": False,
66
+ "output_hidden_states": False,
67
+ "output_scores": False,
68
+ "pad_token_id": None,
69
+ "prefix": None,
70
+ "problem_type": None,
71
+ "pruned_heads": {},
72
+ "remove_invalid_values": False,
73
+ "repetition_penalty": 1.0,
74
+ "return_dict": True,
75
+ "return_dict_in_generate": False,
76
+ "rms_norm_eps": 1e-06,
77
+ "rope_theta": 1000000.0,
78
+ "sep_token_id": None,
79
+ "sliding_window": 131072,
80
+ "suppress_tokens": None,
81
+ "task_specific_params": None,
82
+ "temperature": 1.0,
83
+ "tf_legacy_loss": False,
84
+ "tie_encoder_decoder": False,
85
+ "tie_word_embeddings": True,
86
+ "tokenizer_class": None,
87
+ "top_k": 50,
88
+ "top_p": 1.0,
89
+ "torch_dtype": "bfloat16",
90
+ "torchscript": False,
91
+ "typical_p": 1.0,
92
+ "use_bfloat16": False,
93
+ "use_cache": True,
94
+ "use_sliding_window": False,
95
+ "vocab_size": 151936,
96
+ "attn_implementation": None,
97
+ }
98
+
99
+
100
+ class Qwen2Config(PretrainedConfig):
101
+ r"""
102
+ This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
103
+ Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
104
+ with the defaults will yield a similar configuration to that of
105
+ Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
106
+
107
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
108
+ documentation from [`PretrainedConfig`] for more information.
109
+
110
+
111
+ Args:
112
+ vocab_size (`int`, *optional*, defaults to 151936):
113
+ Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
114
+ `inputs_ids` passed when calling [`Qwen2Model`]
115
+ hidden_size (`int`, *optional*, defaults to 4096):
116
+ Dimension of the hidden representations.
117
+ intermediate_size (`int`, *optional*, defaults to 22016):
118
+ Dimension of the MLP representations.
119
+ num_hidden_layers (`int`, *optional*, defaults to 32):
120
+ Number of hidden layers in the Transformer encoder.
121
+ num_attention_heads (`int`, *optional*, defaults to 32):
122
+ Number of attention heads for each attention layer in the Transformer encoder.
123
+ num_key_value_heads (`int`, *optional*, defaults to 32):
124
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
125
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
126
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
127
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
128
+ by meanpooling all the original heads within that group. For more details checkout [this
129
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
130
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
131
+ The non-linear activation function (function or string) in the decoder.
132
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
133
+ The maximum sequence length that this model might ever be used with.
134
+ initializer_range (`float`, *optional*, defaults to 0.02):
135
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
136
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
137
+ The epsilon used by the rms normalization layers.
138
+ use_cache (`bool`, *optional*, defaults to `True`):
139
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
140
+ relevant if `config.is_decoder=True`.
141
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
142
+ Whether the model's input and output word embeddings should be tied.
143
+ rope_theta (`float`, *optional*, defaults to 10000.0):
144
+ The base period of the RoPE embeddings.
145
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
146
+ Whether to use sliding window attention.
147
+ sliding_window (`int`, *optional*, defaults to 4096):
148
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
149
+ max_window_layers (`int`, *optional*, defaults to 28):
150
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
151
+ attention_dropout (`float`, *optional*, defaults to 0.0):
152
+ The dropout ratio for the attention probabilities.
153
+
154
+ ```python
155
+ >>> from transformers import Qwen2Model, Qwen2Config
156
+
157
+ >>> # Initializing a Qwen2 style configuration
158
+ >>> configuration = Qwen2Config()
159
+
160
+ >>> # Initializing a model from the Qwen2-7B style configuration
161
+ >>> model = Qwen2Model(configuration)
162
+
163
+ >>> # Accessing the model configuration
164
+ >>> configuration = model.config
165
+ ```"""
166
+
167
+ model_type = "qwen2"
168
+ keys_to_ignore_at_inference = ["past_key_values"]
169
+
170
+ def __init__(
171
+ self,
172
+ vocab_size=151936,
173
+ hidden_size=4096,
174
+ intermediate_size=22016,
175
+ num_hidden_layers=32,
176
+ num_attention_heads=32,
177
+ num_key_value_heads=32,
178
+ hidden_act="silu",
179
+ max_position_embeddings=32768,
180
+ initializer_range=0.02,
181
+ rms_norm_eps=1e-6,
182
+ use_cache=True,
183
+ tie_word_embeddings=False,
184
+ rope_theta=10000.0,
185
+ use_sliding_window=False,
186
+ sliding_window=4096,
187
+ max_window_layers=28,
188
+ attention_dropout=0.0,
189
+ encoder_config=None,
190
+ **kwargs,
191
+ ):
192
+ self.vocab_size = vocab_size
193
+ self.max_position_embeddings = max_position_embeddings
194
+ self.hidden_size = hidden_size
195
+ self.intermediate_size = intermediate_size
196
+ self.num_hidden_layers = num_hidden_layers
197
+ self.num_attention_heads = num_attention_heads
198
+ self.use_sliding_window = use_sliding_window
199
+ self.sliding_window = sliding_window
200
+ self.max_window_layers = max_window_layers
201
+
202
+ # for backward compatibility
203
+ if num_key_value_heads is None:
204
+ num_key_value_heads = num_attention_heads
205
+
206
+ self.num_key_value_heads = num_key_value_heads
207
+ self.hidden_act = hidden_act
208
+ self.initializer_range = initializer_range
209
+ self.rms_norm_eps = rms_norm_eps
210
+ self.use_cache = use_cache
211
+ self.rope_theta = rope_theta
212
+ self.attention_dropout = attention_dropout
213
+ self.encoder_config = encoder_config
214
+
215
+ super().__init__(
216
+ tie_word_embeddings=tie_word_embeddings,
217
+ **kwargs,
218
+ )
modeling_dolphin.py ADDED
@@ -0,0 +1,735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ AutoTokenizer, AutoModelForCausalLM, AutoConfig, logging
3
+ )
4
+ from transformers.modeling_outputs import (
5
+ BaseModelOutputWithPast,
6
+ CausalLMOutputWithPast,
7
+ SequenceClassifierOutputWithPast,
8
+ )
9
+ from transformers.utils import (ModelOutput)
10
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
11
+ from transformers.models.qwen2.modeling_qwen2 import (
12
+ Qwen2PreTrainedModel, Qwen2Model, Qwen2RMSNorm
13
+ )
14
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
15
+ import torch
16
+ import torch.nn as nn
17
+ from typing import List, Optional, Tuple, Union
18
+ import warnings
19
+ from dataclasses import dataclass
20
+ from torch.nn import CrossEntropyLoss
21
+ from .configuration_dolphin import encoder_config_dict, Qwen2Config
22
+
23
+ CONTEXT_EMB = 896 # Qwen 0.7B has dimension of 896
24
+ HIDDEN_EMB = 3584 # Qwen 7B has dimension of 3584
25
+ warnings.filterwarnings("ignore")
26
+ MEM_SIZE = 32
27
+ logger = logging.get_logger(__name__)
28
+
29
+ @dataclass
30
+ class DolphinMemoryOutput(ModelOutput):
31
+ memory_states: Optional[torch.FloatTensor] = None
32
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
33
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
34
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
35
+
36
+ class Qwen2ForMemoryOutput(Qwen2PreTrainedModel):
37
+ def __init__(self, config):
38
+ super().__init__(config)
39
+ self.num_labels = config.num_labels
40
+ self.model = Qwen2Model(config)
41
+ self.model.config.pad_token_id = self.model.config.eos_token_id
42
+
43
+ # Initialize weights and apply final processing
44
+ self.post_init()
45
+
46
+ def get_input_embeddings(self):
47
+ return self.model.embed_tokens
48
+
49
+ def set_input_embeddings(self, value):
50
+ self.model.embed_tokens = value
51
+
52
+ def forward(
53
+ self,
54
+ input_ids: torch.LongTensor = None,
55
+ attention_mask: Optional[torch.Tensor] = None,
56
+ position_ids: Optional[torch.LongTensor] = None,
57
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
58
+ inputs_embeds: Optional[torch.FloatTensor] = None,
59
+ labels: Optional[torch.LongTensor] = None,
60
+ use_cache: Optional[bool] = None,
61
+ output_attentions: Optional[bool] = None,
62
+ output_hidden_states: Optional[bool] = None,
63
+ return_dict: Optional[bool] = None,
64
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
65
+ r"""
66
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
67
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
68
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
69
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
70
+ """
71
+ return_dict = (
72
+ return_dict if return_dict is not None else self.config.use_return_dict
73
+ )
74
+ transformer_outputs = self.model(
75
+ input_ids,
76
+ attention_mask=attention_mask,
77
+ position_ids=position_ids,
78
+ past_key_values=past_key_values,
79
+ inputs_embeds=inputs_embeds,
80
+ use_cache=use_cache,
81
+ output_attentions=output_attentions,
82
+ output_hidden_states=output_hidden_states,
83
+ return_dict=return_dict,
84
+ )
85
+ hidden_states = transformer_outputs[0]
86
+
87
+ if input_ids is not None:
88
+ batch_size = input_ids.shape[0]
89
+ else:
90
+ batch_size = inputs_embeds.shape[0]
91
+
92
+ if self.config.pad_token_id is None and batch_size != 1:
93
+ raise ValueError(
94
+ "Cannot handle batch sizes > 1 if no padding token is defined."
95
+ )
96
+ if self.config.pad_token_id is None:
97
+ sequence_lengths = -1
98
+ else:
99
+ if input_ids is not None:
100
+ sequence_lengths = (
101
+ torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1)
102
+ )
103
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
104
+ sequence_lengths = sequence_lengths.to(hidden_states.device)
105
+ else:
106
+ sequence_lengths = -1
107
+
108
+ # if sequence_lengths != -1:
109
+ # assert (sequence_lengths > MEMORY_SIZE).all(), "All sequences must be longer than MEMORY_SIZE"
110
+
111
+ MEMORY_SIZE = 32
112
+ batch_range = torch.arange(batch_size, device=hidden_states.device)
113
+ start_indices = sequence_lengths - MEMORY_SIZE
114
+ # print(sequence_lengths)
115
+ # print(torch.arange(MEMORY_SIZE, device=hidden_states.device)[None, :] + start_indices[:, None])
116
+ memory_states = hidden_states[
117
+ batch_range[:, None],
118
+ torch.arange(MEMORY_SIZE, device=hidden_states.device)[None, :]
119
+ + start_indices[:, None],
120
+ ]
121
+
122
+ return DolphinMemoryOutput(
123
+ memory_states=memory_states,
124
+ past_key_values=transformer_outputs.past_key_values,
125
+ hidden_states=transformer_outputs.hidden_states,
126
+ attentions=transformer_outputs.attentions,
127
+ )
128
+
129
+
130
+ class Projector(nn.Module):
131
+ def __init__(self, context_dim: int, hidden_dim: int, projection_cls="linear"):
132
+ super().__init__()
133
+ self.projection_cls = projection_cls
134
+ if projection_cls == "linear":
135
+ self.context_projection = nn.Linear(context_dim, hidden_dim)
136
+ elif projection_cls == "mlp":
137
+ dim_projection = hidden_dim
138
+ depth = 2
139
+ layers = [
140
+ nn.Linear(context_dim, dim_projection),
141
+ ]
142
+ for _ in range(1, depth):
143
+ layers.extend(
144
+ [
145
+ nn.GELU(),
146
+ nn.Linear(dim_projection, dim_projection),
147
+ ]
148
+ )
149
+ self.context_projection = nn.Sequential(*layers)
150
+ else:
151
+ raise ValueError(f"Projection class {projection_cls} not supported")
152
+
153
+ def forward(self, x):
154
+ if self.projection_cls == "linear":
155
+ return self.context_projection(x)
156
+
157
+ for layer in self.context_projection:
158
+ x = layer(x)
159
+ return x
160
+
161
+ class ContextEmbd(nn.Module):
162
+ def __init__(
163
+ self, config, context_dim, hidden_dim, MEM_SIZE=32, torch_dtype=torch.bfloat16
164
+ ):
165
+ super().__init__()
166
+ self.encoder = Qwen2ForMemoryOutput(config).to(torch_dtype)
167
+ self.projector = Projector(context_dim, hidden_dim).to(torch_dtype)
168
+ self.MEM_SIZE = MEM_SIZE
169
+
170
+ def forward(self, context_input_ids, context_attention_mask=None):
171
+ memory_slot = self.encoder(
172
+ context_input_ids, context_attention_mask, output_hidden_states=True
173
+ ).memory_states
174
+
175
+ # now project the memory slot into token space
176
+ return self.projector(memory_slot)
177
+
178
+ class DolphinModel(Qwen2PreTrainedModel):
179
+ """
180
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
181
+
182
+ Args:
183
+ config: DolphinModel
184
+ """
185
+
186
+ def __init__(self, config: Qwen2Config):
187
+ super().__init__(config)
188
+ self.padding_idx = config.pad_token_id
189
+ self.vocab_size = config.vocab_size
190
+
191
+ self.embed_tokens = nn.Embedding(
192
+ config.vocab_size, config.hidden_size, self.padding_idx
193
+ )
194
+ self.layers = nn.ModuleList(
195
+ [
196
+ Qwen2DecoderLayer(config, layer_idx)
197
+ for layer_idx in range(config.num_hidden_layers)
198
+ ]
199
+ )
200
+ self._attn_implementation = config._attn_implementation
201
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
202
+ self.gradient_checkpointing = False
203
+
204
+ if not config.encoder_config:
205
+ raise ValueError("Please provide the encoder config")
206
+ self.encoder_config = Qwen2Config.from_dict(config.encoder_config)
207
+ self.context_encoder = ContextEmbd(
208
+ config=self.encoder_config, context_dim=CONTEXT_EMB, hidden_dim=HIDDEN_EMB
209
+ )
210
+
211
+ # Initialize weights and apply final processing
212
+ self.post_init()
213
+
214
+ def get_input_embeddings(self):
215
+ return self.embed_tokens
216
+
217
+ def set_input_embeddings(self, value):
218
+ self.embed_tokens = value
219
+
220
+ # We assume there is only on context, and this function can only support one context
221
+ def get_token_embebddings_context(
222
+ self,
223
+ input_ids: torch.LongTensor,
224
+ context_input_ids: torch.LongTensor,
225
+ context_attention_mask: torch.LongTensor,
226
+ ) -> torch.FloatTensor:
227
+ # The size is batch_size x memory_size x hidden_dim
228
+ context_emb = self.context_encoder(context_input_ids, context_attention_mask)
229
+
230
+ # Create embeddings for regular tokens
231
+ embed_input_ids = input_ids.clone()
232
+ embed_input_ids[embed_input_ids < 0] = (
233
+ 0 # Replace negative values with 0 for embedding
234
+ )
235
+ hidden_states = self.embed_tokens(embed_input_ids)
236
+
237
+ batch_size, seq_len, hidden_dim = hidden_states.shape
238
+ _, memory_size, _ = context_emb.shape
239
+
240
+ # Find the start positions of -1 sequences
241
+ mask = input_ids == -1
242
+ starts = torch.where(mask[:, :-1] < mask[:, 1:])[1]
243
+
244
+ # Replace -1 spans with context embeddings
245
+ for i in range(batch_size):
246
+ for start in starts:
247
+ if start + memory_size <= seq_len:
248
+ hidden_states[i, start : start + memory_size] = context_emb[i]
249
+
250
+ return hidden_states
251
+
252
+ def forward(
253
+ self,
254
+ input_ids: torch.LongTensor = None,
255
+ attention_mask: Optional[torch.Tensor] = None,
256
+ context_input_ids: Optional[torch.LongTensor] = None,
257
+ context_attention_mask: Optional[torch.Tensor] = None,
258
+ position_ids: Optional[torch.LongTensor] = None,
259
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
260
+ inputs_embeds: Optional[torch.FloatTensor] = None,
261
+ use_cache: Optional[bool] = None,
262
+ output_attentions: Optional[bool] = None,
263
+ output_hidden_states: Optional[bool] = None,
264
+ return_dict: Optional[bool] = None,
265
+ cache_position: Optional[torch.LongTensor] = None,
266
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
267
+ output_attentions = (
268
+ output_attentions
269
+ if output_attentions is not None
270
+ else self.config.output_attentions
271
+ )
272
+ output_hidden_states = (
273
+ output_hidden_states
274
+ if output_hidden_states is not None
275
+ else self.config.output_hidden_states
276
+ )
277
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
278
+
279
+ return_dict = (
280
+ return_dict if return_dict is not None else self.config.use_return_dict
281
+ )
282
+
283
+ if (input_ids is None) ^ (inputs_embeds is not None):
284
+ raise ValueError(
285
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
286
+ )
287
+
288
+ if self.gradient_checkpointing and self.training:
289
+ if use_cache:
290
+ logger.warning_once(
291
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
292
+ )
293
+ use_cache = False
294
+
295
+ use_legacy_cache = False
296
+ if use_cache and not isinstance(past_key_values, Cache):
297
+ use_legacy_cache = True
298
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
299
+ logger.warning_once(
300
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
301
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
302
+ )
303
+
304
+ if inputs_embeds is None:
305
+ if context_input_ids is not None:
306
+ assert (
307
+ context_attention_mask is not None
308
+ ), "You have to provide the context_attention_mask"
309
+ inputs_embeds = self.get_token_embebddings_context(
310
+ input_ids, context_input_ids, context_attention_mask
311
+ )
312
+ else:
313
+ inputs_embeds = self.embed_tokens(input_ids)
314
+
315
+ # We need to update the attention mask if the attention mask is provided
316
+ # if attention_mask is not None:
317
+ # MEMORY_SIZE = 32
318
+ # batch_size = inputs_embeds.shape[0]
319
+ # attention_mask = torch.cat(
320
+ # (torch.ones(batch_size, MEMORY_SIZE, device=inputs_embeds.device), attention_mask),
321
+ # dim=1,
322
+ # ).to(attention_mask.dtype).to(attention_mask.device)
323
+
324
+ if cache_position is None:
325
+ past_seen_tokens = (
326
+ past_key_values.get_seq_length() if past_key_values is not None else 0
327
+ )
328
+ cache_position = torch.arange(
329
+ past_seen_tokens,
330
+ past_seen_tokens + inputs_embeds.shape[1],
331
+ device=inputs_embeds.device,
332
+ )
333
+ if position_ids is None:
334
+ position_ids = cache_position.unsqueeze(0)
335
+
336
+ causal_mask = self._update_causal_mask(
337
+ attention_mask,
338
+ inputs_embeds,
339
+ cache_position,
340
+ past_key_values,
341
+ output_attentions,
342
+ )
343
+
344
+ hidden_states = inputs_embeds
345
+
346
+ # decoder layers
347
+ all_hidden_states = () if output_hidden_states else None
348
+ all_self_attns = () if output_attentions else None
349
+ next_decoder_cache = None
350
+
351
+ for decoder_layer in self.layers:
352
+ if output_hidden_states:
353
+ all_hidden_states += (hidden_states,)
354
+
355
+ if self.gradient_checkpointing and self.training:
356
+ layer_outputs = self._gradient_checkpointing_func(
357
+ decoder_layer.__call__,
358
+ hidden_states,
359
+ causal_mask,
360
+ position_ids,
361
+ past_key_values,
362
+ output_attentions,
363
+ use_cache,
364
+ cache_position,
365
+ )
366
+ else:
367
+ layer_outputs = decoder_layer(
368
+ hidden_states,
369
+ attention_mask=causal_mask,
370
+ position_ids=position_ids,
371
+ past_key_value=past_key_values,
372
+ output_attentions=output_attentions,
373
+ use_cache=use_cache,
374
+ cache_position=cache_position,
375
+ )
376
+
377
+ hidden_states = layer_outputs[0]
378
+
379
+ if use_cache:
380
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
381
+
382
+ if output_attentions:
383
+ all_self_attns += (layer_outputs[1],)
384
+
385
+ hidden_states = self.norm(hidden_states)
386
+
387
+ # add hidden states from the last decoder layer
388
+ if output_hidden_states:
389
+ all_hidden_states += (hidden_states,)
390
+
391
+ next_cache = None
392
+ if use_cache:
393
+ next_cache = (
394
+ next_decoder_cache.to_legacy_cache()
395
+ if use_legacy_cache
396
+ else next_decoder_cache
397
+ )
398
+
399
+ if not return_dict:
400
+ return tuple(
401
+ v
402
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
403
+ if v is not None
404
+ )
405
+ return BaseModelOutputWithPast(
406
+ last_hidden_state=hidden_states,
407
+ past_key_values=next_cache,
408
+ hidden_states=all_hidden_states,
409
+ attentions=all_self_attns,
410
+ )
411
+
412
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
413
+ def _update_causal_mask(
414
+ self,
415
+ attention_mask: torch.Tensor,
416
+ input_tensor: torch.Tensor,
417
+ cache_position: torch.Tensor,
418
+ past_key_values: Cache,
419
+ output_attentions: bool,
420
+ ):
421
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
422
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
423
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
424
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
425
+
426
+ if self.config._attn_implementation == "flash_attention_2":
427
+ if attention_mask is not None and 0.0 in attention_mask:
428
+ return attention_mask
429
+ return None
430
+
431
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
432
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
433
+ # to infer the attention mask.
434
+ past_seen_tokens = (
435
+ past_key_values.get_seq_length() if past_key_values is not None else 0
436
+ )
437
+ using_static_cache = isinstance(past_key_values, StaticCache)
438
+
439
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
440
+ if (
441
+ self.config._attn_implementation == "sdpa"
442
+ and not using_static_cache
443
+ and not output_attentions
444
+ ):
445
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
446
+ attention_mask,
447
+ inputs_embeds=input_tensor,
448
+ past_key_values_length=past_seen_tokens,
449
+ is_training=self.training,
450
+ ):
451
+ return None
452
+
453
+ dtype, device = input_tensor.dtype, input_tensor.device
454
+ min_dtype = torch.finfo(dtype).min
455
+ sequence_length = input_tensor.shape[1]
456
+ if using_static_cache:
457
+ target_length = past_key_values.get_max_length()
458
+ else:
459
+ target_length = (
460
+ attention_mask.shape[-1]
461
+ if isinstance(attention_mask, torch.Tensor)
462
+ else past_seen_tokens + sequence_length + 1
463
+ )
464
+
465
+ if attention_mask is not None and attention_mask.dim() == 4:
466
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
467
+ if attention_mask.max() != 0:
468
+ raise ValueError(
469
+ "Custom 4D attention mask should be passed in inverted form with max==0`"
470
+ )
471
+ causal_mask = attention_mask
472
+ else:
473
+ causal_mask = torch.full(
474
+ (sequence_length, target_length),
475
+ fill_value=min_dtype,
476
+ dtype=dtype,
477
+ device=device,
478
+ )
479
+ if sequence_length != 1:
480
+ causal_mask = torch.triu(causal_mask, diagonal=1)
481
+ causal_mask *= torch.arange(
482
+ target_length, device=device
483
+ ) > cache_position.reshape(-1, 1)
484
+ causal_mask = causal_mask[None, None, :, :].expand(
485
+ input_tensor.shape[0], 1, -1, -1
486
+ )
487
+ if attention_mask is not None:
488
+ causal_mask = (
489
+ causal_mask.clone()
490
+ ) # copy to contiguous memory for in-place edit
491
+ mask_length = attention_mask.shape[-1]
492
+ padding_mask = (
493
+ causal_mask[:, :, :, :mask_length]
494
+ + attention_mask[:, None, None, :]
495
+ )
496
+ padding_mask = padding_mask == 0
497
+ causal_mask[:, :, :, :mask_length] = causal_mask[
498
+ :, :, :, :mask_length
499
+ ].masked_fill(padding_mask, min_dtype)
500
+ if (
501
+ self.config._attn_implementation == "sdpa"
502
+ and attention_mask is not None
503
+ and attention_mask.device.type == "cuda"
504
+ and not output_attentions
505
+ ):
506
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
507
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
508
+ # Details: https://github.com/pytorch/pytorch/issues/110213
509
+ causal_mask = AttentionMaskConverter._unmask_unattended(
510
+ causal_mask, min_dtype
511
+ )
512
+
513
+ return causal_mask
514
+
515
+
516
+ class DolphinForCausalLM(Qwen2PreTrainedModel):
517
+ _tied_weights_keys = ["lm_head.weight"]
518
+
519
+ def __init__(self, config):
520
+ super().__init__(config)
521
+ self.model = DolphinModel(config)
522
+ self.vocab_size = config.vocab_size
523
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
524
+
525
+ # Initialize weights and apply final processing
526
+ self.post_init()
527
+
528
+ def get_input_embeddings(self):
529
+ return self.model.embed_tokens
530
+
531
+ def set_input_embeddings(self, value):
532
+ self.model.embed_tokens = value
533
+
534
+ def get_output_embeddings(self):
535
+ return self.lm_head
536
+
537
+ def set_output_embeddings(self, new_embeddings):
538
+ self.lm_head = new_embeddings
539
+
540
+ def set_decoder(self, decoder):
541
+ self.model = decoder
542
+
543
+ def get_decoder(self):
544
+ return self.model
545
+
546
+ def forward(
547
+ self,
548
+ input_ids: torch.LongTensor = None,
549
+ attention_mask: Optional[torch.Tensor] = None,
550
+ context_input_ids: Optional[torch.LongTensor] = None,
551
+ context_attention_mask: Optional[torch.Tensor] = None,
552
+ position_ids: Optional[torch.LongTensor] = None,
553
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
554
+ inputs_embeds: Optional[torch.FloatTensor] = None,
555
+ labels: Optional[torch.LongTensor] = None,
556
+ use_cache: Optional[bool] = None,
557
+ output_attentions: Optional[bool] = None,
558
+ output_hidden_states: Optional[bool] = None,
559
+ return_dict: Optional[bool] = None,
560
+ cache_position: Optional[torch.LongTensor] = None,
561
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
562
+ r"""
563
+ Args:
564
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
565
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
566
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
567
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
568
+ ```"""
569
+
570
+ output_attentions = (
571
+ output_attentions
572
+ if output_attentions is not None
573
+ else self.config.output_attentions
574
+ )
575
+ output_hidden_states = (
576
+ output_hidden_states
577
+ if output_hidden_states is not None
578
+ else self.config.output_hidden_states
579
+ )
580
+ return_dict = (
581
+ return_dict if return_dict is not None else self.config.use_return_dict
582
+ )
583
+
584
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
585
+ outputs = self.model(
586
+ input_ids=input_ids,
587
+ attention_mask=attention_mask,
588
+ context_input_ids=context_input_ids,
589
+ context_attention_mask=context_attention_mask,
590
+ position_ids=position_ids,
591
+ past_key_values=past_key_values,
592
+ inputs_embeds=inputs_embeds,
593
+ use_cache=use_cache,
594
+ output_attentions=output_attentions,
595
+ output_hidden_states=output_hidden_states,
596
+ return_dict=return_dict,
597
+ cache_position=cache_position,
598
+ )
599
+
600
+ hidden_states = outputs[0]
601
+ logits = self.lm_head(hidden_states)
602
+ logits = logits.float()
603
+
604
+ loss = None
605
+ if labels is not None:
606
+ # Shift so that tokens < n predict n
607
+ shift_logits = logits[..., :-1, :].contiguous()
608
+ shift_labels = labels[..., 1:].contiguous()
609
+ # Flatten the tokens
610
+ loss_fct = CrossEntropyLoss()
611
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
612
+ shift_labels = shift_labels.view(-1)
613
+ # Enable model parallelism
614
+ shift_labels = shift_labels.to(shift_logits.device)
615
+ loss = loss_fct(shift_logits, shift_labels)
616
+
617
+ if not return_dict:
618
+ output = (logits,) + outputs[1:]
619
+ return (loss,) + output if loss is not None else output
620
+
621
+ return CausalLMOutputWithPast(
622
+ loss=loss,
623
+ logits=logits,
624
+ past_key_values=outputs.past_key_values,
625
+ hidden_states=outputs.hidden_states,
626
+ attentions=outputs.attentions,
627
+ )
628
+
629
+ def prepare_inputs_for_generation(
630
+ self,
631
+ input_ids,
632
+ past_key_values=None,
633
+ attention_mask=None,
634
+ inputs_embeds=None,
635
+ cache_position=None,
636
+ use_cache=True,
637
+ **kwargs,
638
+ ):
639
+ past_length = 0
640
+ # Omit tokens covered by past_key_values
641
+ if past_key_values is not None:
642
+ # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
643
+ past_length = (
644
+ cache_position[0]
645
+ if cache_position is not None
646
+ else past_key_values.get_seq_length()
647
+ )
648
+ max_cache_length = (
649
+ torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
650
+ if past_key_values.get_max_length() is not None
651
+ else None
652
+ )
653
+ cache_length = (
654
+ past_length
655
+ if max_cache_length is None
656
+ else torch.min(max_cache_length, past_length)
657
+ )
658
+
659
+ # Keep only the unprocessed tokens:
660
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
661
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
662
+ # input)
663
+ if (
664
+ attention_mask is not None
665
+ and attention_mask.shape[1] > input_ids.shape[1]
666
+ ):
667
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
668
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
669
+ # input_ids based on the past_length.
670
+ elif past_length < input_ids.shape[1]:
671
+ input_ids = input_ids[:, past_length:]
672
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
673
+
674
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
675
+ if (
676
+ max_cache_length is not None
677
+ and attention_mask is not None
678
+ and cache_length + input_ids.shape[1] > max_cache_length
679
+ ):
680
+ attention_mask = attention_mask[:, -max_cache_length:]
681
+
682
+ position_ids = kwargs.get("position_ids", None)
683
+ if attention_mask is not None and position_ids is None:
684
+ # create position_ids on the fly for batch generation
685
+ position_ids = attention_mask.long().cumsum(-1) - 1
686
+ position_ids.masked_fill_(attention_mask == 0, 1)
687
+ if past_key_values:
688
+ position_ids = position_ids[:, -input_ids.shape[1] :]
689
+
690
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
691
+ if inputs_embeds is not None and past_length == 0:
692
+ model_inputs = {"inputs_embeds": inputs_embeds}
693
+ else:
694
+ model_inputs = {"input_ids": input_ids}
695
+
696
+ input_length = (
697
+ position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
698
+ )
699
+ if cache_position is None:
700
+ cache_position = torch.arange(
701
+ past_length, past_length + input_length, device=input_ids.device
702
+ )
703
+ elif use_cache:
704
+ cache_position = cache_position[-input_length:]
705
+
706
+ model_inputs.update(
707
+ {
708
+ "position_ids": position_ids,
709
+ "past_key_values": past_key_values,
710
+ "use_cache": use_cache,
711
+ "attention_mask": attention_mask,
712
+ "cache_position": cache_position,
713
+ }
714
+ )
715
+ return model_inputs
716
+
717
+ @staticmethod
718
+ def _reorder_cache(past_key_values, beam_idx):
719
+ reordered_past = ()
720
+ for layer_past in past_key_values:
721
+ reordered_past += (
722
+ tuple(
723
+ past_state.index_select(0, beam_idx.to(past_state.device))
724
+ for past_state in layer_past
725
+ ),
726
+ )
727
+ return reordered_past
728
+
729
+ if __name__ == "__main__":
730
+ config = Qwen2Config(encoder_config=encoder_config_dict)
731
+ dolphin_model = DolphinModel(config)
732
+ # AutoConfig.register("dolphin", Qwen2Config)
733
+ AutoModelForCausalLM.register(Qwen2Config, DolphinForCausalLM)
734
+ tokenizer = AutoTokenizer.from_pretrained('nexa-collaboration/dolphin_instruct_1M_0805', trust_remote_code=True)
735
+ model = AutoModelForCausalLM.from_pretrained('nexa-collaboration/dolphin_instruct_1M_0805', trust_remote_code=True)