awni commited on
Commit
07a848a
1 Parent(s): 423f068

Upload folder using huggingface_hub

Browse files
config.json CHANGED
@@ -29,7 +29,7 @@
29
  "rms_norm_eps": 1e-05,
30
  "rope_scaling": null,
31
  "rope_theta": 10000.0,
32
- "sliding_window": 2047,
33
  "tie_word_embeddings": false,
34
  "torch_dtype": "bfloat16",
35
  "transformers_version": "4.39.3",
 
29
  "rms_norm_eps": 1e-05,
30
  "rope_scaling": null,
31
  "rope_theta": 10000.0,
32
+ "sliding_window": 2048,
33
  "tie_word_embeddings": false,
34
  "torch_dtype": "bfloat16",
35
  "transformers_version": "4.39.3",
configuration_phi3.py CHANGED
@@ -83,10 +83,12 @@ class Phi3Config(PretrainedConfig):
83
  rope_theta (`float`, *optional*, defaults to 10000.0):
84
  The base period of the RoPE embeddings.
85
  rope_scaling (`dict`, *optional*):
86
- The scaling factor for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
87
- contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and
88
  the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
89
  divided by the number of attention heads divided by 2.
 
 
90
  eos_token_id (`int`, *optional*, defaults to 32000):
91
  The id of the "end-of-sequence" token.
92
  pad_token_id (`int`, *optional*, defaults to 32000):
@@ -132,6 +134,7 @@ class Phi3Config(PretrainedConfig):
132
  tie_word_embeddings=False,
133
  rope_theta=10000.0,
134
  rope_scaling=None,
 
135
  eos_token_id=32000,
136
  pad_token_id=32000,
137
  sliding_window=None,
@@ -158,9 +161,11 @@ class Phi3Config(PretrainedConfig):
158
  self.use_cache = use_cache
159
  self.rope_theta = rope_theta
160
  self.rope_scaling = rope_scaling
 
161
  self.sliding_window = sliding_window
162
 
163
  super().__init__(
 
164
  eos_token_id=eos_token_id,
165
  pad_token_id=pad_token_id,
166
  tie_word_embeddings=tie_word_embeddings,
@@ -168,33 +173,41 @@ class Phi3Config(PretrainedConfig):
168
  )
169
 
170
  def _rope_scaling_validation(self):
 
 
 
171
  if self.rope_scaling is None:
172
  return
173
 
174
- assert (
175
- (isinstance(self.rope_scaling, dict))
176
- and ("type" in self.rope_scaling)
177
- and ("short_factor" in self.rope_scaling)
178
- and ("long_factor" in self.rope_scaling)
179
- ), (
180
- "`rope_scaling` must be a dictionary with three keys: `type`, `short_factor` and `long_factor`, "
181
- f"got {self.rope_scaling}."
182
- )
183
-
184
- assert self.rope_scaling["type"].lower() == "longrope", "RoPE scaling type must be `longrope`."
185
-
186
- short_factor = self.rope_scaling["short_factor"]
187
- assert isinstance(short_factor, list) and all(
188
- isinstance(x, (int, float)) for x in short_factor
189
- ), f"RoPE scaling factor must be a list of numbers, got {short_factor}."
190
- assert (
191
- len(short_factor) == self.hidden_size // self.num_attention_heads // 2
192
- ), f"Length of RoPE scaling factor must be half of the attention head, got {short_factor}."
193
-
194
- long_factor = self.rope_scaling["long_factor"]
195
- assert isinstance(long_factor, list) and all(
196
- isinstance(x, (int, float)) for x in long_factor
197
- ), f"RoPE scaling factor must be a list of numbers, got {long_factor}."
198
- assert (
199
- len(long_factor) == self.hidden_size // self.num_attention_heads // 2
200
- ), f"Length of RoPE scaling factor must be half of the attention head, got {long_factor}."
 
 
 
 
 
 
83
  rope_theta (`float`, *optional*, defaults to 10000.0):
84
  The base period of the RoPE embeddings.
85
  rope_scaling (`dict`, *optional*):
86
+ The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
87
+ contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
88
  the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
89
  divided by the number of attention heads divided by 2.
90
+ bos_token_id (`int`, *optional*, defaults to 1):
91
+ The id of the "beginning-of-sequence" token.
92
  eos_token_id (`int`, *optional*, defaults to 32000):
93
  The id of the "end-of-sequence" token.
94
  pad_token_id (`int`, *optional*, defaults to 32000):
 
134
  tie_word_embeddings=False,
135
  rope_theta=10000.0,
136
  rope_scaling=None,
137
+ bos_token_id=1,
138
  eos_token_id=32000,
139
  pad_token_id=32000,
140
  sliding_window=None,
 
161
  self.use_cache = use_cache
162
  self.rope_theta = rope_theta
163
  self.rope_scaling = rope_scaling
164
+ self._rope_scaling_validation()
165
  self.sliding_window = sliding_window
166
 
167
  super().__init__(
168
+ bos_token_id=bos_token_id,
169
  eos_token_id=eos_token_id,
170
  pad_token_id=pad_token_id,
171
  tie_word_embeddings=tie_word_embeddings,
 
173
  )
174
 
175
  def _rope_scaling_validation(self):
176
+ """
177
+ Validate the `rope_scaling` configuration.
178
+ """
179
  if self.rope_scaling is None:
180
  return
181
 
182
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
183
+ raise ValueError(
184
+ "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, "
185
+ f"got {self.rope_scaling}"
186
+ )
187
+ rope_scaling_type = self.rope_scaling.get("type", None)
188
+ rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
189
+ rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
190
+ if rope_scaling_type is None or rope_scaling_type not in ["su", "yarn"]:
191
+ raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}")
192
+ if not (
193
+ isinstance(rope_scaling_short_factor, list)
194
+ and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
195
+ ):
196
+ raise ValueError(
197
+ f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
198
+ )
199
+ if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
200
+ raise ValueError(
201
+ f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
202
+ )
203
+ if not (
204
+ isinstance(rope_scaling_long_factor, list)
205
+ and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
206
+ ):
207
+ raise ValueError(
208
+ f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
209
+ )
210
+ if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
211
+ raise ValueError(
212
+ f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
213
+ )
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:608877a3271792ffd6986f15cfb08c86008bd26e2562e344c57fd574213888d5
3
  size 2291290600
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e908f149cf6056b2f8fee5e1443cdae521be06558907eb952fbd5f383ad533b8
3
  size 2291290600
modeling_phi3.py CHANGED
@@ -40,6 +40,7 @@ from transformers.utils import (
40
  add_code_sample_docstrings,
41
  add_start_docstrings,
42
  add_start_docstrings_to_model_forward,
 
43
  is_flash_attn_greater_or_equal_2_10,
44
  logging,
45
  replace_return_docstrings,
@@ -54,26 +55,17 @@ logger = logging.get_logger(__name__)
54
  _flash_supports_window_size = False
55
  try:
56
  from flash_attn import flash_attn_func, flash_attn_varlen_func
 
57
 
58
  _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
59
-
60
- if not _flash_supports_window_size:
61
- raise ValueError("Please update flash-attention to support window size.")
62
-
63
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
64
- from flash_attn.ops.activations import swiglu
65
- from flash_attn.ops.rms_norm import RMSNorm as Phi3FlashRMSNorm
66
- # else:
67
  except ImportError as error:
68
  logger.warning(
69
- f"Flash Attention or Flash Attention Submodules not found, consider installing for better performance: {error}."
70
  )
71
  if not _flash_supports_window_size:
72
  logger.warning(
73
- "This version of flash does not support window size. Please use `attn_implementation='eager'` or upgrade flash-attn library."
74
  )
75
- swiglu = None
76
- Phi3FlashRMSNorm = None
77
 
78
  _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
79
  _CONFIG_FOR_DOC = "Phi3Config"
@@ -103,9 +95,6 @@ class Phi3RMSNorm(nn.Module):
103
  return self.weight * hidden_states.to(input_dtype)
104
 
105
 
106
- PHI3_NORM_CLASS = Phi3RMSNorm if Phi3FlashRMSNorm is None else Phi3FlashRMSNorm
107
-
108
-
109
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
110
  def _get_unpad_data(attention_mask):
111
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@@ -119,7 +108,7 @@ def _get_unpad_data(attention_mask):
119
  )
120
 
121
 
122
- # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi3
123
  class Phi3RotaryEmbedding(nn.Module):
124
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
125
  super().__init__()
@@ -127,98 +116,109 @@ class Phi3RotaryEmbedding(nn.Module):
127
  self.dim = dim
128
  self.max_position_embeddings = max_position_embeddings
129
  self.base = base
130
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
131
- self.register_buffer("inv_freq", inv_freq, persistent=False)
132
 
133
- # Build here to make `torch.jit.trace` work.
134
- self._set_cos_sin_cache(
135
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
136
- )
137
-
138
- def _set_cos_sin_cache(self, seq_len, device, dtype):
139
- self.max_seq_len_cached = seq_len
140
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
141
-
142
- freqs = torch.outer(t, self.inv_freq)
143
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
144
- emb = torch.cat((freqs, freqs), dim=-1)
145
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
146
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
147
-
148
- def forward(self, x, seq_len=None):
149
  # x: [bs, num_attention_heads, seq_len, head_size]
150
- if seq_len > self.max_seq_len_cached:
151
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- return (
154
- self.cos_cached[:seq_len].to(dtype=x.dtype),
155
- self.sin_cached[:seq_len].to(dtype=x.dtype),
156
- )
 
 
 
157
 
 
 
158
 
159
- class Phi3LongScaledRotaryEmbedding(nn.Module):
160
- def __init__(
161
- self,
162
- dim,
163
- short_factor,
164
- long_factor,
165
- max_position_embeddings=4096,
166
- original_max_position_embeddings=4096,
167
- base=10000,
168
- magnitude_scaling_policy="su",
169
- ):
170
- super().__init__()
171
 
172
- self.dim = dim
173
- self.max_position_embeddings = max_position_embeddings
174
- self.original_max_position_embeddings = original_max_position_embeddings
175
- self.base = base
 
 
 
176
 
177
- if magnitude_scaling_policy == "su":
178
- self._calc_mscale = self._calc_mscale_su
179
- elif magnitude_scaling_policy == "yarn":
180
- self._calc_mscale = self._calc_mscale_yarn
181
- else:
182
- self._calc_mscale = lambda scale: float(scale)
183
 
184
- self.short_factor = short_factor
185
- self.long_factor = long_factor
 
186
 
187
- def _calc_mscale_su(self, scale):
188
- if scale <= 1.0:
189
- return 1.0
190
- return math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
191
 
192
- def _calc_mscale_yarn(self, scale):
193
- if scale <= 1.0:
194
- return 1.0
195
- return 0.1 * math.log(scale) + 1.0
196
 
197
- @torch.no_grad()
198
- def forward(self, x, seq_len=None):
199
- if seq_len is None:
200
- seq_len = x.shape[-2]
201
- t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
202
 
 
 
 
203
  if seq_len > self.original_max_position_embeddings:
204
- t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
205
- rescale_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
206
  else:
207
- t = torch.arange(self.original_max_position_embeddings, device=x.device, dtype=torch.float32)
208
- rescale_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
209
- assert rescale_factors.shape == (
210
- self.dim // 2,
211
- ), f"misaligned shape for LongRoPE rescale factors: {rescale_factors.shape}"
212
-
213
- inv_freq = 1.0 / (
214
- rescale_factors * (self.base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
215
- )
216
 
217
- freqs = torch.outer(t, inv_freq)
218
- mscale = self._calc_mscale(self.max_position_embeddings / self.original_max_position_embeddings)
219
- emb = torch.cat((freqs, freqs), dim=-1)
220
 
221
- return (emb.cos() * mscale).to(x.dtype), (emb.sin() * mscale).to(x.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
 
224
  # Copied from transformers.models.llama.modeling_llama.rotate_half
@@ -229,7 +229,8 @@ def rotate_half(x):
229
  return torch.cat((-x2, x1), dim=-1)
230
 
231
 
232
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
 
233
  """Applies Rotary Position Embedding to the query and key tensors.
234
 
235
  Args:
@@ -237,9 +238,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
237
  k (`torch.Tensor`): The key tensor.
238
  cos (`torch.Tensor`): The cosine part of the rotary embedding.
239
  sin (`torch.Tensor`): The sine part of the rotary embedding.
240
- position_ids (`torch.Tensor`):
241
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
242
- used to pass offsetted position ids when working with a KV-cache.
243
  unsqueeze_dim (`int`, *optional*, defaults to 1):
244
  The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
245
  sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -250,27 +250,14 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
250
  Returns:
251
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
252
  """
253
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
254
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
255
- # Need fp32 here to match logits
256
- q_embed = (q.to(dtype=torch.float32) * cos.to(dtype=torch.float32)) + (
257
- rotate_half(q).to(dtype=torch.float32) * sin.to(dtype=torch.float32)
258
- )
259
- k_embed = (k.to(dtype=torch.float32) * cos.to(dtype=torch.float32)) + (
260
- rotate_half(k).to(dtype=torch.float32) * sin.to(dtype=torch.float32)
261
- )
262
- return q_embed.to(q.dtype), k_embed.to(k.dtype)
263
 
264
 
265
  class Phi3MLP(nn.Module):
266
- """Gated Linear Unit.
267
-
268
- Reference:
269
- Language Modeling with Gated Convolutional Networks.
270
- https://arxiv.org/pdf/1612.08083v3.pdf.
271
-
272
- """
273
-
274
  def __init__(self, config):
275
  super().__init__()
276
 
@@ -281,17 +268,12 @@ class Phi3MLP(nn.Module):
281
  self.activation_fn = ACT2FN[config.hidden_act]
282
 
283
  def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
284
- y = self.gate_up_proj(hidden_states)
285
 
286
- # Special case for SwiGLU
287
- if self.config.hidden_act == "silu" and swiglu is not None:
288
- gate, y = y.chunk(2, dim=-1)
289
- y = swiglu(gate, y)
290
- else:
291
- gate, y = y.chunk(2, dim=-1)
292
- y = y * self.activation_fn(gate)
293
 
294
- return self.down_proj(y)
295
 
296
 
297
  # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
@@ -341,9 +323,10 @@ class Phi3Attention(nn.Module):
341
 
342
  op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
343
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
344
-
345
  self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False)
 
346
 
 
347
  if self.rope_scaling is None:
348
  self.rotary_emb = Phi3RotaryEmbedding(
349
  self.head_dim,
@@ -351,17 +334,13 @@ class Phi3Attention(nn.Module):
351
  base=self.rope_theta,
352
  )
353
  else:
354
- self.rotary_emb = Phi3LongScaledRotaryEmbedding(
355
- self.head_dim,
356
- self.config.rope_scaling["short_factor"],
357
- self.config.rope_scaling["long_factor"],
358
- max_position_embeddings=self.config.max_position_embeddings,
359
- original_max_position_embeddings=self.config.original_max_position_embeddings,
360
- base=self.config.rope_theta,
361
- )
362
-
363
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
364
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
365
 
366
  def forward(
367
  self,
@@ -395,7 +374,8 @@ class Phi3Attention(nn.Module):
395
  "with a layer index."
396
  )
397
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
398
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 
399
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
400
 
401
  if past_key_value is not None:
@@ -515,7 +495,7 @@ class Phi3FlashAttention2(Phi3Attention):
515
 
516
  # Because the input can be padded, the absolute sequence length depends on the max position id.
517
  rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
518
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
519
 
520
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
521
 
@@ -802,7 +782,7 @@ class Phi3SdpaAttention(Phi3Attention):
802
  kv_seq_len = key_states.shape[-2]
803
  if past_key_value is not None:
804
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
805
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
806
 
807
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
808
 
@@ -859,11 +839,11 @@ class Phi3DecoderLayer(nn.Module):
859
  self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
860
 
861
  self.mlp = Phi3MLP(config)
862
- self.input_layernorm = PHI3_NORM_CLASS(config.hidden_size, eps=config.rms_norm_eps)
863
 
864
  self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
865
  self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
866
- self.post_attention_layernorm = PHI3_NORM_CLASS(config.hidden_size, eps=config.rms_norm_eps)
867
 
868
  def forward(
869
  self,
@@ -1066,9 +1046,8 @@ class Phi3Model(Phi3PreTrainedModel):
1066
  self.layers = nn.ModuleList(
1067
  [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1068
  )
1069
- self.norm = PHI3_NORM_CLASS(config.hidden_size, eps=config.rms_norm_eps)
1070
-
1071
  self._attn_implementation = config._attn_implementation
 
1072
 
1073
  self.gradient_checkpointing = False
1074
  # Initialize weights and apply final processing
@@ -1255,6 +1234,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
1255
  def get_decoder(self):
1256
  return self.model
1257
 
 
1258
  @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1259
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1260
  def forward(
@@ -1284,8 +1264,8 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
1284
  ```python
1285
  >>> from transformers import AutoTokenizer, Phi3ForCausalLM
1286
 
1287
- >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3")
1288
- >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3")
1289
 
1290
  >>> prompt = "This is an example script ."
1291
  >>> inputs = tokenizer(prompt, return_tensors="pt")
@@ -1293,7 +1273,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
1293
  >>> # Generate
1294
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1295
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1296
- 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
1297
  ```"""
1298
 
1299
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
40
  add_code_sample_docstrings,
41
  add_start_docstrings,
42
  add_start_docstrings_to_model_forward,
43
+ is_flash_attn_2_available,
44
  is_flash_attn_greater_or_equal_2_10,
45
  logging,
46
  replace_return_docstrings,
 
55
  _flash_supports_window_size = False
56
  try:
57
  from flash_attn import flash_attn_func, flash_attn_varlen_func
58
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
59
 
60
  _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
 
 
 
 
 
 
 
 
61
  except ImportError as error:
62
  logger.warning(
63
+ f"`flash-attention` package not found, consider installing for better performance: {error}."
64
  )
65
  if not _flash_supports_window_size:
66
  logger.warning(
67
+ "Current `flash-attenton` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`."
68
  )
 
 
69
 
70
  _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
71
  _CONFIG_FOR_DOC = "Phi3Config"
 
95
  return self.weight * hidden_states.to(input_dtype)
96
 
97
 
 
 
 
98
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
99
  def _get_unpad_data(attention_mask):
100
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
108
  )
109
 
110
 
111
+ # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
112
  class Phi3RotaryEmbedding(nn.Module):
113
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
114
  super().__init__()
 
116
  self.dim = dim
117
  self.max_position_embeddings = max_position_embeddings
118
  self.base = base
119
+ self.register_buffer("inv_freq", None, persistent=False)
 
120
 
121
+ @torch.no_grad()
122
+ def forward(self, x, position_ids, seq_len=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  # x: [bs, num_attention_heads, seq_len, head_size]
124
+ if self.inv_freq is None:
125
+ self.inv_freq = 1.0 / (
126
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
127
+ )
128
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
129
+ position_ids_expanded = position_ids[:, None, :].float()
130
+ # Force float32 since bfloat16 loses precision on long contexts
131
+ # See https://github.com/huggingface/transformers/pull/29285
132
+ device_type = x.device.type
133
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
134
+ with torch.autocast(device_type=device_type, enabled=False):
135
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
136
+ emb = torch.cat((freqs, freqs), dim=-1)
137
+ cos = emb.cos()
138
+ sin = emb.sin()
139
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
140
+
141
+
142
+ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
143
+ def __init__(self, dim, config, device=None):
144
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
145
+
146
+ self.short_factor = config.rope_scaling["short_factor"]
147
+ self.long_factor = config.rope_scaling["long_factor"]
148
+ self.original_max_position_embeddings = config.original_max_position_embeddings
149
 
150
+ @torch.no_grad()
151
+ def forward(self, x, position_ids, seq_len=None):
152
+ seq_len = torch.max(position_ids) + 1
153
+ if seq_len > self.original_max_position_embeddings:
154
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
155
+ else:
156
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
157
 
158
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
159
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
160
 
161
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
162
+ position_ids_expanded = position_ids[:, None, :].float()
 
 
 
 
 
 
 
 
 
 
163
 
164
+ # Force float32 since bfloat16 loses precision on long contexts
165
+ # See https://github.com/huggingface/transformers/pull/29285
166
+ device_type = x.device.type
167
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
168
+ with torch.autocast(device_type=device_type, enabled=False):
169
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
170
+ emb = torch.cat((freqs, freqs), dim=-1)
171
 
172
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
173
+ if scale <= 1.0:
174
+ scaling_factor = 1.0
175
+ else:
176
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
 
177
 
178
+ cos = emb.cos() * scaling_factor
179
+ sin = emb.sin() * scaling_factor
180
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
181
 
 
 
 
 
182
 
183
+ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
184
+ def __init__(self, dim, config, device=None):
185
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
 
186
 
187
+ self.short_factor = config.rope_scaling["short_factor"]
188
+ self.long_factor = config.rope_scaling["long_factor"]
189
+ self.original_max_position_embeddings = config.original_max_position_embeddings
 
 
190
 
191
+ @torch.no_grad()
192
+ def forward(self, x, position_ids, seq_len=None):
193
+ seq_len = torch.max(position_ids) + 1
194
  if seq_len > self.original_max_position_embeddings:
195
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
 
196
  else:
197
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
 
 
 
 
 
 
 
 
198
 
199
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
200
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
 
201
 
202
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
203
+ position_ids_expanded = position_ids[:, None, :].float()
204
+
205
+ # Force float32 since bfloat16 loses precision on long contexts
206
+ # See https://github.com/huggingface/transformers/pull/29285
207
+ device_type = x.device.type
208
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
209
+ with torch.autocast(device_type=device_type, enabled=False):
210
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
211
+ emb = torch.cat((freqs, freqs), dim=-1)
212
+
213
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
214
+ if scale <= 1.0:
215
+ scaling_factor = 1.0
216
+ else:
217
+ scaling_factor = 0.1 * math.log(scale) + 1.0
218
+
219
+ cos = emb.cos() * scaling_factor
220
+ sin = emb.sin() * scaling_factor
221
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
222
 
223
 
224
  # Copied from transformers.models.llama.modeling_llama.rotate_half
 
229
  return torch.cat((-x2, x1), dim=-1)
230
 
231
 
232
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
233
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
234
  """Applies Rotary Position Embedding to the query and key tensors.
235
 
236
  Args:
 
238
  k (`torch.Tensor`): The key tensor.
239
  cos (`torch.Tensor`): The cosine part of the rotary embedding.
240
  sin (`torch.Tensor`): The sine part of the rotary embedding.
241
+ position_ids (`torch.Tensor`, *optional*):
242
+ Deprecated and unused.
 
243
  unsqueeze_dim (`int`, *optional*, defaults to 1):
244
  The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
245
  sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
 
250
  Returns:
251
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
252
  """
253
+ cos = cos.unsqueeze(unsqueeze_dim)
254
+ sin = sin.unsqueeze(unsqueeze_dim)
255
+ q_embed = (q * cos) + (rotate_half(q) * sin)
256
+ k_embed = (k * cos) + (rotate_half(k) * sin)
257
+ return q_embed, k_embed
 
 
 
 
 
258
 
259
 
260
  class Phi3MLP(nn.Module):
 
 
 
 
 
 
 
 
261
  def __init__(self, config):
262
  super().__init__()
263
 
 
268
  self.activation_fn = ACT2FN[config.hidden_act]
269
 
270
  def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
271
+ up_states = self.gate_up_proj(hidden_states)
272
 
273
+ gate, up_states = up_states.chunk(2, dim=-1)
274
+ up_states = up_states * self.activation_fn(gate)
 
 
 
 
 
275
 
276
+ return self.down_proj(up_states)
277
 
278
 
279
  # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
 
323
 
324
  op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
325
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
 
326
  self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False)
327
+ self._init_rope()
328
 
329
+ def _init_rope(self):
330
  if self.rope_scaling is None:
331
  self.rotary_emb = Phi3RotaryEmbedding(
332
  self.head_dim,
 
334
  base=self.rope_theta,
335
  )
336
  else:
337
+ scaling_type = self.config.rope_scaling["type"]
338
+ if scaling_type == "su":
339
+ self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
340
+ elif scaling_type == "yarn":
341
+ self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
342
+ else:
343
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
 
 
 
 
344
 
345
  def forward(
346
  self,
 
374
  "with a layer index."
375
  )
376
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
377
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
378
+
379
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
380
 
381
  if past_key_value is not None:
 
495
 
496
  # Because the input can be padded, the absolute sequence length depends on the max position id.
497
  rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
498
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len)
499
 
500
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
501
 
 
782
  kv_seq_len = key_states.shape[-2]
783
  if past_key_value is not None:
784
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
785
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
786
 
787
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
788
 
 
839
  self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
840
 
841
  self.mlp = Phi3MLP(config)
842
+ self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
843
 
844
  self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
845
  self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
846
+ self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
847
 
848
  def forward(
849
  self,
 
1046
  self.layers = nn.ModuleList(
1047
  [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1048
  )
 
 
1049
  self._attn_implementation = config._attn_implementation
1050
+ self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1051
 
1052
  self.gradient_checkpointing = False
1053
  # Initialize weights and apply final processing
 
1234
  def get_decoder(self):
1235
  return self.model
1236
 
1237
+ # Ignore copy
1238
  @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1239
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1240
  def forward(
 
1264
  ```python
1265
  >>> from transformers import AutoTokenizer, Phi3ForCausalLM
1266
 
1267
+ >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1268
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1269
 
1270
  >>> prompt = "This is an example script ."
1271
  >>> inputs = tokenizer(prompt, return_tensors="pt")
 
1273
  >>> # Generate
1274
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1275
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1276
+ 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
1277
  ```"""
1278
 
1279
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
sample_finetune.py CHANGED
@@ -25,7 +25,6 @@ check accelerate config:
25
  args = {
26
  "bf16": True,
27
  "do_eval": False,
28
- "eval_strategy": "no",
29
  "learning_rate": 5.0e-06,
30
  "log_level": "info",
31
  "logging_steps": 20,
 
25
  args = {
26
  "bf16": True,
27
  "do_eval": False,
 
28
  "learning_rate": 5.0e-06,
29
  "log_level": "info",
30
  "logging_steps": 20,
tokenizer_config.json CHANGED
@@ -335,7 +335,7 @@
335
  "<|/inst|>"
336
  ],
337
  "bos_token": "<s>",
338
- "chat_template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
339
  "clean_up_tokenization_spaces": false,
340
  "eos_token": "<|endoftext|>",
341
  "legacy": false,
 
335
  "<|/inst|>"
336
  ],
337
  "bos_token": "<s>",
338
+ "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
339
  "clean_up_tokenization_spaces": false,
340
  "eos_token": "<|endoftext|>",
341
  "legacy": false,