Text Generation
Transformers
Safetensors
imp
custom_code
Oyoy1235 commited on
Commit
470ad57
1 Parent(s): 5f3a7bd

update new phi2 structure

Browse files
README.md CHANGED
@@ -22,7 +22,7 @@ We release our model weights and provide an example below to run our model . Det
22
 
23
  **Install dependencies**
24
  ```bash
25
- pip install transformers # latest version is ok, but we recommend v4.36.0
26
  pip install -q pillow accelerate einops
27
  ```
28
 
 
22
 
23
  **Install dependencies**
24
  ```bash
25
+ pip install transformers # latest version is ok, but we recommend v4.37.0
26
  pip install -q pillow accelerate einops
27
  ```
28
 
config.json CHANGED
@@ -4,23 +4,26 @@
4
  "architectures": [
5
  "ImpForCausalLM"
6
  ],
 
7
  "attn_pdrop": 0.0,
8
  "auto_map": {
9
  "AutoConfig": "configuration_imp.ImpConfig",
10
  "AutoModelForCausalLM": "modeling_imp.ImpForCausalLM"
11
  },
 
12
  "embd_pdrop": 0.0,
13
  "eos_token_id": 50295,
14
- "flash_attn": false,
15
- "flash_rotary": false,
16
  "freeze_mm_mlp_adapter": false,
17
- "fused_dense": false,
 
18
  "image_aspect_ratio": "square",
19
  "image_token": "<image>",
20
  "image_token_index": 50296,
21
  "img_processor": null,
22
  "initializer_range": 0.02,
23
- "layer_norm_epsilon": 1e-05,
 
 
24
  "mm_hidden_size": 1152,
25
  "mm_projector_lr": 2e-05,
26
  "mm_projector_type": "mlp2x_gelu",
@@ -30,24 +33,25 @@
30
  "mm_vision_select_layer": -2,
31
  "mm_vision_tower": "google/siglip-so400m-patch14-384",
32
  "model_type": "imp",
33
- "n_embd": 2560,
34
- "n_head": 32,
35
- "n_head_kv": null,
36
- "n_inner": null,
37
- "n_layer": 32,
38
- "n_positions": 3072,
39
  "pad_token_id": 50256,
 
 
40
  "resid_pdrop": 0.1,
41
- "rotary_dim": 32,
 
42
  "tie_word_embeddings": false,
43
  "tokenizer_model_max_length": 3072,
44
  "tokenizer_padding_side": "right",
45
  "torch_dtype": "float16",
46
- "transformers_version": "4.31.0",
47
  "use_cache": true,
48
  "use_mm_proj": true,
49
  "vision_tower_config": {
50
  "attention_dropout": 0.0,
 
51
  "hidden_act": "gelu_pytorch_tanh",
52
  "hidden_size": 1152,
53
  "image_size": 384,
 
4
  "architectures": [
5
  "ImpForCausalLM"
6
  ],
7
+ "attention_dropout": 0.0,
8
  "attn_pdrop": 0.0,
9
  "auto_map": {
10
  "AutoConfig": "configuration_imp.ImpConfig",
11
  "AutoModelForCausalLM": "modeling_imp.ImpForCausalLM"
12
  },
13
+ "bos_token_id": 1,
14
  "embd_pdrop": 0.0,
15
  "eos_token_id": 50295,
 
 
16
  "freeze_mm_mlp_adapter": false,
17
+ "hidden_act": "gelu_new",
18
+ "hidden_size": 2560,
19
  "image_aspect_ratio": "square",
20
  "image_token": "<image>",
21
  "image_token_index": 50296,
22
  "img_processor": null,
23
  "initializer_range": 0.02,
24
+ "intermediate_size": 10240,
25
+ "layer_norm_eps": 1e-05,
26
+ "max_position_embeddings": 3072,
27
  "mm_hidden_size": 1152,
28
  "mm_projector_lr": 2e-05,
29
  "mm_projector_type": "mlp2x_gelu",
 
33
  "mm_vision_select_layer": -2,
34
  "mm_vision_tower": "google/siglip-so400m-patch14-384",
35
  "model_type": "imp",
36
+ "num_attention_heads": 32,
37
+ "num_hidden_layers": 32,
38
+ "num_key_value_heads": 32,
 
 
 
39
  "pad_token_id": 50256,
40
+ "partial_rotary_factor": 0.4,
41
+ "qk_layernorm": false,
42
  "resid_pdrop": 0.1,
43
+ "rope_scaling": null,
44
+ "rope_theta": 10000.0,
45
  "tie_word_embeddings": false,
46
  "tokenizer_model_max_length": 3072,
47
  "tokenizer_padding_side": "right",
48
  "torch_dtype": "float16",
49
+ "transformers_version": "4.37.0",
50
  "use_cache": true,
51
  "use_mm_proj": true,
52
  "vision_tower_config": {
53
  "attention_dropout": 0.0,
54
+ "attn_implementation": null,
55
  "hidden_act": "gelu_pytorch_tanh",
56
  "hidden_size": 1152,
57
  "image_size": 384,
configuration_imp.py CHANGED
@@ -56,59 +56,169 @@ logger = logging.get_logger(__name__)
56
 
57
 
58
  class PhiConfig(PretrainedConfig):
59
- """Phi configuration."""
 
 
 
 
60
 
61
- model_type = "phi-msft"
62
- attribute_map = {
63
- "max_position_embeddings": "n_positions",
64
- "hidden_size": "n_embd",
65
- "num_attention_heads": "n_head",
66
- "num_hidden_layers": "n_layer",
67
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  def __init__(
70
  self,
71
- vocab_size: int = 50304,
72
- n_positions: int = 2048,
73
- n_embd: int = 1024,
74
- n_layer: int = 20,
75
- n_inner: Optional[int] = None,
76
- n_head: int = 16,
77
- n_head_kv: Optional[int] = None,
78
- rotary_dim: Optional[int] = 32,
79
- activation_function: Optional[str] = "gelu_new",
80
- flash_attn: bool = False,
81
- flash_rotary: bool = False,
82
- fused_dense: bool = False,
83
- attn_pdrop: float = 0.0,
84
- embd_pdrop: float = 0.0,
85
- resid_pdrop: float = 0.0,
86
- layer_norm_epsilon: float = 1e-5,
87
- initializer_range: float = 0.02,
88
- tie_word_embeddings: bool = False,
89
- pad_vocab_size_multiple: int = 64,
90
- **kwargs
91
- ) -> None:
92
- self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
93
- self.n_positions = n_positions
94
- self.n_embd = n_embd
95
- self.n_layer = n_layer
96
- self.n_inner = n_inner
97
- self.n_head = n_head
98
- self.n_head_kv = n_head_kv
99
- self.rotary_dim = min(rotary_dim, n_embd // n_head)
100
- self.activation_function = activation_function
101
- self.flash_attn = flash_attn
102
- self.flash_rotary = flash_rotary
103
- self.fused_dense = fused_dense
104
- self.attn_pdrop = attn_pdrop
105
- self.embd_pdrop = embd_pdrop
106
  self.resid_pdrop = resid_pdrop
107
- self.layer_norm_epsilon = layer_norm_epsilon
 
 
 
108
  self.initializer_range = initializer_range
 
 
 
 
 
 
 
109
 
110
- super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
 
 
 
 
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
 
114
  class SiglipVisionConfig(PretrainedConfig):
 
56
 
57
 
58
  class PhiConfig(PretrainedConfig):
59
+ r"""
60
+ This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an Phi
61
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
62
+ defaults will yield a similar configuration to that of the Phi
63
+ [microsoft/phi-1](https://huggingface.co/microsoft/phi-1).
64
 
65
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
66
+ documentation from [`PretrainedConfig`] for more information.
67
+
68
+ Args:
69
+ vocab_size (`int`, *optional*, defaults to 51200):
70
+ Vocabulary size of the Phi model. Defines the number of different tokens that can be represented by the
71
+ `inputs_ids` passed when calling [`PhiModel`].
72
+ hidden_size (`int`, *optional*, defaults to 2048):
73
+ Dimension of the hidden representations.
74
+ intermediate_size (`int`, *optional*, defaults to 8192):
75
+ Dimension of the MLP representations.
76
+ num_hidden_layers (`int`, *optional*, defaults to 24):
77
+ Number of hidden layers in the Transformer decoder.
78
+ num_attention_heads (`int`, *optional*, defaults to 32):
79
+ Number of attention heads for each attention layer in the Transformer decoder.
80
+ num_key_value_heads (`int`, *optional*):
81
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
82
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
83
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
84
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
85
+ by meanpooling all the original heads within that group. For more details checkout [this
86
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
87
+ `num_attention_heads`.
88
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
89
+ Dropout probability for mlp outputs.
90
+ embd_pdrop (`int`, *optional*, defaults to 0.0):
91
+ The dropout ratio for the embeddings.
92
+ attention_dropout (`float`, *optional*, defaults to 0.0):
93
+ The dropout ratio after computing the attention scores.
94
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`):
95
+ The non-linear activation function (function or string) in the decoder.
96
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
97
+ The maximum sequence length that this model might ever be used with. Phi-1 and Phi-1.5 supports up to 2048
98
+ tokens.
99
+ initializer_range (`float`, *optional*, defaults to 0.02):
100
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
101
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
102
+ The epsilon used by the rms normalization layers.
103
+ use_cache (`bool`, *optional*, defaults to `True`):
104
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
105
+ relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
106
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
107
+ Whether to tie weight embeddings
108
+ rope_theta (`float`, *optional*, defaults to 10000.0):
109
+ The base period of the RoPE embeddings.
110
+ rope_scaling (`Dict`, *optional*):
111
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
112
+ strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
113
+ is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
114
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
115
+ these scaling strategies behave:
116
+ https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This
117
+ is an experimental feature, subject to breaking API changes in future versions.
118
+ partial_rotary_factor (`float`, *optional*, defaults to 0.5):
119
+ Percentage of the query and keys which will have rotary embedding.
120
+ qk_layernorm (`bool`, *optional*, defaults to `False`):
121
+ Whether or not to normalize the Queries and Keys after projecting the hidden states.
122
+ bos_token_id (`int`, *optional*, defaults to 1):
123
+ Denotes beginning of sequences token id.
124
+ eos_token_id (`int`, *optional*, defaults to 2):
125
+ Denotes end of sequences token id.
126
+
127
+ Example:
128
+
129
+ ```python
130
+ >>> from transformers import PhiModel, PhiConfig
131
+
132
+ >>> # Initializing a Phi-1 style configuration
133
+ >>> configuration = PhiConfig.from_pretrained("microsoft/phi-1")
134
+
135
+ >>> # Initializing a model from the configuration
136
+ >>> model = PhiModel(configuration)
137
+
138
+ >>> # Accessing the model configuration
139
+ >>> configuration = model.config
140
+ ```"""
141
+
142
+ model_type = "phi"
143
+ keys_to_ignore_at_inference = ["past_key_values"]
144
 
145
  def __init__(
146
  self,
147
+ vocab_size=51200,
148
+ hidden_size=2048,
149
+ intermediate_size=8192,
150
+ num_hidden_layers=32, #24
151
+ num_attention_heads=32,
152
+ num_key_value_heads=None,
153
+ resid_pdrop=0.0,
154
+ embd_pdrop=0.0,
155
+ attention_dropout=0.0,
156
+ hidden_act="gelu_new",
157
+ max_position_embeddings=2048,
158
+ initializer_range=0.02,
159
+ layer_norm_eps=1e-5,
160
+ use_cache=True,
161
+ tie_word_embeddings=False,
162
+ rope_theta=10000.0,
163
+ rope_scaling=None,
164
+ partial_rotary_factor=0.5,
165
+ qk_layernorm=False,
166
+ bos_token_id=1,
167
+ eos_token_id=2,
168
+ **kwargs,
169
+ ):
170
+ self.vocab_size = vocab_size
171
+ self.hidden_size = hidden_size
172
+ self.intermediate_size = intermediate_size
173
+ self.num_hidden_layers = num_hidden_layers
174
+ self.num_attention_heads = num_attention_heads
175
+
176
+ if num_key_value_heads is None:
177
+ num_key_value_heads = num_attention_heads
178
+
179
+ self.num_key_value_heads = num_key_value_heads
 
 
180
  self.resid_pdrop = resid_pdrop
181
+ self.embd_pdrop = embd_pdrop
182
+ self.attention_dropout = attention_dropout
183
+ self.hidden_act = hidden_act
184
+ self.max_position_embeddings = max_position_embeddings
185
  self.initializer_range = initializer_range
186
+ self.layer_norm_eps = layer_norm_eps
187
+ self.use_cache = use_cache
188
+ self.rope_theta = rope_theta
189
+ self.rope_scaling = rope_scaling
190
+ self.partial_rotary_factor = partial_rotary_factor
191
+ self.qk_layernorm = qk_layernorm
192
+ self._rope_scaling_validation()
193
 
194
+ super().__init__(
195
+ bos_token_id=bos_token_id,
196
+ eos_token_id=eos_token_id,
197
+ tie_word_embeddings=tie_word_embeddings,
198
+ **kwargs,
199
+ )
200
 
201
+ # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
202
+ def _rope_scaling_validation(self):
203
+ """
204
+ Validate the `rope_scaling` configuration.
205
+ """
206
+ if self.rope_scaling is None:
207
+ return
208
+
209
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
210
+ raise ValueError(
211
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
212
+ f"got {self.rope_scaling}"
213
+ )
214
+ rope_scaling_type = self.rope_scaling.get("type", None)
215
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
216
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
217
+ raise ValueError(
218
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
219
+ )
220
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
221
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
222
 
223
 
224
  class SiglipVisionConfig(PretrainedConfig):
generation_config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
- "eos_token_id":50295,
3
- "pad_token_id":50256,
4
  "_from_model_config": true,
5
- "transformers_version": "4.31.0"
 
 
6
  }
 
1
  {
 
 
2
  "_from_model_config": true,
3
+ "eos_token_id": 50295,
4
+ "pad_token_id": 50256,
5
+ "transformers_version": "4.37.0"
6
  }
model-00001-of-00007.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:29f0e0859601ec012c410dc6471b8965c5d3d16875f84fa095782487cb83113f
3
- size 996428776
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bd37cefba2183e42125e333a6d76440cb3f1e1eec64529307d4298200eaa7bc
3
+ size 996420688
model-00002-of-00007.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:beda1bcee40de85570fd28e37eef166276cbe7fbf0835579f8f847e12494968a
3
- size 996507088
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:957f6ac0a8419184dc3329759317512be05b60fb32b631a7d844b5d73ceb52a4
3
+ size 1022735040
model-00003-of-00007.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:20f3b15038f8d7191a99835d3e21aa63927a9361ad35e6c778a63b3136830b82
3
- size 996512312
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3e903ca59bb5e4559bab71fbaa6596ac592f88fc2ec83f2a1f53cb40688cf1a
3
+ size 1022740016
model-00004-of-00007.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:28422cb2f997225cf1a3a1771cb37e3b49e41c0312b989f98096fd9b4a31338a
3
- size 996512088
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f18313320dc95f815c4b26f6cbe2741ccae8e6cc03dbbefc08550038fb9e916
3
+ size 1022735112
model-00005-of-00007.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c672e017ca169da27170ab02267b9a9deedb2e3b240fc6cefdc734aa6e262a13
3
- size 996507152
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1416a37b0223dc43dd22b0e24825628eb12e1f17442fa23529d2565e1a503a5
3
+ size 1022740016
model-00006-of-00007.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fa13bd6005ccb00d73a495f1c0d06009dbf0891ccaf659cca28af3a59a4e1722
3
- size 1021447256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79fe8384fb45ff1928edba88068009a8fad8b2c0d8d0785e4f4dc4bf09b355df
3
+ size 1011258320
model-00007-of-00007.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7bc6ecad02a265ac62054118b43c5485b769f398963da4c9f73d2ece152fb027
3
- size 370061024
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ede8a43731129319ee4952693f82a49855b709b960b1b1c88e95f5cd7308739
3
+ size 275359120
model.safetensors.index.json CHANGED
The diff for this file is too large to render. See raw diff
 
modeling_imp.py CHANGED
@@ -4,7 +4,7 @@
4
  # Some code here is copied from the project Phi-2 (https://huggingface.co/microsoft/phi-2),
5
  # SigLIP@transformers==4.37.0.dev0 (https://huggingface.co/google/siglip-so400m-patch14-384),
6
  # and Llava (https://github.com/haotian-liu/LLaVA), and modified by
7
- # Zhenwei Shao ([email protected]) @ MILVLG. We thank them for their great works.
8
  # And their original licenses and copyright should be inherited (see the statements
9
  # in `configuration_imp.py` for more details).
10
 
@@ -16,13 +16,15 @@ from __future__ import annotations
16
  import os
17
  import math
18
  import re
19
- from dataclasses import dataclass, field
20
  from typing import Any, Dict, Optional, Tuple, Union, List
21
  from abc import ABC, abstractmethod
22
 
23
  import torch
24
- import torch.nn as nn
25
- from einops import rearrange, repeat
 
 
 
26
  from transformers import (
27
  PretrainedConfig,
28
  PreTrainedModel,
@@ -30,854 +32,744 @@ from transformers import (
30
  AutoModelForCausalLM
31
  )
32
  from transformers.activations import ACT2FN
33
- from transformers.modeling_outputs import CausalLMOutputWithPast
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  import sys
35
  from .configuration_imp import PhiConfig, ImpConfig
36
  from .vision_encoder import VisionTower
37
 
38
  try:
39
- from flash_attn.bert_padding import pad_input, unpad_input
40
- from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
41
- from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
42
- from flash_attn.ops.fused_dense import FusedDense
43
  except:
44
- pad_input, unpad_input = None, None
45
- FlashRotaryEmbedding = None
46
- FlashSelfAttention, FlashCrossAttention = None, None
47
- FusedDense = None
48
 
 
49
 
50
- @dataclass
51
- class InferenceParams:
52
- """Inference parameters passed to model to efficiently calculate
53
- and store context during inference.
54
-
55
- Reference:
56
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
57
-
58
- Args:
59
- max_seqlen: Maximum sequence length.
60
- max_batch_size: Maximum batch size.
61
- seqlen_offset: Sequence length offset.
62
- batch_size_offset: Batch size offset.
63
- key_value_memory_dict: Key value memory dictionary.
64
- lengths_per_sample: Lengths per sample.
65
-
66
- """
67
-
68
- max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
69
-
70
- max_batch_size: int = field(metadata={"help": "Maximum batch size."})
71
-
72
- seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
73
-
74
- batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
75
-
76
- key_value_memory_dict: Dict[str, Any] = field(
77
- default_factory=dict, metadata={"help": "Key value memory dictionary."}
78
- )
79
-
80
- lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
81
-
82
-
83
- class Embedding(nn.Module):
84
- """Token embedding with dropout."""
85
-
86
- def __init__(self, config: PretrainedConfig) -> None:
87
  super().__init__()
88
 
89
- self.wte = nn.Embedding(config.vocab_size, config.n_embd)
90
- self.drop = nn.Dropout(config.embd_pdrop)
91
-
92
- def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
93
- input_shape = input_ids.size()
94
- input_ids = input_ids.view(-1, input_shape[-1])
95
-
96
- hidden_states = self.wte(input_ids)
97
- hidden_states = self.drop(hidden_states)
98
-
99
- return hidden_states
100
 
 
 
 
 
101
 
 
 
 
102
 
103
- def _apply_rotary_emb(
104
- x: torch.FloatTensor,
105
- cos: torch.FloatTensor,
106
- sin: torch.FloatTensor,
107
- ) -> torch.FloatTensor:
108
- _, seqlen, _, _ = x.shape
109
- _, rotary_dim = cos.shape
110
- rotary_dim *= 2
111
 
112
- x_rot = x[:, :, :, :rotary_dim]
113
- x_pass = x[:, :, :, rotary_dim:]
 
 
114
 
115
- x1, x2 = x_rot.chunk(2, dim=-1)
116
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
117
- x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
 
118
 
119
- x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
120
 
121
- return torch.cat([x_rot, x_pass], axis=-1)
 
 
122
 
 
 
 
123
 
124
- def _apply_rotary_emb_kv(
125
- kv: torch.FloatTensor,
126
- cos: torch.FloatTensor,
127
- sin: torch.FloatTensor,
128
- cos_k: Optional[torch.FloatTensor] = None,
129
- sin_k: Optional[torch.FloatTensor] = None,
130
- ) -> torch.FloatTensor:
131
- _, seqlen, _, _, _ = kv.shape
132
- _, rotary_dim = cos.shape
133
- rotary_dim *= 2
134
 
135
- k_rot = kv[:, :, 0, :, :rotary_dim]
136
- k_pass = kv[:, :, 0, :, rotary_dim:]
 
 
 
137
 
138
- k1, k2 = k_rot.chunk(2, dim=-1)
139
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
140
- k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
141
 
142
- k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
 
 
143
 
144
- return torch.cat(
145
- [
146
- torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
147
- kv[:, :, 1:2, :, :],
148
- ],
149
- axis=2,
150
- )
151
 
 
 
152
 
153
- def _apply_rotary_emb_qkv(
154
- qkv: torch.FloatTensor,
155
- cos: torch.FloatTensor,
156
- sin: torch.FloatTensor,
157
- cos_k: Optional[torch.FloatTensor] = None,
158
- sin_k: Optional[torch.FloatTensor] = None,
159
- ) -> torch.FloatTensor:
160
- _, seqlen, _, _, _ = qkv.shape
161
- _, rotary_dim = cos.shape
162
- rotary_dim *= 2
163
 
164
- q_rot = qkv[:, :, 0, :, :rotary_dim]
165
- q_pass = qkv[:, :, 0, :, rotary_dim:]
166
 
167
- k_rot = qkv[:, :, 1, :, :rotary_dim]
168
- k_pass = qkv[:, :, 1, :, rotary_dim:]
 
 
 
169
 
170
- q1, q2 = q_rot.chunk(2, dim=-1)
171
- k1, k2 = k_rot.chunk(2, dim=-1)
172
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
173
- q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
174
 
175
- q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
176
- k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
 
 
 
 
177
 
178
- return torch.cat(
179
- [
180
- torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
181
- torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
182
- qkv[:, :, 2:3, :, :],
183
- ],
184
- axis=2,
185
- )
186
 
 
 
 
187
 
188
- class RotaryEmbedding(nn.Module):
189
- """Rotary positional embedding (RoPE).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
- Reference:
192
- RoFormer: Enhanced Transformer with Rotary Position Embedding.
193
- https://arxiv.org/pdf/2104.09864.pdf.
194
 
195
- """
196
 
197
- def __init__(
198
- self,
199
- dim: int,
200
- base: int = 10000,
201
- scale_base: Optional[float] = None,
202
- pos_idx_in_fp32: bool = True,
203
- max_position_embeddings: int = 2048,
204
- device: Optional[str] = None,
205
- **kwargs,
206
- ) -> None:
207
  super().__init__()
 
 
 
 
208
 
209
- if scale_base is not None:
210
- raise NotImplementedError
 
 
 
211
 
212
- self.dim = dim
213
- self.base = float(base)
214
- self.scale_base = scale_base
215
- self.pos_idx_in_fp32 = pos_idx_in_fp32
216
- self.max_position_embeddings = max_position_embeddings
217
- self.device = device
218
 
219
- # Generate and save the inverse frequency buffer (non-trainable)
220
- inv_freq = self._compute_inv_freq(device)
221
- self.register_buffer("inv_freq", inv_freq, persistent=False)
 
 
 
 
 
 
 
 
222
 
223
- # Generate and save the scale buffer (non-trainable)
224
- scale = (
225
- (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
226
- if scale_base is not None
227
- else None
228
- )
229
- self.register_buffer("scale", scale, persistent=False)
230
 
231
- # Initialize cached attributes since ONNX can't rely on dynamic initialization
232
- self._update_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.float32)
233
 
234
- def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
235
- return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
236
 
237
- def _update_cos_sin_cache(
238
- self,
239
- seqlen: int,
240
- device: Optional[str] = None,
241
- dtype: Optional[torch.dtype] = None,
242
- ) -> None:
243
- self._seq_len_cached = seqlen
244
-
245
- # fp32 is preferred since the output of `torch.arange` can be quite large
246
- # and bf16 would lose a lot of precision
247
- if self.pos_idx_in_fp32:
248
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
249
- if self.inv_freq.dtype != torch.float32:
250
- inv_freq = self._compute_inv_freq(device=device)
251
- else:
252
- inv_freq = self.inv_freq
253
- else:
254
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
255
- inv_freq = self.inv_freq
256
-
257
- # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
258
- freqs = torch.outer(t, inv_freq)
259
- if self.scale is None:
260
- self._cos_cached = torch.cos(freqs).to(dtype)
261
- self._sin_cached = torch.sin(freqs).to(dtype)
262
- else:
263
- power = (
264
- torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
265
- ) / self.scale_base
266
- scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
267
 
268
- # Force the scale multiplication to happen in fp32
269
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
270
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
271
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
272
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
273
 
274
- def forward(
275
- self,
276
- qkv: torch.Tensor,
277
- kv: Optional[torch.Tensor] = None,
278
- seqlen_offset: int = 0,
279
- **kwargs,
280
- ) -> Tuple[torch.Tensor, torch.Tensor]:
281
- if (
282
- self._seq_len_cached < qkv.shape[1] + seqlen_offset
283
- or self._cos_cached.device != qkv.device
284
- or self._cos_cached.dtype != qkv.dtype
285
- or (self.training and self._cos_cached.is_inference())
286
- ):
287
- self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
288
-
289
- if kv is None:
290
- return _apply_rotary_emb_qkv(
291
- qkv,
292
- self._cos_cached[seqlen_offset:],
293
- self._sin_cached[seqlen_offset:],
294
- )
295
- else:
296
- q = _apply_rotary_emb(
297
- qkv,
298
- self._cos_cached[seqlen_offset:],
299
- self._sin_cached[seqlen_offset:],
300
  )
301
- kv = _apply_rotary_emb_kv(
302
- kv,
303
- self._cos_cached[seqlen_offset:],
304
- self._sin_cached[seqlen_offset:],
305
  )
306
 
307
- return q, kv
308
-
309
-
310
- class MLP(nn.Module):
311
- """Multi-Layer Perceptron.
312
-
313
- Reference:
314
- Attention Is All You Need.
315
- https://arxiv.org/pdf/1706.03762.pdf.
316
-
317
- """
318
-
319
- def __init__(
320
- self,
321
- config: PretrainedConfig,
322
- n_inner: Optional[int] = None,
323
- act_fn: Optional[str] = None,
324
- ) -> None:
325
- super().__init__()
326
-
327
- act_fn = config.activation_function if act_fn is None else act_fn
328
-
329
- n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
330
- n_inner = n_inner if n_inner is not None else 4 * config.n_embd
331
-
332
- self.fc1 = nn.Linear(config.n_embd, n_inner)
333
- self.fc2 = nn.Linear(n_inner, config.n_embd)
334
- self.act = ACT2FN[act_fn]
335
 
336
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
337
- hidden_states = self.fc1(hidden_states)
338
- hidden_states = self.act(hidden_states)
339
- hidden_states = self.fc2(hidden_states)
340
-
341
- return hidden_states
342
-
343
-
344
- class SelfAttention(nn.Module):
345
- """Self-attention layer (compatible with PyTorch).
346
-
347
- Reference:
348
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
349
-
350
- """
351
-
352
- def __init__(
353
- self,
354
- causal: bool = True,
355
- softmax_scale: Optional[float] = None,
356
- attention_dropout: float = 0.0,
357
- ) -> None:
358
- super().__init__()
359
-
360
- self.causal = causal
361
- self.softmax_scale = softmax_scale
362
- self.drop = nn.Dropout(attention_dropout)
363
 
 
364
  @torch.autocast("cpu", enabled=False)
365
  @torch.autocast("cuda", enabled=False)
366
  def forward(
367
  self,
368
- qkv: torch.FloatTensor,
369
- causal: bool = None,
370
- key_padding_mask: Optional[torch.BoolTensor] = None,
371
- **kwargs,
372
- ) -> torch.FloatTensor:
373
- batch_size, seqlen = qkv.shape[0], qkv.shape[1]
374
- q, k, v = qkv.unbind(dim=2)
375
-
376
- q = q.to(torch.float32)
377
- k = k.to(torch.float32)
378
 
379
- causal = self.causal if causal is None else causal
380
- softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
- # Autocast is manually disabled to avoid `torch.einsum` performing the operation
383
- # using float16, which might lead to overflow
384
- scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
- if key_padding_mask is not None:
387
- padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
388
- padding_mask.masked_fill_(key_padding_mask, 0.0)
 
 
 
389
 
390
- scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
 
 
391
 
392
- if causal:
393
- causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
394
- scores = scores + causal_mask.to(dtype=scores.dtype)
395
 
396
- attention = torch.softmax(scores, dim=-1).to(v.dtype)
397
- attention = self.drop(attention)
 
 
 
398
 
399
- output = torch.einsum("bhts,bshd->bthd", attention, v)
 
400
 
401
- return output
402
 
 
403
 
404
- class CrossAttention(nn.Module):
405
- """Cross-attention layer (compatible with PyTorch).
406
 
407
- Reference:
408
- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
409
 
 
 
 
 
 
410
  """
411
 
412
- def __init__(
413
- self,
414
- causal: bool = True,
415
- softmax_scale: Optional[float] = None,
416
- attention_dropout: float = 0.0,
417
- ) -> None:
418
- super().__init__()
419
 
420
- self.causal = causal
421
- self.softmax_scale = softmax_scale
422
- self.drop = nn.Dropout(attention_dropout)
 
423
 
424
- @torch.autocast("cpu", enabled=False)
425
- @torch.autocast("cuda", enabled=False)
426
  def forward(
427
  self,
428
- q: torch.FloatTensor,
429
- kv: torch.FloatTensor,
430
- causal: bool = None,
431
- key_padding_mask: Optional[torch.BoolTensor] = None,
 
 
432
  **kwargs,
433
- ) -> torch.FloatTensor:
434
- batch_size, seqlen_q = q.shape[0], q.shape[1]
435
- seqlen_k = kv.shape[1]
436
-
437
- if kv.shape[3] != q.shape[2]:
438
- kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
439
- k, v = kv.unbind(dim=2)
440
-
441
- q = q.to(torch.float32)
442
- k = k.to(torch.float32)
443
-
444
- causal = self.causal if causal is None else causal
445
- softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
446
-
447
- # Autocast is manually disabled to avoid `torch.einsum` performing the operation
448
- # using float16, which might lead to overflow
449
- scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
450
-
451
- if key_padding_mask is not None:
452
- padding_mask = torch.full(
453
- (batch_size, seqlen_k),
454
- -10000.0,
455
- dtype=scores.dtype,
456
- device=scores.device,
457
- )
458
- padding_mask.masked_fill_(key_padding_mask, 0.0)
459
-
460
- scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
461
-
462
- if causal:
463
- rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
464
- cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
465
- causal_mask = cols > rows + seqlen_k - seqlen_q
466
 
467
- scores = scores.masked_fill(causal_mask, -10000.0)
468
 
469
- attention = torch.softmax(scores, dim=-1).to(v.dtype)
470
- attention = self.drop(attention)
471
 
472
- output = torch.einsum("bhts,bshd->bthd", attention, v)
 
 
473
 
474
- return output
 
 
475
 
 
 
 
 
 
 
476
 
477
- def _find_mha_dims(
478
- config: PretrainedConfig,
479
- n_head: Optional[int] = None,
480
- n_head_kv: Optional[int] = None,
481
- head_dim: Optional[int] = None,
482
- ) -> Tuple[int, int]:
483
- if n_head is None and head_dim is None:
484
- head_dim = config.n_embd // config.n_head
485
- n_head = config.n_head
486
- elif n_head is None or head_dim is None:
487
- raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
488
 
489
- if n_head_kv is None:
490
- n_head_kv = getattr(config, "n_head_kv", None) or n_head
491
-
492
- return n_head, n_head_kv, head_dim
493
-
494
-
495
- def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
496
- num_heads, head_dim = kv.shape[-2:]
497
-
498
- if layer_idx not in inference_params.key_value_memory_dict:
499
- inference_params.key_value_memory_dict[layer_idx] = torch.empty(
500
- inference_params.max_batch_size,
501
- inference_params.max_seqlen,
502
- 2,
503
- num_heads,
504
- head_dim,
505
- dtype=kv.dtype,
506
- device=kv.device,
507
  )
508
-
509
- batch_start = inference_params.batch_size_offset
510
- batch_end = batch_start + kv.shape[0]
511
-
512
- sequence_start = inference_params.seqlen_offset
513
- sequence_end = sequence_start + kv.shape[1]
514
-
515
- # When the current sequence length is equal to or larger than the maximum sequence length,
516
- # we need to concatenate the current `kv` with the cached `kv` to expand its length
517
- if sequence_end >= inference_params.max_seqlen:
518
- inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
519
-
520
- inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
521
- kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
522
-
523
- return kv
524
-
525
-
526
- class MHA(nn.Module):
527
- """Multi-head attention layer."""
528
-
529
- def __init__(
530
- self,
531
- config: PretrainedConfig,
532
- dtype: Optional[torch.dtype] = None,
533
- device: Optional[str] = None,
534
- rotary_dim: Optional[int] = None,
535
- rotary_base: float = 10000.0,
536
- rotary_scale_base: Optional[float] = None,
537
- n_head: Optional[int] = None,
538
- n_head_kv: Optional[int] = None,
539
- head_dim: Optional[int] = None,
540
- bias: bool = True,
541
- causal: bool = True,
542
- softmax_scale: Optional[float] = None,
543
- layer_idx: Optional[int] = None,
544
- return_residual: bool = False,
545
- checkpointing: bool = False,
546
- ) -> None:
547
- super().__init__()
548
-
549
- # Rotary embedding
550
- self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
551
- if self.rotary_dim > 0:
552
- rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
553
- if rotary_cls is None:
554
- rotary_cls = RotaryEmbedding
555
-
556
- rotary_kwargs = {}
557
- if rotary_cls is RotaryEmbedding:
558
- rotary_kwargs["max_position_embeddings"] = config.n_positions
559
-
560
- self.rotary_emb = rotary_cls(
561
- self.rotary_dim,
562
- base=rotary_base,
563
- scale_base=rotary_scale_base,
564
- device=device,
565
- **rotary_kwargs,
566
- )
567
-
568
- # MLP
569
- self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
570
- config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim
571
  )
572
- op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
573
- hidden_size = config.n_embd
574
-
575
- linear_cls = FusedDense if config.fused_dense else nn.Linear
576
- if linear_cls is None:
577
- linear_cls = nn.Linear
578
-
579
- self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
580
- self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
581
-
582
- # Attention
583
- attn_cls = FlashSelfAttention if config.flash_attn else SelfAttention
584
- if attn_cls is None:
585
- attn_cls = SelfAttention
586
-
587
- cross_attn_cls = FlashCrossAttention if config.flash_attn else CrossAttention
588
- if cross_attn_cls is None:
589
- cross_attn_cls = CrossAttention
590
-
591
- self.inner_attn = attn_cls(
592
- causal=causal,
593
- softmax_scale=softmax_scale,
594
- attention_dropout=config.attn_pdrop,
595
- )
596
- self.inner_cross_attn = cross_attn_cls(
597
- causal=causal,
598
- softmax_scale=softmax_scale,
599
- attention_dropout=config.attn_pdrop,
600
- )
601
-
602
- self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
603
- self.layer_idx = layer_idx
604
- self.return_residual = return_residual
605
- self.checkpointing = checkpointing
606
-
607
- def _forward_self_attn(
608
- self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
609
- ) -> torch.FloatTensor:
610
- qkv = self.Wqkv(x)
611
- qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
612
-
613
- if self.rotary_dim > 0:
614
- qkv = self.rotary_emb(qkv)
615
-
616
- if self.flash_attn:
617
- batch_size, seqlen = qkv.shape[0], qkv.shape[1]
618
-
619
- cu_seqlens, max_seqlen = None, None
620
- if key_padding_mask is not None:
621
- # If `key_padding_mask` is supplied, we need to unpad the input and retrieve
622
- # the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
623
- qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
624
-
625
- if self.checkpointing:
626
- attn_output = torch.utils.checkpoint.checkpoint(
627
- self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
628
- )
629
  else:
630
- attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
631
-
632
- # If `key_padding_mask` is supplied, we need to pad the output back to the original shape
633
- return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
634
-
635
- if self.checkpointing:
636
- return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
637
-
638
- return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
639
 
640
- def _forward_cross_attn(
641
- self,
642
- x: torch.FloatTensor,
643
- past_key_values: Optional[InferenceParams],
644
- key_padding_mask: Optional[torch.BoolTensor],
645
- ) -> torch.FloatTensor:
646
- batch_size = x.shape[0]
647
-
648
- qkv = self.Wqkv(x)
649
 
650
- q = qkv[..., : self.n_head * self.head_dim]
651
- q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
 
652
 
653
- kv = qkv[..., self.n_head * self.head_dim :]
654
- kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
 
655
 
656
- seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0
657
- causal = None if seqlen_offset == 0 else False
658
- if self.rotary_dim > 0:
659
- q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
660
 
661
- if past_key_values is not None:
662
- kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
663
 
664
- if self.flash_attn:
665
- batch_size, seqlen_q = q.shape[0], q.shape[1]
666
- seqlen_k = kv.shape[1]
667
 
668
- cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = (
669
- None,
670
- None,
671
- None,
672
- None,
673
- )
674
- if key_padding_mask is not None:
675
- kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
676
-
677
- if seqlen_q == 1:
678
- key_padding_mask = torch.ones(batch_size, 1, device=q.device)
679
- elif seqlen_q != seqlen_k:
680
- key_padding_mask = key_padding_mask[:, -seqlen_q:]
681
-
682
- q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
683
-
684
- if self.checkpointing:
685
- attn_output = torch.utils.checkpoint.checkpoint(
686
- self.inner_cross_attn,
687
- q,
688
- kv,
689
- causal=causal,
690
- cu_seqlens=cu_seqlens_q,
691
- max_seqlen=max_seqlen_q,
692
- cu_seqlens_k=cu_seqlens_k,
693
- max_seqlen_k=max_seqlen_k,
694
- )
695
- else:
696
- attn_output = self.inner_cross_attn(
697
- q,
698
- kv,
699
- causal=causal,
700
- cu_seqlens=cu_seqlens_q,
701
- max_seqlen=max_seqlen_q,
702
- cu_seqlens_k=cu_seqlens_k,
703
- max_seqlen_k=max_seqlen_k,
704
- )
705
 
706
- return (
707
- pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
708
- if key_padding_mask is not None
709
- else attn_output
 
710
  )
711
 
712
- if self.checkpointing:
713
- return torch.utils.checkpoint.checkpoint(
714
- self.inner_cross_attn,
715
- q,
716
- kv,
717
- key_padding_mask=key_padding_mask,
 
 
 
 
 
 
 
718
  causal=causal,
719
  )
720
 
721
- return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
722
-
723
- def forward(
724
- self,
725
- x: torch.FloatTensor,
726
- past_key_values: Optional[InferenceParams] = None,
727
- attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
728
- **kwargs,
729
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
730
- if attention_mask is not None:
731
- attention_mask = attention_mask.bool()
732
  else:
733
- attention_mask = None
 
 
734
 
735
- # MHA
736
- if self.n_head == self.n_head_kv:
737
- if past_key_values is None:
738
- # If `past_key_values` are not supplied, we run self-attention
739
- attn_output = self._forward_self_attn(x, attention_mask)
740
- else:
741
- # If `past_key_values` are supplied, it means that we might have cached values and
742
- # could take advantage of cross-attention
743
- attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
744
- # MQA / GQA
745
- else:
746
- # Regardless of `past_key_values` being supplied or not, it always use cross-attention
747
- # because `q` and `kv` lengths might be different
748
- attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
749
 
750
- output = rearrange(attn_output, "... h d -> ... (h d)")
751
- output = self.out_proj(output)
 
 
752
 
753
- return output if not self.return_residual else (output, x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
 
755
 
756
- class ParallelBlock(nn.Module):
757
- """Parallel block.
758
 
759
- This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
 
 
 
760
 
761
- """
762
 
763
- def __init__(
764
- self,
765
- config: PretrainedConfig,
766
- block_idx: Optional[int] = None,
767
- ) -> None:
768
  super().__init__()
769
-
770
- self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
 
771
  self.resid_dropout = nn.Dropout(config.resid_pdrop)
772
- self.block_idx = block_idx
773
-
774
- self.mixer = MHA(config, layer_idx=block_idx)
775
- self.mlp = MLP(config)
776
 
777
  def forward(
778
  self,
779
- hidden_states: torch.FloatTensor,
780
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
781
- attention_mask: Optional[torch.BoolTensor] = None,
782
- **kwargs,
783
- ) -> torch.FloatTensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
784
  residual = hidden_states
785
- hidden_states = self.ln(hidden_states)
786
 
787
- attn_outputs = self.mixer(
788
- hidden_states,
789
- past_key_values=past_key_values,
 
 
790
  attention_mask=attention_mask,
 
 
 
 
791
  )
792
- if isinstance(attn_outputs, tuple):
793
- attn_outputs = attn_outputs[0]
794
-
795
  attn_outputs = self.resid_dropout(attn_outputs)
796
- feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
797
 
 
798
  hidden_states = attn_outputs + feed_forward_hidden_states + residual
 
799
 
800
- return hidden_states
801
-
802
-
803
- class CausalLMHead(nn.Module):
804
- """Causal Language Modeling head.
805
-
806
- Reference:
807
- Improving Language Understanding by Generative Pre-Training.
808
- https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
809
-
810
- """
811
-
812
- def __init__(self, config: PretrainedConfig) -> None:
813
- super().__init__()
814
-
815
- self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
816
- self.linear = nn.Linear(config.n_embd, config.vocab_size)
817
 
818
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
819
- hidden_states = self.ln(hidden_states)
820
- logits = self.linear(hidden_states).to(torch.float32)
821
 
822
- return logits
823
 
824
 
825
  class PhiPreTrainedModel(PreTrainedModel):
826
  """Phi pre-trained model."""
827
 
828
  config_class = PhiConfig
829
- base_model_prefix = "transformer"
830
  supports_gradient_checkpointing = True
831
- _no_split_modules = ["ParallelBlock", "CLIPEncoderLayer", "Block"]
 
 
 
832
 
833
  def __init__(self, *inputs, **kwargs) -> None:
834
  super().__init__(*inputs, **kwargs)
835
 
836
- def _init_weights(self, module: nn.Module) -> None:
837
- if isinstance(module, (nn.Linear,)):
838
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
 
839
  if module.bias is not None:
840
  module.bias.data.zero_()
841
  elif isinstance(module, nn.Embedding):
842
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
843
  if module.padding_idx is not None:
844
  module.weight.data[module.padding_idx].zero_()
845
- elif isinstance(module, nn.LayerNorm):
846
- if module.bias is not None:
847
- module.bias.data.zero_()
848
- module.weight.data.fill_(1.0)
849
 
850
  def prepare_inputs_for_generation(
851
  self,
852
  input_ids: torch.LongTensor,
853
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
 
854
  attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
855
  **kwargs,
856
  ) -> Dict[str, Any]:
857
- if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
858
- past_key_values = InferenceParams(
859
- max_seqlen=self.config.n_positions,
860
- max_batch_size=input_ids.shape[0],
861
- seqlen_offset=0,
862
- batch_size_offset=0,
863
- key_value_memory_dict={},
864
- lengths_per_sample=None,
865
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
866
  else:
867
- # ======================================================================
868
- # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
869
- # inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...]
870
- # past_key_values.seqlen_offset = input_ids.shape[1] - 1
871
- # ======================================================================
872
- # I change the way of updating `past_key_values.seqlen_offset` to make the inference of imp work.
873
- # [Edited by zhenwei - 2024-01-20 21:15]
874
- input_ids = input_ids[:, -1].unsqueeze(-1)
875
-
876
- return {
877
- "input_ids": input_ids,
878
- "past_key_values": past_key_values,
879
- "attention_mask": attention_mask,
880
- }
881
 
882
 
883
  class LlavaMetaModel(ABC):
@@ -922,15 +814,20 @@ class LlavaMetaModel(ABC):
922
  class ImpModel(PhiPreTrainedModel, LlavaMetaModel):
923
  """Imp model. This implementation is modified from the implementation of Phi-2"""
924
 
925
- config_class = ImpConfig
926
- # _keys_to_ignore_on_load_missing = [""]
927
- # _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
928
 
929
  def __init__(self, config: ImpConfig) -> None:
930
  super().__init__(config)
 
 
 
 
 
 
 
 
 
 
931
 
932
- self.embd = Embedding(config)
933
- self.h = nn.ModuleList([ParallelBlock(config, block_idx=i) for i in range(config.n_layer)])
934
  self.gradient_checkpointing = False
935
 
936
  if hasattr(config, "mm_vision_tower"):
@@ -939,57 +836,139 @@ class ImpModel(PhiPreTrainedModel, LlavaMetaModel):
939
 
940
  self.post_init()
941
 
942
- def embed_tokens(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
943
- return self.embd(input_ids)[0]
944
 
945
  def get_input_embeddings(self) -> nn.Embedding:
946
- return self.embd.wte
 
947
 
948
- def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
949
- self.embd.wte = new_embeddings
950
 
951
  def forward(
952
  self,
953
  input_ids: torch.LongTensor,
954
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
955
  attention_mask: Optional[torch.BoolTensor] = None,
956
- inputs_embeds: Optional[torch.FloatTensor] = None
957
- ) -> torch.FloatTensor:
 
 
 
 
 
 
 
 
 
 
 
958
 
959
- if inputs_embeds is None:
960
- hidden_states = self.embd(input_ids)
 
 
 
 
 
 
 
961
  else:
962
- hidden_states = inputs_embeds
 
 
963
 
964
- for layer in self.h:
965
- if self.gradient_checkpointing and self.training:
 
 
 
 
 
 
 
 
 
 
 
966
 
967
- def create_custom_forward(module):
968
- def custom_forward(*inputs):
969
- # None for past_key_value
970
- return module(*inputs)
 
 
 
 
 
971
 
972
- return custom_forward
973
 
974
- hidden_states = torch.utils.checkpoint.checkpoint(
975
- create_custom_forward(layer),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
976
  hidden_states,
977
- None,
978
  attention_mask,
 
 
 
979
  )
980
  else:
981
- hidden_states = layer(
982
  hidden_states,
983
- past_key_values=past_key_values,
984
  attention_mask=attention_mask,
 
 
 
 
985
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
986
 
987
- # I change the way of updating `past_key_values.seqlen_offset` to make the inference of imp work.
988
- # [Edited by zhenwei - 2024-01-20 21:15]
989
- if past_key_values is not None: # FIXME: when multi-batch inference, it is a bug
990
- past_key_values.seqlen_offset += hidden_states.shape[1]
991
-
992
- return hidden_states
993
 
994
 
995
  class LlavaMetaForCausalLM(ABC):
@@ -1016,18 +995,40 @@ class LlavaMetaForCausalLM(ABC):
1016
  self, input_ids, position_ids, attention_mask, past_key_values, labels, images
1017
  ):
1018
  vision_tower = self.get_vision_tower()
1019
- # if vision_tower is None or images is None or past_key_values.seqlen_offset != 0:
 
 
 
 
 
 
 
 
 
1020
  if vision_tower is None or images is None or input_ids.shape[1] == 1:
1021
- if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
1022
- target_shape = past_key_values.seqlen_offset + 1
1023
- # inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...]
1024
- attention_mask = torch.cat((attention_mask, torch.ones(
1025
- (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
1026
- dtype=attention_mask.dtype,
1027
- device=attention_mask.device
1028
- )), dim=1)
1029
- position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
1030
- return input_ids, position_ids, attention_mask, past_key_values, None, labels
 
 
 
 
 
 
 
 
 
 
 
 
 
1031
 
1032
  if type(images) is list or images.ndim == 5:
1033
  concat_images = torch.cat([image for image in images], dim=0)
@@ -1159,6 +1160,7 @@ class LlavaMetaForCausalLM(ABC):
1159
  position_ids = None
1160
 
1161
  return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
 
1162
 
1163
 
1164
  class ImpForCausalLM(PhiPreTrainedModel, LlavaMetaForCausalLM):
@@ -1171,37 +1173,36 @@ class ImpForCausalLM(PhiPreTrainedModel, LlavaMetaForCausalLM):
1171
  def __init__(self, config: ImpConfig) -> None:
1172
  super().__init__(config)
1173
 
1174
- self.transformer = ImpModel(config)
1175
- self.lm_head = CausalLMHead(config)
 
1176
 
1177
  self.post_init()
1178
  self.init_constants(config)
1179
 
 
 
 
 
 
 
1180
  def get_output_embeddings(self) -> nn.Linear:
1181
- return self.lm_head.linear
1182
 
1183
  def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
1184
- self.lm_head.linear = new_embeddings
1185
 
1186
  def get_model(self):
1187
- return self.transformer
 
 
 
 
 
 
1188
 
1189
  def image_preprocess(self, images):
1190
  return self.get_vision_tower().image_processor(images)['pixel_values']
1191
-
1192
- def backbone_forward(
1193
- self,
1194
- input_ids: torch.LongTensor,
1195
- past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
1196
- attention_mask: Optional[torch.BoolTensor] = None,
1197
- labels: Optional[torch.LongTensor] = None,
1198
- inputs_embeds: Optional[torch.FloatTensor] = None,
1199
- **kwargs,
1200
- ) -> CausalLMOutputWithPast:
1201
- hidden_states = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds)
1202
- lm_logits = self.lm_head(hidden_states)
1203
-
1204
- return CausalLMOutputWithPast(loss=None, logits=lm_logits, past_key_values=past_key_values)
1205
 
1206
  def forward(
1207
  self,
@@ -1217,6 +1218,12 @@ class ImpForCausalLM(PhiPreTrainedModel, LlavaMetaForCausalLM):
1217
  images: Optional[torch.FloatTensor] = None,
1218
  return_dict: Optional[bool] = None,
1219
  ) -> Union[Tuple, CausalLMOutputWithPast]:
 
 
 
 
 
 
1220
 
1221
  if inputs_embeds is None:
1222
  (
@@ -1235,17 +1242,44 @@ class ImpForCausalLM(PhiPreTrainedModel, LlavaMetaForCausalLM):
1235
  images
1236
  )
1237
 
1238
- return self.backbone_forward(
1239
  input_ids=input_ids,
 
1240
  attention_mask=attention_mask,
1241
- position_ids=position_ids,
1242
- past_key_values=past_key_values,
1243
  inputs_embeds=inputs_embeds,
1244
- labels=labels,
1245
  use_cache=use_cache,
1246
  output_attentions=output_attentions,
1247
  output_hidden_states=output_hidden_states,
1248
  return_dict=return_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1249
  )
1250
 
1251
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
 
4
  # Some code here is copied from the project Phi-2 (https://huggingface.co/microsoft/phi-2),
5
  # SigLIP@transformers==4.37.0.dev0 (https://huggingface.co/google/siglip-so400m-patch14-384),
6
  # and Llava (https://github.com/haotian-liu/LLaVA), and modified by
7
+ # Zhenwei Shao ([email protected]) and Xuecheng Ouyang (ouyangxc@hdu.edu.cn) @ MILVLG. We thank them for their great works.
8
  # And their original licenses and copyright should be inherited (see the statements
9
  # in `configuration_imp.py` for more details).
10
 
 
16
  import os
17
  import math
18
  import re
 
19
  from typing import Any, Dict, Optional, Tuple, Union, List
20
  from abc import ABC, abstractmethod
21
 
22
  import torch
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+ import torch.utils.checkpoint
28
  from transformers import (
29
  PretrainedConfig,
30
  PreTrainedModel,
 
32
  AutoModelForCausalLM
33
  )
34
  from transformers.activations import ACT2FN
35
+ from transformers.cache_utils import Cache, DynamicCache
36
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
37
+ from transformers.modeling_outputs import (
38
+ BaseModelOutputWithPast,
39
+ CausalLMOutputWithPast,
40
+ SequenceClassifierOutputWithPast,
41
+ TokenClassifierOutput,
42
+ )
43
+ from transformers.modeling_utils import PreTrainedModel
44
+ from transformers.utils import (
45
+ add_code_sample_docstrings,
46
+ add_start_docstrings,
47
+ add_start_docstrings_to_model_forward,
48
+ is_flash_attn_2_available,
49
+ is_flash_attn_greater_or_equal_2_10,
50
+ logging,
51
+ replace_return_docstrings,
52
+ )
53
  import sys
54
  from .configuration_imp import PhiConfig, ImpConfig
55
  from .vision_encoder import VisionTower
56
 
57
  try:
58
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
59
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
 
60
  except:
61
+ pass
 
 
 
62
 
63
+ logger = logging.get_logger(__name__)
64
 
65
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
66
+ class PhiRotaryEmbedding(nn.Module):
67
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  super().__init__()
69
 
70
+ self.dim = dim
71
+ self.max_position_embeddings = max_position_embeddings
72
+ self.base = base
73
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
74
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
 
 
 
 
 
 
75
 
76
+ # Build here to make `torch.jit.trace` work.
77
+ self._set_cos_sin_cache(
78
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
79
+ )
80
 
81
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
82
+ self.max_seq_len_cached = seq_len
83
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
84
 
85
+ freqs = torch.outer(t, self.inv_freq)
86
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
87
+ emb = torch.cat((freqs, freqs), dim=-1)
88
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
89
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
 
 
 
90
 
91
+ def forward(self, x, seq_len=None):
92
+ # x: [bs, num_attention_heads, seq_len, head_size]
93
+ if seq_len > self.max_seq_len_cached:
94
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
95
 
96
+ return (
97
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
98
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
99
+ )
100
 
 
101
 
102
+ # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
103
+ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
104
+ """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
105
 
106
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
107
+ self.scaling_factor = scaling_factor
108
+ super().__init__(dim, max_position_embeddings, base, device)
109
 
110
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
111
+ self.max_seq_len_cached = seq_len
112
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
113
+ t = t / self.scaling_factor
 
 
 
 
 
 
114
 
115
+ freqs = torch.outer(t, self.inv_freq)
116
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
117
+ emb = torch.cat((freqs, freqs), dim=-1)
118
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
119
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
120
 
 
 
 
121
 
122
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
123
+ class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
124
+ """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
125
 
126
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
127
+ self.scaling_factor = scaling_factor
128
+ super().__init__(dim, max_position_embeddings, base, device)
 
 
 
 
129
 
130
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
131
+ self.max_seq_len_cached = seq_len
132
 
133
+ if seq_len > self.max_position_embeddings:
134
+ base = self.base * (
135
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
136
+ ) ** (self.dim / (self.dim - 2))
137
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
138
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
 
 
 
 
139
 
140
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
 
141
 
142
+ freqs = torch.outer(t, self.inv_freq)
143
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
144
+ emb = torch.cat((freqs, freqs), dim=-1)
145
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
146
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
147
 
 
 
 
 
148
 
149
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
150
+ def rotate_half(x):
151
+ """Rotates half the hidden dims of the input."""
152
+ x1 = x[..., : x.shape[-1] // 2]
153
+ x2 = x[..., x.shape[-1] // 2 :]
154
+ return torch.cat((-x2, x1), dim=-1)
155
 
 
 
 
 
 
 
 
 
156
 
157
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
158
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
159
+ """Applies Rotary Position Embedding to the query and key tensors.
160
 
161
+ Args:
162
+ q (`torch.Tensor`): The query tensor.
163
+ k (`torch.Tensor`): The key tensor.
164
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
165
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
166
+ position_ids (`torch.Tensor`):
167
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
168
+ used to pass offsetted position ids when working with a KV-cache.
169
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
170
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
171
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
172
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
173
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
174
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
175
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
176
+ Returns:
177
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
178
+ """
179
+ temp_type=q.dtype#ouyang modified
180
+ q, k, cos, sin = [t.to(dtype=torch.float32) for t in [q, k, cos, sin]] #ouyang modified
181
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
182
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
183
+ q_embed = (q * cos) + (rotate_half(q) * sin)
184
+ k_embed = (k * cos) + (rotate_half(k) * sin)
185
+ q_embed,k_embed = q_embed.to(temp_type), k_embed.to(temp_type)#ouyang modified
186
+ return q_embed, k_embed
187
 
 
 
 
188
 
 
189
 
190
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
191
+ class PhiMLP(nn.Module):
192
+ def __init__(self, config):
 
 
 
 
 
 
 
193
  super().__init__()
194
+ self.config = config
195
+ self.activation_fn = ACT2FN[config.hidden_act]
196
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
197
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
198
 
199
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
200
+ hidden_states = self.fc1(hidden_states)
201
+ hidden_states = self.activation_fn(hidden_states)
202
+ hidden_states = self.fc2(hidden_states)
203
+ return hidden_states
204
 
 
 
 
 
 
 
205
 
206
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
207
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
208
+ """
209
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
210
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
211
+ """
212
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
213
+ if n_rep == 1:
214
+ return hidden_states
215
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
216
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
217
 
 
 
 
 
 
 
 
218
 
 
 
219
 
220
+ class PhiAttention(nn.Module):
221
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
222
 
223
+ def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
224
+ super().__init__()
225
+ self.config = config
226
+ self.layer_idx = layer_idx
227
+ # if layer_idx is None:
228
+ # logger.warning_once(
229
+ # f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
230
+ # "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
231
+ # "when creating this class."
232
+ # )
233
+
234
+ self.attention_dropout = config.attention_dropout
235
+ self.hidden_size = config.hidden_size
236
+ self.num_heads = config.num_attention_heads
237
+ self.head_dim = self.hidden_size // self.num_heads
238
+ self.num_key_value_heads = config.num_key_value_heads
239
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
240
+ self.max_position_embeddings = config.max_position_embeddings
241
+ self.rope_theta = config.rope_theta
242
+ self.partial_rotary_factor = config.partial_rotary_factor
243
+ self.is_causal = True
244
+
245
+ if (self.head_dim * self.num_heads) != self.hidden_size:
246
+ raise ValueError(
247
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
248
+ f" and `num_heads`: {self.num_heads})."
249
+ )
 
 
 
250
 
251
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
252
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
253
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
254
+ self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
 
255
 
256
+ self.qk_layernorm = config.qk_layernorm
257
+ if self.qk_layernorm:
258
+ self.q_layernorm = nn.LayerNorm(
259
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  )
261
+ self.k_layernorm = nn.LayerNorm(
262
+ config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
 
 
263
  )
264
 
265
+ self._init_rope()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
+ def _init_rope(self):
268
+ if self.config.rope_scaling is None:
269
+ self.rotary_emb = PhiRotaryEmbedding(
270
+ int(self.partial_rotary_factor * self.head_dim),
271
+ max_position_embeddings=self.max_position_embeddings,
272
+ base=self.rope_theta,
273
+ )
274
+ else:
275
+ scaling_type = self.config.rope_scaling["type"]
276
+ scaling_factor = self.config.rope_scaling["factor"]
277
+ if scaling_type == "linear":
278
+ self.rotary_emb = PhiLinearScalingRotaryEmbedding(
279
+ int(self.partial_rotary_factor * self.head_dim),
280
+ max_position_embeddings=self.max_position_embeddings,
281
+ scaling_factor=scaling_factor,
282
+ base=self.rope_theta,
283
+ )
284
+ elif scaling_type == "dynamic":
285
+ self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
286
+ int(self.partial_rotary_factor * self.head_dim),
287
+ max_position_embeddings=self.max_position_embeddings,
288
+ scaling_factor=scaling_factor,
289
+ base=self.rope_theta,
290
+ )
291
+ else:
292
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
 
293
 
294
+ # Phi-2 has an attention overflow issue (with FP16) and requires autocast to be disabled
295
  @torch.autocast("cpu", enabled=False)
296
  @torch.autocast("cuda", enabled=False)
297
  def forward(
298
  self,
299
+ hidden_states: torch.Tensor,
300
+ attention_mask: Optional[torch.Tensor] = None,
301
+ position_ids: Optional[torch.LongTensor] = None,
302
+ past_key_value: Optional[Cache] = None,
303
+ output_attentions: bool = False,
304
+ use_cache: bool = False,
305
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
306
+ bsz, q_len, _ = hidden_states.size()
307
+
 
308
 
309
+ query_states = self.q_proj(hidden_states)
310
+ key_states = self.k_proj(hidden_states)
311
+ value_states = self.v_proj(hidden_states)
312
+
313
+ if self.qk_layernorm:
314
+ query_states = self.q_layernorm(query_states)
315
+ key_states = self.k_layernorm(key_states)
316
+
317
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
318
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
319
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
320
+
321
+ kv_seq_len = key_states.shape[-2]
322
+ if past_key_value is not None:
323
+ if self.layer_idx is None:
324
+ raise ValueError(
325
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
326
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
327
+ "with a layer index."
328
+ )
329
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
330
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
331
 
332
+ # Partial rotary embedding
333
+ query_rot, query_pass = (
334
+ query_states[..., : self.rotary_emb.dim],
335
+ query_states[..., self.rotary_emb.dim :],
336
+ )
337
+ key_rot, key_pass = (
338
+ key_states[..., : self.rotary_emb.dim],
339
+ key_states[..., self.rotary_emb.dim :],
340
+ )
341
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
342
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
343
+
344
+ # [batch_size, seq_length, num_heads, head_dim]
345
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
346
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
347
+
348
+ if past_key_value is not None:
349
+ cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
350
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
351
+
352
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
353
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
354
+
355
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
356
+ # attn_weights = torch.matmul(
357
+ # query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
358
+ # ) / math.sqrt(self.head_dim)
359
+
360
+ softmax_scale = 1.0 / math.sqrt(query_states.shape[-1])
361
+ attn_weights = torch.matmul(
362
+ query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)*softmax_scale
363
+ )#ouyang modified
364
+
365
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
366
+ raise ValueError(
367
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
368
+ f" {attn_weights.size()}"
369
+ )
370
 
371
+ if attention_mask is not None:
372
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
373
+ raise ValueError(
374
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
375
+ )
376
+ attn_weights = attn_weights + attention_mask
377
 
378
+ # upcast attention to fp32
379
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
380
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
381
 
382
+ attn_output = torch.matmul(attn_weights, value_states)
 
 
383
 
384
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
385
+ raise ValueError(
386
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
387
+ f" {attn_output.size()}"
388
+ )
389
 
390
+ attn_output = attn_output.transpose(1, 2).contiguous()
391
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
392
 
 
393
 
394
+ attn_output = self.dense(attn_output)
395
 
396
+ if not output_attentions:
397
+ attn_weights = None
398
 
399
+ return attn_output, attn_weights, past_key_value
 
400
 
401
+ class PhiFlashAttention2(PhiAttention):
402
+ """
403
+ Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
404
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
405
+ flash attention and deal with padding tokens in case the input contains any of them.
406
  """
407
 
408
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
409
+ def __init__(self, *args, **kwargs):
410
+ super().__init__(*args, **kwargs)
 
 
 
 
411
 
412
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
413
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
414
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
415
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
416
 
 
 
417
  def forward(
418
  self,
419
+ hidden_states: torch.Tensor,
420
+ attention_mask: Optional[torch.LongTensor] = None,
421
+ position_ids: Optional[torch.LongTensor] = None,
422
+ past_key_value: Optional[Cache] = None,
423
+ output_attentions: bool = False,
424
+ use_cache: bool = False,
425
  **kwargs,
426
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
427
+ # PhiFlashAttention2 attention does not support output_attentions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
 
429
+ output_attentions = False
430
 
431
+ bsz, q_len, _ = hidden_states.size()
 
432
 
433
+ query_states = self.q_proj(hidden_states)
434
+ key_states = self.k_proj(hidden_states)
435
+ value_states = self.v_proj(hidden_states)
436
 
437
+ if self.qk_layernorm:
438
+ query_states = self.q_layernorm(query_states)
439
+ key_states = self.k_layernorm(key_states)
440
 
441
+ # Flash attention requires the input to have the shape
442
+ # batch_size x seq_length x head_dim x hidden_dim
443
+ # therefore we just need to keep the original shape
444
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
445
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
446
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
447
 
448
+ kv_seq_len = key_states.shape[-2]
449
+ if past_key_value is not None:
450
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
451
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 
 
 
 
 
 
 
452
 
453
+ # Partial rotary embedding
454
+ query_rot, query_pass = (
455
+ query_states[..., : self.rotary_emb.dim],
456
+ query_states[..., self.rotary_emb.dim :],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  )
458
+ key_rot, key_pass = (
459
+ key_states[..., : self.rotary_emb.dim],
460
+ key_states[..., self.rotary_emb.dim :],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  )
462
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
463
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
464
+
465
+ # [batch_size, seq_length, num_heads, head_dim]
466
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
467
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
468
+
469
+ if past_key_value is not None:
470
+ cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
471
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
472
+
473
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
474
+ # to be able to avoid many of these transpose/reshape/view.
475
+ query_states = query_states.transpose(1, 2)
476
+ key_states = key_states.transpose(1, 2)
477
+ value_states = value_states.transpose(1, 2)
478
+
479
+ attn_dropout = self.attention_dropout if self.training else 0.0
480
+
481
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
482
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
483
+ # cast them back in the correct dtype just to be sure everything works as expected.
484
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
485
+ # in fp32.
486
+
487
+ if query_states.dtype == torch.float32:
488
+ if torch.is_autocast_enabled():
489
+ target_dtype = torch.get_autocast_gpu_dtype()
490
+ # Handle the case where the model is quantized
491
+ elif hasattr(self.config, "_pre_quantization_dtype"):
492
+ target_dtype = self.config._pre_quantization_dtype
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  else:
494
+ target_dtype = self.q_proj.weight.dtype
 
 
 
 
 
 
 
 
495
 
496
+ logger.warning_once(
497
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
498
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
499
+ f" {target_dtype}."
500
+ )
 
 
 
 
501
 
502
+ query_states = query_states.to(target_dtype)
503
+ key_states = key_states.to(target_dtype)
504
+ value_states = value_states.to(target_dtype)
505
 
506
+ attn_output = self._flash_attention_forward(
507
+ query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
508
+ )
509
 
510
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
511
+ attn_output = self.dense(attn_output)
 
 
512
 
513
+ if not output_attentions:
514
+ attn_weights = None
515
 
516
+ return attn_output, attn_weights, past_key_value
 
 
517
 
518
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
519
+ def _flash_attention_forward(
520
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
521
+ ):
522
+ """
523
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
524
+ first unpad the input, then computes the attention scores and pad the final attention scores.
525
+
526
+ Args:
527
+ query_states (`torch.Tensor`):
528
+ Input query states to be passed to Flash Attention API
529
+ key_states (`torch.Tensor`):
530
+ Input key states to be passed to Flash Attention API
531
+ value_states (`torch.Tensor`):
532
+ Input value states to be passed to Flash Attention API
533
+ attention_mask (`torch.Tensor`):
534
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
535
+ position of padding tokens and 1 for the position of non-padding tokens.
536
+ dropout (`int`, *optional*):
537
+ Attention dropout
538
+ softmax_scale (`float`, *optional*):
539
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
540
+ """
541
+ if not self._flash_attn_uses_top_left_mask:
542
+ causal = self.is_causal
543
+ else:
544
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
545
+ causal = self.is_causal and query_length != 1
 
 
 
 
 
 
 
 
 
546
 
547
+ # Contains at least one padding token in the sequence
548
+ if attention_mask is not None:
549
+ batch_size = query_states.shape[0]
550
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
551
+ query_states, key_states, value_states, attention_mask, query_length
552
  )
553
 
554
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
555
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
556
+
557
+ attn_output_unpad = flash_attn_varlen_func(
558
+ query_states,
559
+ key_states,
560
+ value_states,
561
+ cu_seqlens_q=cu_seqlens_q,
562
+ cu_seqlens_k=cu_seqlens_k,
563
+ max_seqlen_q=max_seqlen_in_batch_q,
564
+ max_seqlen_k=max_seqlen_in_batch_k,
565
+ dropout_p=dropout,
566
+ softmax_scale=softmax_scale,
567
  causal=causal,
568
  )
569
 
570
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
 
 
 
 
 
 
 
 
 
 
571
  else:
572
+ attn_output = flash_attn_func(
573
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
574
+ )
575
 
576
+ return attn_output
 
 
 
 
 
 
 
 
 
 
 
 
 
577
 
578
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
579
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
580
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
581
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
582
 
583
+ key_layer = index_first_axis(
584
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
585
+ )
586
+ value_layer = index_first_axis(
587
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
588
+ )
589
+ if query_length == kv_seq_len:
590
+ query_layer = index_first_axis(
591
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
592
+ )
593
+ cu_seqlens_q = cu_seqlens_k
594
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
595
+ indices_q = indices_k
596
+ elif query_length == 1:
597
+ max_seqlen_in_batch_q = 1
598
+ cu_seqlens_q = torch.arange(
599
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
600
+ ) # There is a memcpy here, that is very bad.
601
+ indices_q = cu_seqlens_q[:-1]
602
+ query_layer = query_layer.squeeze(1)
603
+ else:
604
+ # The -q_len: slice assumes left padding.
605
+ attention_mask = attention_mask[:, -query_length:]
606
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
607
+
608
+ return (
609
+ query_layer,
610
+ key_layer,
611
+ value_layer,
612
+ indices_q,
613
+ (cu_seqlens_q, cu_seqlens_k),
614
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
615
+ )
616
 
617
 
 
 
618
 
619
+ PHI_ATTENTION_CLASSES = {
620
+ "eager": PhiAttention,
621
+ "flash_attention_2": PhiFlashAttention2,
622
+ }
623
 
 
624
 
625
+ class PhiDecoderLayer(nn.Module):
626
+ def __init__(self, config: PhiConfig, layer_idx: int):
 
 
 
627
  super().__init__()
628
+ self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
629
+ self.mlp = PhiMLP(config)
630
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
631
  self.resid_dropout = nn.Dropout(config.resid_pdrop)
 
 
 
 
632
 
633
  def forward(
634
  self,
635
+ hidden_states: torch.Tensor,
636
+ attention_mask: Optional[torch.Tensor] = None,
637
+ position_ids: Optional[torch.LongTensor] = None,
638
+ output_attentions: Optional[bool] = False,
639
+ use_cache: Optional[bool] = False,
640
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
641
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
642
+ """
643
+ Args:
644
+ hidden_states (`torch.FloatTensor`):
645
+ input to the layer of shape `(batch, seq_len, embed_dim)`
646
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
647
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
648
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
649
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
650
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
651
+ output_attentions (`bool`, *optional*):
652
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
653
+ returned tensors for more detail.
654
+ use_cache (`bool`, *optional*):
655
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
656
+ (see `past_key_values`).
657
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
658
+ """
659
+
660
  residual = hidden_states
 
661
 
662
+ hidden_states = self.input_layernorm(hidden_states)
663
+
664
+ # Self Attention
665
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
666
+ hidden_states=hidden_states,
667
  attention_mask=attention_mask,
668
+ position_ids=position_ids,
669
+ past_key_value=past_key_value,
670
+ output_attentions=output_attentions,
671
+ use_cache=use_cache,
672
  )
 
 
 
673
  attn_outputs = self.resid_dropout(attn_outputs)
 
674
 
675
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
676
  hidden_states = attn_outputs + feed_forward_hidden_states + residual
677
+ outputs = (hidden_states,)
678
 
679
+ if output_attentions:
680
+ outputs += (self_attn_weights,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681
 
682
+ if use_cache:
683
+ outputs += (present_key_value,)
 
684
 
685
+ return outputs
686
 
687
 
688
  class PhiPreTrainedModel(PreTrainedModel):
689
  """Phi pre-trained model."""
690
 
691
  config_class = PhiConfig
692
+ base_model_prefix = "model"
693
  supports_gradient_checkpointing = True
694
+ _no_split_modules = ["PhiDecoderLayer"]
695
+ _skip_keys_device_placement = "past_key_values"
696
+ _supports_flash_attn_2 = True
697
+ _supports_cache_class = True
698
 
699
  def __init__(self, *inputs, **kwargs) -> None:
700
  super().__init__(*inputs, **kwargs)
701
 
702
+ def _init_weights(self, module):
703
+ std = self.config.initializer_range
704
+ if isinstance(module, nn.Linear):
705
+ module.weight.data.normal_(mean=0.0, std=std)
706
  if module.bias is not None:
707
  module.bias.data.zero_()
708
  elif isinstance(module, nn.Embedding):
709
+ module.weight.data.normal_(mean=0.0, std=std)
710
  if module.padding_idx is not None:
711
  module.weight.data[module.padding_idx].zero_()
 
 
 
 
712
 
713
  def prepare_inputs_for_generation(
714
  self,
715
  input_ids: torch.LongTensor,
716
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
717
+ inputs_embeds: Optional[torch.FloatTensor] = None,
718
  attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
719
  **kwargs,
720
  ) -> Dict[str, Any]:
721
+ if past_key_values is not None:
722
+ if isinstance(past_key_values, Cache):
723
+ cache_length = past_key_values.get_seq_length()
724
+ past_length = past_key_values.seen_tokens
725
+ max_cache_length = past_key_values.get_max_length()
726
+ else:
727
+ cache_length = past_length = past_key_values[0][0].shape[2]
728
+ max_cache_length = None
729
+
730
+ # Keep only the unprocessed tokens:
731
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
732
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
733
+ # input)
734
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
735
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
736
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
737
+ # input_ids based on the past_length.
738
+ elif past_length < input_ids.shape[1]:
739
+ input_ids = input_ids[:, past_length:]
740
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
741
+
742
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
743
+ if (
744
+ max_cache_length is not None
745
+ and attention_mask is not None
746
+ and cache_length + input_ids.shape[1] > max_cache_length
747
+ ):
748
+ attention_mask = attention_mask[:, -max_cache_length:]
749
+
750
+ position_ids = kwargs.get("position_ids", None)
751
+ if attention_mask is not None and position_ids is None:
752
+ # create position_ids on the fly for batch generation
753
+ position_ids = attention_mask.long().cumsum(-1) - 1
754
+ position_ids.masked_fill_(attention_mask == 0, 1)
755
+ if past_key_values:
756
+ position_ids = position_ids[:, -input_ids.shape[1] :]
757
+
758
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
759
+ if inputs_embeds is not None and past_key_values is None:
760
+ model_inputs = {"inputs_embeds": inputs_embeds}
761
  else:
762
+ model_inputs = {"input_ids": input_ids}
763
+
764
+ model_inputs.update(
765
+ {
766
+ "position_ids": position_ids,
767
+ "past_key_values": past_key_values,
768
+ "use_cache": kwargs.get("use_cache"),
769
+ "attention_mask": attention_mask,
770
+ }
771
+ )
772
+ return model_inputs
 
 
 
773
 
774
 
775
  class LlavaMetaModel(ABC):
 
814
  class ImpModel(PhiPreTrainedModel, LlavaMetaModel):
815
  """Imp model. This implementation is modified from the implementation of Phi-2"""
816
 
 
 
 
817
 
818
  def __init__(self, config: ImpConfig) -> None:
819
  super().__init__(config)
820
+ self.padding_idx = config.pad_token_id
821
+ self.vocab_size = config.vocab_size
822
+
823
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
824
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
825
+ self.layers = nn.ModuleList(
826
+ [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
827
+ )
828
+ self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
829
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
830
 
 
 
831
  self.gradient_checkpointing = False
832
 
833
  if hasattr(config, "mm_vision_tower"):
 
836
 
837
  self.post_init()
838
 
839
+ # def embed_tokens(self, input_ids: torch.LongTensor) -> torch.FloatTensor: #old
840
+ # return self.embd(input_ids)[0]
841
 
842
  def get_input_embeddings(self) -> nn.Embedding:
843
+ # return self.embd.wte#old
844
+ return self.embed_tokens
845
 
846
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
847
+ self.embed_tokens = value
848
 
849
  def forward(
850
  self,
851
  input_ids: torch.LongTensor,
 
852
  attention_mask: Optional[torch.BoolTensor] = None,
853
+ position_ids: Optional[torch.LongTensor] = None,
854
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
855
+ inputs_embeds: Optional[torch.FloatTensor] = None,
856
+ use_cache: Optional[bool] = None,
857
+ output_attentions: Optional[bool] = None,
858
+ output_hidden_states: Optional[bool] = None,
859
+ return_dict: Optional[bool] = None,
860
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
861
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
862
+ output_hidden_states = (
863
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
864
+ )
865
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
866
 
867
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
868
+
869
+ # retrieve input_ids and inputs_embeds
870
+ if input_ids is not None and inputs_embeds is not None:
871
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
872
+ elif input_ids is not None:
873
+ batch_size, seq_length = input_ids.shape
874
+ elif inputs_embeds is not None:
875
+ batch_size, seq_length, _ = inputs_embeds.shape
876
  else:
877
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
878
+
879
+ past_key_values_length = 0
880
 
881
+ if self.gradient_checkpointing and self.training:
882
+ if use_cache:
883
+ logger.warning_once(
884
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
885
+ )
886
+ use_cache = False
887
+ if use_cache:
888
+ use_legacy_cache = not isinstance(past_key_values, Cache)
889
+ if use_legacy_cache:
890
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
891
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
892
+
893
+
894
 
895
+ if position_ids is None:
896
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
897
+ position_ids = torch.arange(
898
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
899
+ )
900
+ position_ids = position_ids.unsqueeze(0)
901
+
902
+ if inputs_embeds is None:
903
+ inputs_embeds = self.embed_tokens(input_ids)
904
 
905
+ inputs_embeds = self.embed_dropout(inputs_embeds)
906
 
907
+ if self._use_flash_attention_2:
908
+ # 2d mask is passed through the layers
909
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
910
+ else:
911
+ # 4d mask is passed through the layers
912
+ attention_mask = _prepare_4d_causal_attention_mask(
913
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
914
+ )
915
+ hidden_states = inputs_embeds
916
+ # ok
917
+
918
+ # decoder layers
919
+ all_hidden_states = () if output_hidden_states else None
920
+ all_self_attns = () if output_attentions else None
921
+ next_decoder_cache = None
922
+
923
+
924
+ for nums,decoder_layer in enumerate(self.layers):
925
+ if output_hidden_states:
926
+ all_hidden_states += (hidden_states,)
927
+
928
+ if self.gradient_checkpointing and self.training:
929
+ layer_outputs = self._gradient_checkpointing_func(
930
+ decoder_layer.__call__,
931
  hidden_states,
 
932
  attention_mask,
933
+ position_ids,
934
+ past_key_values,
935
+ output_attentions,
936
  )
937
  else:
938
+ layer_outputs = decoder_layer(
939
  hidden_states,
 
940
  attention_mask=attention_mask,
941
+ position_ids=position_ids,
942
+ past_key_value=past_key_values,
943
+ output_attentions=output_attentions,
944
+ use_cache=use_cache,
945
  )
946
+ hidden_states = layer_outputs[0]
947
+
948
+ if use_cache:
949
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
950
+ if output_attentions:
951
+ all_self_attns += (layer_outputs[1],)
952
+
953
+
954
+ hidden_states = self.final_layernorm(hidden_states) #final_new_phi
955
+
956
+ # add hidden states from the last decoder layer
957
+ if output_hidden_states:
958
+ all_hidden_states += (hidden_states,)
959
+
960
+ next_cache = None
961
+ if use_cache:
962
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
963
+ if not return_dict:
964
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
965
+ return BaseModelOutputWithPast(
966
+ last_hidden_state=hidden_states,
967
+ past_key_values=next_cache,
968
+ hidden_states=all_hidden_states,
969
+ attentions=all_self_attns,
970
+ )
971
 
 
 
 
 
 
 
972
 
973
 
974
  class LlavaMetaForCausalLM(ABC):
 
995
  self, input_ids, position_ids, attention_mask, past_key_values, labels, images
996
  ):
997
  vision_tower = self.get_vision_tower()
998
+ if past_key_values is not None:
999
+ target_shape = past_key_values[0][0].shape[2] + 1
1000
+ attention_mask = torch.ones(
1001
+ (attention_mask.shape[0], target_shape),
1002
+ dtype=attention_mask.dtype,
1003
+ device=attention_mask.device
1004
+ )
1005
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
1006
+ # print(input_ids[:, -1:].item())
1007
+ return input_ids[:, -1:], position_ids, attention_mask, past_key_values, None, labels
1008
  if vision_tower is None or images is None or input_ids.shape[1] == 1:
1009
+ # if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
1010
+ # target_shape = past_key_values.seqlen_offset + 1
1011
+ # attention_mask = torch.cat((attention_mask, torch.ones(
1012
+ # (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
1013
+ # dtype=attention_mask.dtype,
1014
+ # device=attention_mask.device
1015
+ # )), dim=1)
1016
+ # position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
1017
+ return input_ids, None, None, past_key_values, None, None
1018
+ # return input_ids, position_ids, attention_mask, past_key_values, None, labels
1019
+
1020
+ # if vision_tower is None or images is None or past_key_values.seqlen_offset != 0:
1021
+ # if vision_tower is None or images is None or input_ids.shape[1] == 1:
1022
+ # if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
1023
+ # target_shape = past_key_values.seqlen_offset + 1
1024
+ # # inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...]
1025
+ # attention_mask = torch.cat((attention_mask, torch.ones(
1026
+ # (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
1027
+ # dtype=attention_mask.dtype,
1028
+ # device=attention_mask.device
1029
+ # )), dim=1)
1030
+ # position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
1031
+ # return input_ids, position_ids, attention_mask, past_key_values, None, labels
1032
 
1033
  if type(images) is list or images.ndim == 5:
1034
  concat_images = torch.cat([image for image in images], dim=0)
 
1160
  position_ids = None
1161
 
1162
  return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
1163
+ #return input_ids, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
1164
 
1165
 
1166
  class ImpForCausalLM(PhiPreTrainedModel, LlavaMetaForCausalLM):
 
1173
  def __init__(self, config: ImpConfig) -> None:
1174
  super().__init__(config)
1175
 
1176
+ self.model = ImpModel(config)
1177
+ self.vocab_size = config.vocab_size
1178
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
1179
 
1180
  self.post_init()
1181
  self.init_constants(config)
1182
 
1183
+ def get_input_embeddings(self):
1184
+ return self.model.embed_tokens
1185
+
1186
+ def set_input_embeddings(self, value):
1187
+ self.model.embed_tokens = value
1188
+
1189
  def get_output_embeddings(self) -> nn.Linear:
1190
+ return self.lm_head
1191
 
1192
  def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
1193
+ self.lm_head = new_embeddings
1194
 
1195
  def get_model(self):
1196
+ return self.model
1197
+
1198
+ def get_decoder(self):
1199
+ return self.model
1200
+
1201
+ def set_decoder(self, decoder):
1202
+ self.model = decoder
1203
 
1204
  def image_preprocess(self, images):
1205
  return self.get_vision_tower().image_processor(images)['pixel_values']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1206
 
1207
  def forward(
1208
  self,
 
1218
  images: Optional[torch.FloatTensor] = None,
1219
  return_dict: Optional[bool] = None,
1220
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1221
+
1222
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1223
+ output_hidden_states = (
1224
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1225
+ )
1226
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1227
 
1228
  if inputs_embeds is None:
1229
  (
 
1242
  images
1243
  )
1244
 
1245
+ outputs = self.model(
1246
  input_ids=input_ids,
1247
+ past_key_values=past_key_values,
1248
  attention_mask=attention_mask,
1249
+ position_ids=position_ids,
 
1250
  inputs_embeds=inputs_embeds,
 
1251
  use_cache=use_cache,
1252
  output_attentions=output_attentions,
1253
  output_hidden_states=output_hidden_states,
1254
  return_dict=return_dict
1255
+ )
1256
+ hidden_states = outputs[0]
1257
+ logits = self.lm_head(hidden_states)
1258
+ logits = logits.float()
1259
+
1260
+ loss = None
1261
+ if labels is not None:
1262
+ # Shift so that tokens < n predict n
1263
+ shift_logits = logits[..., :-1, :].contiguous()
1264
+ shift_labels = labels[..., 1:].contiguous()
1265
+ # Flatten the tokens
1266
+ loss_fct = CrossEntropyLoss()
1267
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1268
+ shift_labels = shift_labels.view(-1)
1269
+ # Enable model parallelism
1270
+ shift_labels = shift_labels.to(shift_logits.device)
1271
+ loss = loss_fct(shift_logits, shift_labels)
1272
+ if not return_dict:
1273
+ loss = None
1274
+ output = (logits,) + outputs[1:]
1275
+ return (loss,) + output if loss is not None else output
1276
+
1277
+ return CausalLMOutputWithPast(
1278
+ loss=loss,
1279
+ logits=logits,
1280
+ past_key_values=outputs.past_key_values,
1281
+ hidden_states=outputs.hidden_states,
1282
+ attentions=outputs.attentions,
1283
  )
1284
 
1285
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
special_tokens_map.json CHANGED
@@ -1,5 +1,23 @@
1
  {
2
- "bos_token": "<|endoftext|>",
3
- "eos_token": "</s>",
4
- "unk_token": "<|endoftext|>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  }
 
1
  {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
  }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json CHANGED
@@ -2,22 +2,6 @@
2
  "add_bos_token": false,
3
  "add_prefix_space": false,
4
  "added_tokens_decoder": {
5
- "50296": {
6
- "content": "<image>",
7
- "lstrip": false,
8
- "normalized": false,
9
- "rstrip": false,
10
- "single_word": false,
11
- "special": true
12
- },
13
- "50295": {
14
- "content": "</s>",
15
- "lstrip": false,
16
- "normalized": false,
17
- "rstrip": false,
18
- "single_word": false,
19
- "special": true
20
- },
21
  "50256": {
22
  "content": "<|endoftext|>",
23
  "lstrip": false,
@@ -329,35 +313,30 @@
329
  "rstrip": false,
330
  "single_word": false,
331
  "special": false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  }
333
  },
334
- "bos_token": {
335
- "__type": "AddedToken",
336
- "content": "<|endoftext|>",
337
- "lstrip": false,
338
- "normalized": true,
339
- "rstrip": false,
340
- "single_word": false
341
- },
342
  "clean_up_tokenization_spaces": true,
343
- "eos_token": {
344
- "__type": "AddedToken",
345
- "content": "<|endoftext|>",
346
- "lstrip": false,
347
- "normalized": true,
348
- "rstrip": false,
349
- "single_word": false
350
- },
351
  "errors": "replace",
352
  "model_max_length": 3072,
353
  "pad_token": null,
354
  "tokenizer_class": "CodeGenTokenizer",
355
- "unk_token": {
356
- "__type": "AddedToken",
357
- "content": "<|endoftext|>",
358
- "lstrip": false,
359
- "normalized": true,
360
- "rstrip": false,
361
- "single_word": false
362
- }
363
  }
 
2
  "add_bos_token": false,
3
  "add_prefix_space": false,
4
  "added_tokens_decoder": {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  "50256": {
6
  "content": "<|endoftext|>",
7
  "lstrip": false,
 
313
  "rstrip": false,
314
  "single_word": false,
315
  "special": false
316
+ },
317
+ "50295": {
318
+ "content": "</s>",
319
+ "lstrip": false,
320
+ "normalized": false,
321
+ "rstrip": false,
322
+ "single_word": false,
323
+ "special": true
324
+ },
325
+ "50296": {
326
+ "content": "<image>",
327
+ "lstrip": false,
328
+ "normalized": false,
329
+ "rstrip": false,
330
+ "single_word": false,
331
+ "special": true
332
  }
333
  },
334
+ "bos_token": "<|endoftext|>",
 
 
 
 
 
 
 
335
  "clean_up_tokenization_spaces": true,
336
+ "eos_token": "<|endoftext|>",
 
 
 
 
 
 
 
337
  "errors": "replace",
338
  "model_max_length": 3072,
339
  "pad_token": null,
340
  "tokenizer_class": "CodeGenTokenizer",
341
+ "unk_token": "<|endoftext|>"
 
 
 
 
 
 
 
342
  }
vocab.json CHANGED
The diff for this file is too large to render. See raw diff