jaymie23 commited on
Commit
f2f2c7f
1 Parent(s): 0c31170

Upload custom_llama.py

Browse files
Files changed (1) hide show
  1. custom_llama.py +328 -0
custom_llama.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.llama.modeling_llama import * #LLaMAModel
2
+ from typing import List, Optional, Tuple, Union
3
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
4
+ from transformers.modeling_outputs import (
5
+ BaseModelOutputWithPast,
6
+ CausalLMOutputWithPast,
7
+ QuestionAnsweringModelOutput,
8
+ SequenceClassifierOutputWithPast,
9
+ TokenClassifierOutput,
10
+ )
11
+ from transformers.utils import (
12
+ add_start_docstrings,
13
+ add_start_docstrings_to_model_forward,
14
+ is_flash_attn_greater_or_equal_2_10,
15
+ is_torchdynamo_compiling,
16
+ logging,
17
+ replace_return_docstrings,
18
+ )
19
+ from transformers.models.llama.configuration_llama import LlamaConfig
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ LLAMA_INPUTS_DOCSTRING = r"""
24
+ Args:
25
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
26
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
27
+ it.
28
+
29
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
30
+ [`PreTrainedTokenizer.__call__`] for details.
31
+
32
+ [What are input IDs?](../glossary#input-ids)
33
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
34
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
35
+
36
+ - 1 for tokens that are **not masked**,
37
+ - 0 for tokens that are **masked**.
38
+
39
+ [What are attention masks?](../glossary#attention-mask)
40
+
41
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
42
+ [`PreTrainedTokenizer.__call__`] for details.
43
+
44
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
45
+ `past_key_values`).
46
+
47
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
48
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
49
+ information on the default strategy.
50
+
51
+ - 1 indicates the head is **not masked**,
52
+ - 0 indicates the head is **masked**.
53
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
54
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
55
+ config.n_positions - 1]`.
56
+
57
+ [What are position IDs?](../glossary#position-ids)
58
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
59
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
60
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
61
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
62
+
63
+ Two formats are allowed:
64
+ - a [`~cache_utils.Cache`] instance, see our
65
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
66
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
67
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
68
+ cache format.
69
+
70
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
71
+ legacy cache format will be returned.
72
+
73
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
74
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
75
+ of shape `(batch_size, sequence_length)`.
76
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
77
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
78
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
79
+ model's internal embedding lookup matrix.
80
+ use_cache (`bool`, *optional*):
81
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
82
+ `past_key_values`).
83
+ output_attentions (`bool`, *optional*):
84
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
85
+ tensors for more detail.
86
+ output_hidden_states (`bool`, *optional*):
87
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
88
+ more detail.
89
+ return_dict (`bool`, *optional*):
90
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
91
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
92
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
93
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
94
+ the complete sequence length.
95
+ """
96
+
97
+ class CustomLLamaModel(LlamaModel):
98
+ def __init__(self, config: LlamaConfig):
99
+ super().__init__(config)
100
+ self.padding_idx = config.pad_token_id
101
+ self.vocab_size = config.vocab_size
102
+
103
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
104
+ self.layers = nn.ModuleList(
105
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
106
+ )
107
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
108
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
109
+ self.gradient_checkpointing = False
110
+
111
+ # Initialize weights and apply final processing
112
+ self.post_init()
113
+ self.num_head = 4
114
+ self.split_idx = config.split_idx
115
+ self.set_quant = True
116
+ self.quant = config.quant
117
+ if self.quant == "fp16":
118
+ self.set_quant_16()
119
+
120
+ def set_quant_16(self):
121
+
122
+ if self.set_quant == True:
123
+ for idx in range(self.split_idx,32):
124
+ self.layers[idx] = self.layers[idx].half()
125
+ self.norm = self.norm.half()
126
+
127
+ self.set_quant = False
128
+
129
+
130
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
131
+ def forward(
132
+ self,
133
+ input_ids: torch.LongTensor = None,
134
+ attention_mask: Optional[torch.Tensor] = None,
135
+ position_ids: Optional[torch.LongTensor] = None,
136
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
137
+ inputs_embeds: Optional[torch.FloatTensor] = None,
138
+ use_cache: Optional[bool] = None,
139
+ output_attentions: Optional[bool] = None,
140
+ output_hidden_states: Optional[bool] = None,
141
+ return_dict: Optional[bool] = None,
142
+ cache_position: Optional[torch.LongTensor] = None,
143
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
144
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
145
+ output_hidden_states = (
146
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
147
+ )
148
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
149
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
150
+
151
+ # if self.set_fp16 == True:
152
+ # for idx in range(16,32):
153
+ # self.layers[idx] = self.layers[idx].half()
154
+ # self.norm = self.norm.half()
155
+
156
+ # self.set_fp16 = False
157
+
158
+ if (input_ids is None) ^ (inputs_embeds is not None):
159
+ raise ValueError(
160
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
161
+ )
162
+
163
+ if self.gradient_checkpointing and self.training and use_cache:
164
+ logger.warning_once(
165
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
166
+ )
167
+ use_cache = False
168
+
169
+ if inputs_embeds is None:
170
+ inputs_embeds = self.embed_tokens(input_ids)
171
+
172
+ # kept for BC (non `Cache` `past_key_values` inputs)
173
+ return_legacy_cache = False
174
+ if use_cache and not isinstance(past_key_values, Cache):
175
+ return_legacy_cache = True
176
+ if past_key_values is None:
177
+ past_key_values = DynamicCache()
178
+ else:
179
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
180
+ logger.warning_once(
181
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
182
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
183
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
184
+ )
185
+
186
+ if cache_position is None:
187
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
188
+ cache_position = torch.arange(
189
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
190
+ )
191
+ if position_ids is None:
192
+ position_ids = cache_position.unsqueeze(0)
193
+
194
+ causal_mask = self._update_causal_mask(
195
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
196
+ )
197
+ hidden_states = inputs_embeds
198
+
199
+ # create position embeddings to be shared across the decoder layers
200
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
201
+
202
+ # decoder layers
203
+ all_hidden_states = () if output_hidden_states else None
204
+ all_self_attns = () if output_attentions else None
205
+ next_decoder_cache = None
206
+ # print(hidden_states.shape)
207
+ # print(attention_mask.shape)
208
+ # try:
209
+ # print(output_attentions.shape)
210
+ # except Exception as e:
211
+ # print(e)
212
+
213
+ for decoder_layer in self.layers[0:self.split_idx]:
214
+ if output_hidden_states:
215
+ all_hidden_states += (hidden_states,)
216
+
217
+ if self.gradient_checkpointing and self.training:
218
+ layer_outputs = self._gradient_checkpointing_func(
219
+ decoder_layer.__call__,
220
+ hidden_states,
221
+ causal_mask,
222
+ position_ids,
223
+ past_key_values,
224
+ output_attentions,
225
+ use_cache,
226
+ cache_position,
227
+ position_embeddings,
228
+ )
229
+ else:
230
+ layer_outputs = decoder_layer(
231
+ hidden_states,
232
+ attention_mask=causal_mask,
233
+ position_ids=position_ids,
234
+ past_key_value=past_key_values,
235
+ output_attentions=output_attentions,
236
+ use_cache=use_cache,
237
+ cache_position=cache_position,
238
+ position_embeddings=position_embeddings,
239
+ )
240
+
241
+ hidden_states = layer_outputs[0]
242
+
243
+ if use_cache:
244
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
245
+
246
+ if output_attentions:
247
+ all_self_attns += (layer_outputs[1],)
248
+
249
+ #################################################################
250
+ if self.quant == "fp16":
251
+ hidden_states = hidden_states.half()
252
+ position_embeddings = (position_embeddings[0].half(),position_embeddings[1].half())
253
+ # causal_mask, use_cache, cache_position, past_key_values are ignored
254
+ #################################################################
255
+ for decoder_layer in self.layers[self.split_idx:]:
256
+ if output_hidden_states:
257
+ all_hidden_states += (hidden_states,)
258
+
259
+ if self.gradient_checkpointing and self.training:
260
+ layer_outputs = self._gradient_checkpointing_func(
261
+ decoder_layer.__call__,
262
+ hidden_states,
263
+ causal_mask,
264
+ position_ids,
265
+ past_key_values,
266
+ output_attentions,
267
+ use_cache,
268
+ cache_position,
269
+ position_embeddings,
270
+ )
271
+ else:
272
+ layer_outputs = decoder_layer(
273
+ hidden_states,
274
+ attention_mask=causal_mask,
275
+ position_ids=position_ids,
276
+ past_key_value=past_key_values,
277
+ output_attentions=output_attentions,
278
+ use_cache=use_cache,
279
+ cache_position=cache_position,
280
+ position_embeddings=position_embeddings,
281
+ )
282
+
283
+ hidden_states = layer_outputs[0]
284
+
285
+ if use_cache:
286
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
287
+
288
+ if output_attentions:
289
+ all_self_attns += (layer_outputs[1],)
290
+
291
+ hidden_states = self.norm(hidden_states)
292
+
293
+ # add hidden states from the last decoder layer
294
+ if output_hidden_states:
295
+ all_hidden_states += (hidden_states,)
296
+
297
+ next_cache = next_decoder_cache if use_cache else None
298
+ if return_legacy_cache:
299
+ next_cache = next_cache.to_legacy_cache()
300
+
301
+ if not return_dict:
302
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
303
+ return BaseModelOutputWithPast(
304
+ last_hidden_state=hidden_states,
305
+ past_key_values=next_cache,
306
+ hidden_states=all_hidden_states,
307
+ attentions=all_self_attns,
308
+ )
309
+
310
+ class CustomLlamaForCausalLM(LlamaForCausalLM):
311
+ def __init__(self, config):
312
+ super().__init__(config)
313
+ self.model = CustomLLamaModel(config)
314
+ self.vocab_size = config.vocab_size
315
+ # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
316
+
317
+ # Initialize weights and apply final processing
318
+ self.post_init()
319
+ self.quant = config.quant
320
+ self.set_quant = True
321
+ if self.quant == "fp16":
322
+ self.set_quant_16()
323
+
324
+ def set_quant_16(self):
325
+ if self.set_quant == True:
326
+ self.lm_head = self.lm_head.half()
327
+ self.set_quant = False
328
+