format
Browse files- modeling_mpt.py +274 -80
modeling_mpt.py
CHANGED
@@ -9,63 +9,90 @@ import torch
|
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
11 |
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
12 |
-
from transformers.modeling_outputs import
|
|
|
|
|
|
|
13 |
from .attention import attn_bias_shape, build_attn_bias
|
14 |
from .blocks import MPTBlock
|
15 |
from .norm import NORM_CLASS_REGISTRY
|
16 |
from .configuration_mpt import MPTConfig
|
17 |
from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
|
18 |
-
from .hf_prefixlm_converter import
|
|
|
|
|
|
|
19 |
from .meta_init_context import init_empty_weights
|
20 |
from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
|
|
|
21 |
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
22 |
|
|
|
23 |
class MPTPreTrainedModel(PreTrainedModel):
|
24 |
config_class = MPTConfig
|
25 |
-
base_model_prefix =
|
26 |
supports_gradient_checkpointing = True
|
27 |
_no_split_modules = []
|
28 |
-
|
29 |
-
class MPTModel(MPTPreTrainedModel):
|
30 |
|
|
|
|
|
31 |
def __init__(self, config: MPTConfig):
|
32 |
config._validate_config()
|
33 |
super().__init__(config)
|
34 |
-
self.attn_impl = config.attn_config[
|
35 |
-
self.prefix_lm = config.attn_config[
|
36 |
-
self.attn_uses_sequence_id = config.attn_config[
|
37 |
-
self.alibi = config.attn_config[
|
38 |
-
self.alibi_bias_max = config.attn_config[
|
39 |
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
|
40 |
-
norm_options =
|
41 |
-
raise NotImplementedError(
|
|
|
|
|
42 |
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
|
43 |
self.embedding_fraction = config.embedding_fraction
|
44 |
-
self.wte = nn.Embedding(
|
|
|
|
|
45 |
if not self.alibi:
|
46 |
-
self.wpe = nn.Embedding(
|
|
|
|
|
47 |
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
48 |
-
self.blocks = nn.ModuleList(
|
|
|
|
|
|
|
|
|
|
|
49 |
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
50 |
-
if config.init_device !=
|
51 |
self.apply(self.param_init_fn)
|
52 |
self.is_causal = not self.prefix_lm
|
53 |
self._attn_bias_initialized = False
|
54 |
self.attn_bias = None
|
55 |
-
self.attn_bias_shape = attn_bias_shape(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
if config.no_bias:
|
57 |
for module in self.modules():
|
58 |
-
if hasattr(module,
|
59 |
if config.verbose:
|
60 |
-
warnings.warn(f
|
61 |
-
module.register_parameter(
|
62 |
if config.verbose and config.verbose > 2:
|
63 |
print(self)
|
64 |
-
if
|
65 |
-
self.config.init_config[
|
66 |
-
if self.config.init_config[
|
67 |
-
init_fn_name = self.config.init_config[
|
68 |
-
warnings.warn(f
|
69 |
|
70 |
def get_input_embeddings(self):
|
71 |
return self.wte
|
@@ -74,13 +101,30 @@ class MPTModel(MPTPreTrainedModel):
|
|
74 |
self.wte = value
|
75 |
|
76 |
@torch.no_grad()
|
77 |
-
def _attn_bias(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
if not self._attn_bias_initialized:
|
79 |
if self.attn_bias_shape:
|
80 |
-
self.attn_bias = torch.zeros(
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
self._attn_bias_initialized = True
|
83 |
-
if self.attn_impl ==
|
84 |
return (self.attn_bias, attention_mask)
|
85 |
if self.attn_bias is not None:
|
86 |
self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
|
@@ -99,58 +143,110 @@ class MPTModel(MPTPreTrainedModel):
|
|
99 |
else:
|
100 |
attn_bias = attn_bias[:, :, :, -s_k:]
|
101 |
if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
|
102 |
-
raise ValueError(
|
|
|
|
|
|
|
103 |
min_val = torch.finfo(attn_bias.dtype).min
|
104 |
-
attn_bias = attn_bias.masked_fill(
|
|
|
|
|
105 |
return (attn_bias, None)
|
106 |
|
107 |
def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
|
108 |
(s_k, s_q) = attn_bias.shape[-2:]
|
109 |
if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
|
110 |
-
raise ValueError(
|
|
|
|
|
|
|
|
|
111 |
seq_len = prefix_mask.shape[-1]
|
112 |
if seq_len > self.config.max_seq_len:
|
113 |
-
raise ValueError(
|
|
|
|
|
114 |
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
115 |
-
causal = torch.tril(
|
|
|
|
|
116 |
prefix = prefix_mask.view(-1, 1, 1, seq_len)
|
117 |
cannot_attend = ~torch.logical_or(causal, prefix.bool())
|
118 |
min_val = torch.finfo(attn_bias.dtype).min
|
119 |
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
120 |
return attn_bias
|
121 |
|
122 |
-
def _apply_sequence_id(
|
|
|
|
|
123 |
seq_len = sequence_id.shape[-1]
|
124 |
if seq_len > self.config.max_seq_len:
|
125 |
-
raise ValueError(
|
|
|
|
|
126 |
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
127 |
-
cannot_attend = torch.logical_not(
|
|
|
|
|
128 |
min_val = torch.finfo(attn_bias.dtype).min
|
129 |
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
130 |
return attn_bias
|
131 |
|
132 |
-
def forward(
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
135 |
if attention_mask is not None:
|
136 |
attention_mask = attention_mask.bool()
|
137 |
if prefix_mask is not None:
|
138 |
prefix_mask = prefix_mask.bool()
|
139 |
if not return_dict:
|
140 |
-
raise NotImplementedError(
|
|
|
|
|
141 |
if output_attentions:
|
142 |
-
raise NotImplementedError(
|
143 |
-
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
if self.prefix_lm and prefix_mask is None:
|
146 |
-
raise ValueError(
|
|
|
|
|
147 |
if self.training:
|
148 |
if self.attn_uses_sequence_id and sequence_id is None:
|
149 |
-
raise ValueError(
|
|
|
|
|
|
|
150 |
elif self.attn_uses_sequence_id is False and sequence_id is not None:
|
151 |
-
warnings.warn(
|
|
|
|
|
|
|
152 |
S = input_ids.size(1)
|
153 |
-
assert
|
|
|
|
|
154 |
tok_emb = self.wte(input_ids)
|
155 |
if self.alibi:
|
156 |
x = tok_emb
|
@@ -158,39 +254,80 @@ class MPTModel(MPTPreTrainedModel):
|
|
158 |
past_position = 0
|
159 |
if past_key_values is not None:
|
160 |
if len(past_key_values) != self.config.n_layers:
|
161 |
-
raise ValueError(
|
|
|
|
|
|
|
162 |
past_position = past_key_values[0][0].size(1)
|
163 |
if S + past_position > self.config.max_seq_len:
|
164 |
-
raise ValueError(
|
165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
if attention_mask is not None:
|
167 |
-
pos = torch.clamp(
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
pos_emb = self.wpe(pos)
|
169 |
x = tok_emb + pos_emb
|
170 |
if self.embedding_fraction == 1:
|
171 |
x = self.emb_drop(x)
|
172 |
else:
|
173 |
-
x_shrunk = x * self.embedding_fraction + x.detach() * (
|
|
|
|
|
174 |
assert isinstance(self.emb_drop, nn.Module)
|
175 |
x = self.emb_drop(x_shrunk)
|
176 |
-
(attn_bias, attention_mask) = self._attn_bias(
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
if use_cache and past_key_values is None:
|
178 |
past_key_values = [() for _ in range(self.config.n_layers)]
|
179 |
all_hidden_states = () if output_hidden_states else None
|
180 |
-
for
|
181 |
if output_hidden_states:
|
182 |
assert all_hidden_states is not None
|
183 |
all_hidden_states = all_hidden_states + (x,)
|
184 |
-
past_key_value =
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
if past_key_values is not None:
|
187 |
past_key_values[b_idx] = past_key_value
|
188 |
x = self.norm_f(x)
|
189 |
-
return BaseModelOutputWithPast(
|
|
|
|
|
|
|
|
|
190 |
|
191 |
def param_init_fn(self, module):
|
192 |
-
init_fn_name = self.config.init_config[
|
193 |
-
MODEL_INIT_REGISTRY[init_fn_name](
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
def fsdp_wrap_fn(self, module):
|
196 |
return isinstance(module, MPTBlock)
|
@@ -198,21 +335,23 @@ class MPTModel(MPTPreTrainedModel):
|
|
198 |
def activation_checkpointing_fn(self, module):
|
199 |
return isinstance(module, MPTBlock)
|
200 |
|
201 |
-
class MPTForCausalLM(MPTPreTrainedModel):
|
202 |
|
|
|
203 |
def __init__(self, config: MPTConfig):
|
204 |
super().__init__(config)
|
205 |
if not config.tie_word_embeddings:
|
206 |
-
raise ValueError(
|
207 |
self.transformer = MPTModel(config)
|
208 |
self.logit_scale = None
|
209 |
if config.logit_scale is not None:
|
210 |
logit_scale = config.logit_scale
|
211 |
if isinstance(logit_scale, str):
|
212 |
-
if logit_scale ==
|
213 |
logit_scale = 1 / math.sqrt(config.d_model)
|
214 |
else:
|
215 |
-
raise ValueError(
|
|
|
|
|
216 |
self.logit_scale = logit_scale
|
217 |
|
218 |
def get_input_embeddings(self):
|
@@ -233,25 +372,63 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
233 |
def get_decoder(self):
|
234 |
return self.transformer
|
235 |
|
236 |
-
def forward(
|
237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
239 |
-
outputs = self.transformer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
|
241 |
if self.logit_scale is not None:
|
242 |
if self.logit_scale == 0:
|
243 |
-
warnings.warn(
|
|
|
|
|
244 |
logits *= self.logit_scale
|
245 |
loss = None
|
246 |
if labels is not None:
|
247 |
labels = torch.roll(labels, shifts=-1)
|
248 |
labels[:, -1] = -100
|
249 |
-
loss = F.cross_entropy(
|
250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
|
252 |
def param_init_fn(self, module):
|
253 |
-
init_fn_name = self.config.init_config[
|
254 |
-
MODEL_INIT_REGISTRY[init_fn_name](
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
def fsdp_wrap_fn(self, module):
|
257 |
return isinstance(module, MPTBlock)
|
@@ -259,12 +436,16 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
259 |
def activation_checkpointing_fn(self, module):
|
260 |
return isinstance(module, MPTBlock)
|
261 |
|
262 |
-
def prepare_inputs_for_generation(
|
|
|
|
|
263 |
if inputs_embeds is not None:
|
264 |
-
raise NotImplementedError(
|
265 |
-
attention_mask = kwargs[
|
266 |
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
|
267 |
-
raise NotImplementedError(
|
|
|
|
|
268 |
if self.transformer.attn_uses_sequence_id and self.training:
|
269 |
sequence_id = torch.zeros_like(input_ids[:1])
|
270 |
else:
|
@@ -273,11 +454,20 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
273 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
274 |
if self.transformer.prefix_lm:
|
275 |
prefix_mask = torch.ones_like(attention_mask)
|
276 |
-
if kwargs.get(
|
277 |
-
raise NotImplementedError(
|
|
|
|
|
278 |
else:
|
279 |
prefix_mask = None
|
280 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
|
282 |
@staticmethod
|
283 |
def _reorder_cache(past_key_values, beam_idx):
|
@@ -288,5 +478,9 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
288 |
"""
|
289 |
reordered_past = []
|
290 |
for layer_past in past_key_values:
|
291 |
-
reordered_past += [
|
292 |
-
|
|
|
|
|
|
|
|
|
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
11 |
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
12 |
+
from transformers.modeling_outputs import (
|
13 |
+
BaseModelOutputWithPast,
|
14 |
+
CausalLMOutputWithPast,
|
15 |
+
)
|
16 |
from .attention import attn_bias_shape, build_attn_bias
|
17 |
from .blocks import MPTBlock
|
18 |
from .norm import NORM_CLASS_REGISTRY
|
19 |
from .configuration_mpt import MPTConfig
|
20 |
from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
|
21 |
+
from .hf_prefixlm_converter import (
|
22 |
+
add_bidirectional_mask_if_missing,
|
23 |
+
convert_hf_causal_lm_to_prefix_lm,
|
24 |
+
)
|
25 |
from .meta_init_context import init_empty_weights
|
26 |
from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
|
27 |
+
|
28 |
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
29 |
|
30 |
+
|
31 |
class MPTPreTrainedModel(PreTrainedModel):
|
32 |
config_class = MPTConfig
|
33 |
+
base_model_prefix = "model"
|
34 |
supports_gradient_checkpointing = True
|
35 |
_no_split_modules = []
|
|
|
|
|
36 |
|
37 |
+
|
38 |
+
class MPTModel(MPTPreTrainedModel):
|
39 |
def __init__(self, config: MPTConfig):
|
40 |
config._validate_config()
|
41 |
super().__init__(config)
|
42 |
+
self.attn_impl = config.attn_config["attn_impl"]
|
43 |
+
self.prefix_lm = config.attn_config["prefix_lm"]
|
44 |
+
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
|
45 |
+
self.alibi = config.attn_config["alibi"]
|
46 |
+
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
47 |
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
|
48 |
+
norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
|
49 |
+
raise NotImplementedError(
|
50 |
+
f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})."
|
51 |
+
)
|
52 |
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
|
53 |
self.embedding_fraction = config.embedding_fraction
|
54 |
+
self.wte = nn.Embedding(
|
55 |
+
config.vocab_size, config.d_model, device=config.init_device
|
56 |
+
)
|
57 |
if not self.alibi:
|
58 |
+
self.wpe = nn.Embedding(
|
59 |
+
config.max_seq_len, config.d_model, device=config.init_device
|
60 |
+
)
|
61 |
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
62 |
+
self.blocks = nn.ModuleList(
|
63 |
+
[
|
64 |
+
MPTBlock(device=config.init_device, **config.to_dict())
|
65 |
+
for _ in range(config.n_layers)
|
66 |
+
]
|
67 |
+
)
|
68 |
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
69 |
+
if config.init_device != "meta":
|
70 |
self.apply(self.param_init_fn)
|
71 |
self.is_causal = not self.prefix_lm
|
72 |
self._attn_bias_initialized = False
|
73 |
self.attn_bias = None
|
74 |
+
self.attn_bias_shape = attn_bias_shape(
|
75 |
+
self.attn_impl,
|
76 |
+
config.n_heads,
|
77 |
+
config.max_seq_len,
|
78 |
+
self.alibi,
|
79 |
+
prefix_lm=self.prefix_lm,
|
80 |
+
causal=self.is_causal,
|
81 |
+
use_sequence_id=self.attn_uses_sequence_id,
|
82 |
+
)
|
83 |
if config.no_bias:
|
84 |
for module in self.modules():
|
85 |
+
if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
|
86 |
if config.verbose:
|
87 |
+
warnings.warn(f"Removing bias ({module.bias}) from {module}.")
|
88 |
+
module.register_parameter("bias", None)
|
89 |
if config.verbose and config.verbose > 2:
|
90 |
print(self)
|
91 |
+
if "verbose" not in self.config.init_config:
|
92 |
+
self.config.init_config["verbose"] = self.config.verbose
|
93 |
+
if self.config.init_config["verbose"] > 1:
|
94 |
+
init_fn_name = self.config.init_config["name"]
|
95 |
+
warnings.warn(f"Using {init_fn_name} initialization.")
|
96 |
|
97 |
def get_input_embeddings(self):
|
98 |
return self.wte
|
|
|
101 |
self.wte = value
|
102 |
|
103 |
@torch.no_grad()
|
104 |
+
def _attn_bias(
|
105 |
+
self,
|
106 |
+
device,
|
107 |
+
dtype,
|
108 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
109 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
110 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
111 |
+
):
|
112 |
if not self._attn_bias_initialized:
|
113 |
if self.attn_bias_shape:
|
114 |
+
self.attn_bias = torch.zeros(
|
115 |
+
self.attn_bias_shape, device=device, dtype=dtype
|
116 |
+
)
|
117 |
+
self.attn_bias = build_attn_bias(
|
118 |
+
self.attn_impl,
|
119 |
+
self.attn_bias,
|
120 |
+
self.config.n_heads,
|
121 |
+
self.config.max_seq_len,
|
122 |
+
causal=self.is_causal,
|
123 |
+
alibi=self.alibi,
|
124 |
+
alibi_bias_max=self.alibi_bias_max,
|
125 |
+
)
|
126 |
self._attn_bias_initialized = True
|
127 |
+
if self.attn_impl == "flash":
|
128 |
return (self.attn_bias, attention_mask)
|
129 |
if self.attn_bias is not None:
|
130 |
self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
|
|
|
143 |
else:
|
144 |
attn_bias = attn_bias[:, :, :, -s_k:]
|
145 |
if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
|
146 |
+
raise ValueError(
|
147 |
+
f"attention_mask shape={attention_mask.shape} "
|
148 |
+
+ f"and prefix_mask shape={prefix_mask.shape} are not equal."
|
149 |
+
)
|
150 |
min_val = torch.finfo(attn_bias.dtype).min
|
151 |
+
attn_bias = attn_bias.masked_fill(
|
152 |
+
~attention_mask.view(-1, 1, 1, s_k), min_val
|
153 |
+
)
|
154 |
return (attn_bias, None)
|
155 |
|
156 |
def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
|
157 |
(s_k, s_q) = attn_bias.shape[-2:]
|
158 |
if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
|
159 |
+
raise ValueError(
|
160 |
+
"attn_bias does not match the expected shape. "
|
161 |
+
+ f"The last two dimensions should both be {self.config.max_length} "
|
162 |
+
+ f"but are {s_k} and {s_q}."
|
163 |
+
)
|
164 |
seq_len = prefix_mask.shape[-1]
|
165 |
if seq_len > self.config.max_seq_len:
|
166 |
+
raise ValueError(
|
167 |
+
f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
|
168 |
+
)
|
169 |
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
170 |
+
causal = torch.tril(
|
171 |
+
torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)
|
172 |
+
).view(1, 1, seq_len, seq_len)
|
173 |
prefix = prefix_mask.view(-1, 1, 1, seq_len)
|
174 |
cannot_attend = ~torch.logical_or(causal, prefix.bool())
|
175 |
min_val = torch.finfo(attn_bias.dtype).min
|
176 |
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
177 |
return attn_bias
|
178 |
|
179 |
+
def _apply_sequence_id(
|
180 |
+
self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor
|
181 |
+
):
|
182 |
seq_len = sequence_id.shape[-1]
|
183 |
if seq_len > self.config.max_seq_len:
|
184 |
+
raise ValueError(
|
185 |
+
f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
|
186 |
+
)
|
187 |
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
188 |
+
cannot_attend = torch.logical_not(
|
189 |
+
torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))
|
190 |
+
).unsqueeze(1)
|
191 |
min_val = torch.finfo(attn_bias.dtype).min
|
192 |
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
193 |
return attn_bias
|
194 |
|
195 |
+
def forward(
|
196 |
+
self,
|
197 |
+
input_ids: torch.LongTensor,
|
198 |
+
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
199 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
200 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
201 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
202 |
+
return_dict: Optional[bool] = None,
|
203 |
+
output_attentions: Optional[bool] = None,
|
204 |
+
output_hidden_states: Optional[bool] = None,
|
205 |
+
use_cache: Optional[bool] = None,
|
206 |
+
):
|
207 |
+
return_dict = (
|
208 |
+
return_dict if return_dict is not None else self.config.return_dict
|
209 |
+
)
|
210 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
211 |
if attention_mask is not None:
|
212 |
attention_mask = attention_mask.bool()
|
213 |
if prefix_mask is not None:
|
214 |
prefix_mask = prefix_mask.bool()
|
215 |
if not return_dict:
|
216 |
+
raise NotImplementedError(
|
217 |
+
"return_dict False is not implemented yet for MPT"
|
218 |
+
)
|
219 |
if output_attentions:
|
220 |
+
raise NotImplementedError(
|
221 |
+
"output_attentions is not implemented yet for MPT"
|
222 |
+
)
|
223 |
+
if (
|
224 |
+
attention_mask is not None
|
225 |
+
and attention_mask[:, 0].sum() != attention_mask.shape[0]
|
226 |
+
and self.training
|
227 |
+
):
|
228 |
+
raise NotImplementedError(
|
229 |
+
"MPT does not support training with left padding."
|
230 |
+
)
|
231 |
if self.prefix_lm and prefix_mask is None:
|
232 |
+
raise ValueError(
|
233 |
+
"prefix_mask is a required argument when MPT is configured with prefix_lm=True."
|
234 |
+
)
|
235 |
if self.training:
|
236 |
if self.attn_uses_sequence_id and sequence_id is None:
|
237 |
+
raise ValueError(
|
238 |
+
"sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True "
|
239 |
+
+ "and the model is in train mode."
|
240 |
+
)
|
241 |
elif self.attn_uses_sequence_id is False and sequence_id is not None:
|
242 |
+
warnings.warn(
|
243 |
+
"MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
|
244 |
+
+ "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
|
245 |
+
)
|
246 |
S = input_ids.size(1)
|
247 |
+
assert (
|
248 |
+
S <= self.config.max_seq_len
|
249 |
+
), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
|
250 |
tok_emb = self.wte(input_ids)
|
251 |
if self.alibi:
|
252 |
x = tok_emb
|
|
|
254 |
past_position = 0
|
255 |
if past_key_values is not None:
|
256 |
if len(past_key_values) != self.config.n_layers:
|
257 |
+
raise ValueError(
|
258 |
+
f"past_key_values must provide a past_key_value for each attention "
|
259 |
+
+ f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
|
260 |
+
)
|
261 |
past_position = past_key_values[0][0].size(1)
|
262 |
if S + past_position > self.config.max_seq_len:
|
263 |
+
raise ValueError(
|
264 |
+
f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
|
265 |
+
)
|
266 |
+
pos = torch.arange(
|
267 |
+
past_position,
|
268 |
+
S + past_position,
|
269 |
+
dtype=torch.long,
|
270 |
+
device=input_ids.device,
|
271 |
+
).unsqueeze(0)
|
272 |
if attention_mask is not None:
|
273 |
+
pos = torch.clamp(
|
274 |
+
pos
|
275 |
+
- torch.cumsum((~attention_mask).to(torch.int32), dim=1)[
|
276 |
+
:, past_position:
|
277 |
+
],
|
278 |
+
min=0,
|
279 |
+
)
|
280 |
pos_emb = self.wpe(pos)
|
281 |
x = tok_emb + pos_emb
|
282 |
if self.embedding_fraction == 1:
|
283 |
x = self.emb_drop(x)
|
284 |
else:
|
285 |
+
x_shrunk = x * self.embedding_fraction + x.detach() * (
|
286 |
+
1 - self.embedding_fraction
|
287 |
+
)
|
288 |
assert isinstance(self.emb_drop, nn.Module)
|
289 |
x = self.emb_drop(x_shrunk)
|
290 |
+
(attn_bias, attention_mask) = self._attn_bias(
|
291 |
+
device=x.device,
|
292 |
+
dtype=x.dtype,
|
293 |
+
attention_mask=attention_mask,
|
294 |
+
prefix_mask=prefix_mask,
|
295 |
+
sequence_id=sequence_id,
|
296 |
+
)
|
297 |
if use_cache and past_key_values is None:
|
298 |
past_key_values = [() for _ in range(self.config.n_layers)]
|
299 |
all_hidden_states = () if output_hidden_states else None
|
300 |
+
for b_idx, block in enumerate(self.blocks):
|
301 |
if output_hidden_states:
|
302 |
assert all_hidden_states is not None
|
303 |
all_hidden_states = all_hidden_states + (x,)
|
304 |
+
past_key_value = (
|
305 |
+
past_key_values[b_idx] if past_key_values is not None else None
|
306 |
+
)
|
307 |
+
(x, past_key_value) = block(
|
308 |
+
x,
|
309 |
+
past_key_value=past_key_value,
|
310 |
+
attn_bias=attn_bias,
|
311 |
+
attention_mask=attention_mask,
|
312 |
+
is_causal=self.is_causal,
|
313 |
+
)
|
314 |
if past_key_values is not None:
|
315 |
past_key_values[b_idx] = past_key_value
|
316 |
x = self.norm_f(x)
|
317 |
+
return BaseModelOutputWithPast(
|
318 |
+
last_hidden_state=x,
|
319 |
+
past_key_values=past_key_values,
|
320 |
+
hidden_states=all_hidden_states,
|
321 |
+
)
|
322 |
|
323 |
def param_init_fn(self, module):
|
324 |
+
init_fn_name = self.config.init_config["name"]
|
325 |
+
MODEL_INIT_REGISTRY[init_fn_name](
|
326 |
+
module=module,
|
327 |
+
n_layers=self.config.n_layers,
|
328 |
+
d_model=self.config.d_model,
|
329 |
+
**self.config.init_config,
|
330 |
+
)
|
331 |
|
332 |
def fsdp_wrap_fn(self, module):
|
333 |
return isinstance(module, MPTBlock)
|
|
|
335 |
def activation_checkpointing_fn(self, module):
|
336 |
return isinstance(module, MPTBlock)
|
337 |
|
|
|
338 |
|
339 |
+
class MPTForCausalLM(MPTPreTrainedModel):
|
340 |
def __init__(self, config: MPTConfig):
|
341 |
super().__init__(config)
|
342 |
if not config.tie_word_embeddings:
|
343 |
+
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
344 |
self.transformer = MPTModel(config)
|
345 |
self.logit_scale = None
|
346 |
if config.logit_scale is not None:
|
347 |
logit_scale = config.logit_scale
|
348 |
if isinstance(logit_scale, str):
|
349 |
+
if logit_scale == "inv_sqrt_d_model":
|
350 |
logit_scale = 1 / math.sqrt(config.d_model)
|
351 |
else:
|
352 |
+
raise ValueError(
|
353 |
+
f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
|
354 |
+
)
|
355 |
self.logit_scale = logit_scale
|
356 |
|
357 |
def get_input_embeddings(self):
|
|
|
372 |
def get_decoder(self):
|
373 |
return self.transformer
|
374 |
|
375 |
+
def forward(
|
376 |
+
self,
|
377 |
+
input_ids: torch.LongTensor,
|
378 |
+
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
379 |
+
attention_mask: Optional[torch.ByteTensor] = None,
|
380 |
+
prefix_mask: Optional[torch.ByteTensor] = None,
|
381 |
+
sequence_id: Optional[torch.LongTensor] = None,
|
382 |
+
labels: Optional[torch.LongTensor] = None,
|
383 |
+
return_dict: Optional[bool] = None,
|
384 |
+
output_attentions: Optional[bool] = None,
|
385 |
+
output_hidden_states: Optional[bool] = None,
|
386 |
+
use_cache: Optional[bool] = None,
|
387 |
+
):
|
388 |
+
return_dict = (
|
389 |
+
return_dict if return_dict is not None else self.config.return_dict
|
390 |
+
)
|
391 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
392 |
+
outputs = self.transformer(
|
393 |
+
input_ids=input_ids,
|
394 |
+
past_key_values=past_key_values,
|
395 |
+
attention_mask=attention_mask,
|
396 |
+
prefix_mask=prefix_mask,
|
397 |
+
sequence_id=sequence_id,
|
398 |
+
return_dict=return_dict,
|
399 |
+
output_attentions=output_attentions,
|
400 |
+
output_hidden_states=output_hidden_states,
|
401 |
+
use_cache=use_cache,
|
402 |
+
)
|
403 |
logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
|
404 |
if self.logit_scale is not None:
|
405 |
if self.logit_scale == 0:
|
406 |
+
warnings.warn(
|
407 |
+
f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
|
408 |
+
)
|
409 |
logits *= self.logit_scale
|
410 |
loss = None
|
411 |
if labels is not None:
|
412 |
labels = torch.roll(labels, shifts=-1)
|
413 |
labels[:, -1] = -100
|
414 |
+
loss = F.cross_entropy(
|
415 |
+
logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
|
416 |
+
)
|
417 |
+
return CausalLMOutputWithPast(
|
418 |
+
loss=loss,
|
419 |
+
logits=logits,
|
420 |
+
past_key_values=outputs.past_key_values,
|
421 |
+
hidden_states=outputs.hidden_states,
|
422 |
+
)
|
423 |
|
424 |
def param_init_fn(self, module):
|
425 |
+
init_fn_name = self.config.init_config["name"]
|
426 |
+
MODEL_INIT_REGISTRY[init_fn_name](
|
427 |
+
module=module,
|
428 |
+
n_layers=self.config.n_layers,
|
429 |
+
d_model=self.config.d_model,
|
430 |
+
**self.config.init_config,
|
431 |
+
)
|
432 |
|
433 |
def fsdp_wrap_fn(self, module):
|
434 |
return isinstance(module, MPTBlock)
|
|
|
436 |
def activation_checkpointing_fn(self, module):
|
437 |
return isinstance(module, MPTBlock)
|
438 |
|
439 |
+
def prepare_inputs_for_generation(
|
440 |
+
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
441 |
+
):
|
442 |
if inputs_embeds is not None:
|
443 |
+
raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
|
444 |
+
attention_mask = kwargs["attention_mask"].bool()
|
445 |
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
|
446 |
+
raise NotImplementedError(
|
447 |
+
"MPT does not support generation with right padding."
|
448 |
+
)
|
449 |
if self.transformer.attn_uses_sequence_id and self.training:
|
450 |
sequence_id = torch.zeros_like(input_ids[:1])
|
451 |
else:
|
|
|
454 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
455 |
if self.transformer.prefix_lm:
|
456 |
prefix_mask = torch.ones_like(attention_mask)
|
457 |
+
if kwargs.get("use_cache") == False:
|
458 |
+
raise NotImplementedError(
|
459 |
+
"MPT with prefix_lm=True does not support use_cache=False."
|
460 |
+
)
|
461 |
else:
|
462 |
prefix_mask = None
|
463 |
+
return {
|
464 |
+
"input_ids": input_ids,
|
465 |
+
"attention_mask": attention_mask,
|
466 |
+
"prefix_mask": prefix_mask,
|
467 |
+
"sequence_id": sequence_id,
|
468 |
+
"past_key_values": past_key_values,
|
469 |
+
"use_cache": kwargs.get("use_cache", True),
|
470 |
+
}
|
471 |
|
472 |
@staticmethod
|
473 |
def _reorder_cache(past_key_values, beam_idx):
|
|
|
478 |
"""
|
479 |
reordered_past = []
|
480 |
for layer_past in past_key_values:
|
481 |
+
reordered_past += [
|
482 |
+
tuple(
|
483 |
+
(past_state.index_select(0, beam_idx) for past_state in layer_past)
|
484 |
+
)
|
485 |
+
]
|
486 |
+
return reordered_past
|