Update hf_prefixlm_converter.py
Browse files- hf_prefixlm_converter.py +67 -26
hf_prefixlm_converter.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
"""Converts Huggingface Causal LM to Prefix LM.
|
2 |
-
|
3 |
Conversion does lightweight surgery on a HuggingFace
|
4 |
Causal LM to convert it to a Prefix LM.
|
5 |
-
|
6 |
Prefix LMs accepts a `bidirectional_mask` input in `forward`
|
7 |
and treat the input prompt as the prefix in `generate`.
|
8 |
"""
|
@@ -12,29 +10,90 @@ from types import MethodType
|
|
12 |
from typing import Any, List, MutableMapping, Optional, Tuple, Union
|
13 |
import torch
|
14 |
from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
from transformers.models.bloom.modeling_bloom import logging
|
18 |
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
19 |
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
|
20 |
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
|
21 |
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
|
22 |
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
23 |
-
from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
|
24 |
-
from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
logger = logging.get_logger(__name__)
|
26 |
_SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
|
27 |
CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
|
28 |
|
29 |
def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
|
30 |
"""Converts a GPT-style Causal LM to a Prefix LM.
|
31 |
-
|
32 |
Supported HuggingFace model classes:
|
33 |
- `GPT2LMHeadModel`
|
34 |
- `GPTNeoForCausalLM`
|
35 |
- `GPTNeoXForCausalLM`
|
36 |
- `GPTJForCausalLM`
|
37 |
-
|
38 |
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
39 |
"""
|
40 |
if hasattr(model, '_prefix_lm_converted'):
|
@@ -44,7 +103,6 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
|
|
44 |
|
45 |
def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
|
46 |
"""Helper that gets a list of the model's attention modules.
|
47 |
-
|
48 |
Each module has a `bias` buffer used for causal masking. The Prefix LM
|
49 |
conversion adds logic to dynamically manipulate these biases to support
|
50 |
Prefix LM attention masking.
|
@@ -113,10 +171,8 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
|
|
113 |
|
114 |
def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
|
115 |
"""Converts a BLOOM Causal LM to a Prefix LM.
|
116 |
-
|
117 |
Supported HuggingFace model classes:
|
118 |
- `BloomForCausalLM`
|
119 |
-
|
120 |
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
121 |
"""
|
122 |
if hasattr(model, '_prefix_lm_converted'):
|
@@ -270,10 +326,8 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
|
|
270 |
|
271 |
def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
|
272 |
"""Converts an OPT Causal LM to a Prefix LM.
|
273 |
-
|
274 |
Supported HuggingFace model classes:
|
275 |
- `OPTForCausalLM`
|
276 |
-
|
277 |
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
278 |
"""
|
279 |
if hasattr(model, '_prefix_lm_converted'):
|
@@ -339,7 +393,6 @@ CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPT
|
|
339 |
|
340 |
def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
|
341 |
"""Converts a HuggingFace Causal LM to a Prefix LM.
|
342 |
-
|
343 |
Supported HuggingFace model classes:
|
344 |
- `GPT2LMHeadModel`
|
345 |
- `GPTNeoForCausalLM`
|
@@ -347,49 +400,38 @@ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES
|
|
347 |
- `GPTJForCausalLM`
|
348 |
- `BloomForCausalLM`
|
349 |
- `OPTForCausalLM`
|
350 |
-
|
351 |
Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
|
352 |
`generate` method and/or select underlying methods depending on the model class.
|
353 |
-
|
354 |
These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask".
|
355 |
-
|
356 |
Notes on training:
|
357 |
To actually train the converted model as a Prefix LM, training batches will need to indicate
|
358 |
the prefix/target structure by including `bidirectional_mask` as part of the batch inputs.
|
359 |
-
|
360 |
**This is not a standard input and requires custom layers either within or after your dataloader.**
|
361 |
-
|
362 |
In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels`
|
363 |
such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`.
|
364 |
That is, the prefix portion of the sequence should not generate any loss. Loss should only be
|
365 |
generated by the target portion of the sequence.
|
366 |
-
|
367 |
Notes on `GPTNeoForCausalLM`:
|
368 |
To simplify the implementation, "global" and "local" attention layers are handled differently.
|
369 |
For "global" layers, we handle conversion as described above. For "local" layers, which use a
|
370 |
causal attention mask within a restricted local window, we do not alter the masking.
|
371 |
-
|
372 |
Notes on `forward` method conversion:
|
373 |
After conversion, the `forward` method will handle a new input, `bidirectional_mask`,
|
374 |
which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions
|
375 |
belonging to the prefix (prefix tokens can attend to one another bidirectionally), and
|
376 |
0 indicates token positions belonging to the target.
|
377 |
-
|
378 |
The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing
|
379 |
causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset
|
380 |
the causal masks before returning the result.
|
381 |
-
|
382 |
Notes on `generate` method conversion:
|
383 |
After conversion, the `generate` method will have the same signature but will internally
|
384 |
convert all causal masks to be purely bidirectional, call the original `generate` method, and
|
385 |
(where appropriate) reset the causal masks before returning the result.
|
386 |
-
|
387 |
This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token
|
388 |
"prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates
|
389 |
each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one
|
390 |
another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and
|
391 |
previously-generated tokens (also as expected in a Prefix LM).
|
392 |
-
|
393 |
To preserve the API, the original methods are renamed to `_original_forward` and
|
394 |
`_original_generate`, and replaced with new `forward` and `generate` methods that wrap
|
395 |
them, respectively. Although implementation details vary by model class.
|
@@ -405,7 +447,6 @@ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES
|
|
405 |
|
406 |
def add_bidirectional_mask_if_missing(batch: MutableMapping):
|
407 |
"""Attempts to add bidirectional_mask to batch if missing.
|
408 |
-
|
409 |
Raises:
|
410 |
KeyError if bidirectional_mask is missing and can't be inferred
|
411 |
"""
|
|
|
1 |
"""Converts Huggingface Causal LM to Prefix LM.
|
|
|
2 |
Conversion does lightweight surgery on a HuggingFace
|
3 |
Causal LM to convert it to a Prefix LM.
|
|
|
4 |
Prefix LMs accepts a `bidirectional_mask` input in `forward`
|
5 |
and treat the input prompt as the prefix in `generate`.
|
6 |
"""
|
|
|
10 |
from typing import Any, List, MutableMapping, Optional, Tuple, Union
|
11 |
import torch
|
12 |
from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss
|
13 |
+
|
14 |
+
#depreciated
|
15 |
+
#from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
|
16 |
+
def _expand_mask_bloom(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
|
17 |
+
"""
|
18 |
+
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
|
19 |
+
"""
|
20 |
+
batch_size, src_length = mask.shape
|
21 |
+
tgt_length = tgt_length if tgt_length is not None else src_length
|
22 |
+
|
23 |
+
expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
|
24 |
+
return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
|
25 |
+
|
26 |
+
#from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom
|
27 |
+
|
28 |
+
def _make_causal_mask_bloom(
|
29 |
+
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
30 |
+
) -> torch.BoolTensor:
|
31 |
+
"""
|
32 |
+
Make causal mask used for self-attention.
|
33 |
+
"""
|
34 |
+
batch_size, target_length = input_ids_shape
|
35 |
+
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
36 |
+
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
|
37 |
+
seq_ids = torch.arange(target_length, device=device)
|
38 |
+
mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
|
39 |
+
|
40 |
+
if past_key_values_length > 0:
|
41 |
+
mask[:, :past_key_values_length] = False
|
42 |
+
|
43 |
+
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
44 |
+
return expanded_mask
|
45 |
+
|
46 |
from transformers.models.bloom.modeling_bloom import logging
|
47 |
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
48 |
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
|
49 |
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
|
50 |
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
|
51 |
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
52 |
+
#from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
|
53 |
+
#from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
|
54 |
+
|
55 |
+
def _make_causal_mask_opt(
|
56 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
57 |
+
):
|
58 |
+
"""
|
59 |
+
Make causal mask used for bi-directional self-attention.
|
60 |
+
"""
|
61 |
+
bsz, tgt_len = input_ids_shape
|
62 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
63 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
64 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
65 |
+
mask = mask.to(dtype)
|
66 |
+
|
67 |
+
if past_key_values_length > 0:
|
68 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
69 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
70 |
+
|
71 |
+
|
72 |
+
def _expand_mask_opt(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
73 |
+
"""
|
74 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
75 |
+
"""
|
76 |
+
bsz, src_len = mask.size()
|
77 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
78 |
+
|
79 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
80 |
+
|
81 |
+
inverted_mask = 1.0 - expanded_mask
|
82 |
+
|
83 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
84 |
+
|
85 |
+
|
86 |
logger = logging.get_logger(__name__)
|
87 |
_SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
|
88 |
CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
|
89 |
|
90 |
def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
|
91 |
"""Converts a GPT-style Causal LM to a Prefix LM.
|
|
|
92 |
Supported HuggingFace model classes:
|
93 |
- `GPT2LMHeadModel`
|
94 |
- `GPTNeoForCausalLM`
|
95 |
- `GPTNeoXForCausalLM`
|
96 |
- `GPTJForCausalLM`
|
|
|
97 |
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
98 |
"""
|
99 |
if hasattr(model, '_prefix_lm_converted'):
|
|
|
103 |
|
104 |
def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
|
105 |
"""Helper that gets a list of the model's attention modules.
|
|
|
106 |
Each module has a `bias` buffer used for causal masking. The Prefix LM
|
107 |
conversion adds logic to dynamically manipulate these biases to support
|
108 |
Prefix LM attention masking.
|
|
|
171 |
|
172 |
def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
|
173 |
"""Converts a BLOOM Causal LM to a Prefix LM.
|
|
|
174 |
Supported HuggingFace model classes:
|
175 |
- `BloomForCausalLM`
|
|
|
176 |
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
177 |
"""
|
178 |
if hasattr(model, '_prefix_lm_converted'):
|
|
|
326 |
|
327 |
def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
|
328 |
"""Converts an OPT Causal LM to a Prefix LM.
|
|
|
329 |
Supported HuggingFace model classes:
|
330 |
- `OPTForCausalLM`
|
|
|
331 |
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
332 |
"""
|
333 |
if hasattr(model, '_prefix_lm_converted'):
|
|
|
393 |
|
394 |
def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
|
395 |
"""Converts a HuggingFace Causal LM to a Prefix LM.
|
|
|
396 |
Supported HuggingFace model classes:
|
397 |
- `GPT2LMHeadModel`
|
398 |
- `GPTNeoForCausalLM`
|
|
|
400 |
- `GPTJForCausalLM`
|
401 |
- `BloomForCausalLM`
|
402 |
- `OPTForCausalLM`
|
|
|
403 |
Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
|
404 |
`generate` method and/or select underlying methods depending on the model class.
|
|
|
405 |
These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask".
|
|
|
406 |
Notes on training:
|
407 |
To actually train the converted model as a Prefix LM, training batches will need to indicate
|
408 |
the prefix/target structure by including `bidirectional_mask` as part of the batch inputs.
|
|
|
409 |
**This is not a standard input and requires custom layers either within or after your dataloader.**
|
|
|
410 |
In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels`
|
411 |
such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`.
|
412 |
That is, the prefix portion of the sequence should not generate any loss. Loss should only be
|
413 |
generated by the target portion of the sequence.
|
|
|
414 |
Notes on `GPTNeoForCausalLM`:
|
415 |
To simplify the implementation, "global" and "local" attention layers are handled differently.
|
416 |
For "global" layers, we handle conversion as described above. For "local" layers, which use a
|
417 |
causal attention mask within a restricted local window, we do not alter the masking.
|
|
|
418 |
Notes on `forward` method conversion:
|
419 |
After conversion, the `forward` method will handle a new input, `bidirectional_mask`,
|
420 |
which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions
|
421 |
belonging to the prefix (prefix tokens can attend to one another bidirectionally), and
|
422 |
0 indicates token positions belonging to the target.
|
|
|
423 |
The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing
|
424 |
causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset
|
425 |
the causal masks before returning the result.
|
|
|
426 |
Notes on `generate` method conversion:
|
427 |
After conversion, the `generate` method will have the same signature but will internally
|
428 |
convert all causal masks to be purely bidirectional, call the original `generate` method, and
|
429 |
(where appropriate) reset the causal masks before returning the result.
|
|
|
430 |
This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token
|
431 |
"prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates
|
432 |
each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one
|
433 |
another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and
|
434 |
previously-generated tokens (also as expected in a Prefix LM).
|
|
|
435 |
To preserve the API, the original methods are renamed to `_original_forward` and
|
436 |
`_original_generate`, and replaced with new `forward` and `generate` methods that wrap
|
437 |
them, respectively. Although implementation details vary by model class.
|
|
|
447 |
|
448 |
def add_bidirectional_mask_if_missing(batch: MutableMapping):
|
449 |
"""Attempts to add bidirectional_mask to batch if missing.
|
|
|
450 |
Raises:
|
451 |
KeyError if bidirectional_mask is missing and can't be inferred
|
452 |
"""
|