File size: 7,560 Bytes
ab2b3bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import functools
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union
from transformers import PreTrainedModel
from transformers.models.opt.modeling_opt import OPTDecoder
if TYPE_CHECKING:
from peft import PeftModel
def rhasattr(obj: Any, attr: str) -> bool:
"""A chain-able attribute version of hasattr.
For example, to check if
`obj` has the attribute `foo.bar.baz`, you can use:
`rhasattr(obj, "foo.bar.baz")`
Reference: https://stackoverflow.com/a/67303315
"""
_nested_attrs = attr.split('.')
_curr_obj = obj
for _a in _nested_attrs[:-1]:
if hasattr(_curr_obj, _a):
_curr_obj = getattr(_curr_obj, _a)
else:
return False
return hasattr(_curr_obj, _nested_attrs[-1])
def rgetattr(obj: Any, attr: str, *args: List[Any]) -> Any:
"""A chain-able attribute version of getattr.
For example, to get the attribute `foo.bar.baz` from `obj`, you can use:
`rgetattr(obj, "foo.bar.baz")`
Reference: https://stackoverflow.com/a/31174427
"""
def _getattr(obj: Any, attr: str):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
def findattr(obj: Any, attrs: Iterable[str]) -> Optional[Any]:
for attr in attrs:
if rhasattr(obj, attr):
return rgetattr(obj, attr)
return None
def hf_get_causal_base_model(model: PreTrainedModel) -> Any:
"""Returns the causal decoder backbone of the specified HuggingFace model.
Newer HF models have a `self.get_decoder()` method. Older models do not.
NOTE: Different model configurations have different causal decoder attribute
names.
- transformer: (GPT2LMHeadModel, GPTJConfig)
- model.decoder: (OPTConfig, BloomConfig)
- gpt_neox: (GPTNeoXConfig)
"""
if hasattr(model, 'get_decoder'):
return model.get_decoder()
decoder_attrs = ('transformer', 'model.decoder', 'gpt_neox', 'model.transformer')
causal_base_model = findattr(model, decoder_attrs)
if causal_base_model is None:
raise ValueError(f'Unable to FSDP-wrap model {model}. Please open a github issue to add support.')
return causal_base_model
def hf_get_hidden_layers(model: PreTrainedModel) -> Any:
"""Returns the hidden layers of the specified model.
Expects to receive the causal decoder backbone, not he XXForCausalLM wrapper.
NOTE: Different model configurations have different hidden layer attribute names.
- h: (BloomForCausalLM, GPT2LMHeadModel, GPTJForCausalLM)
- decoder.layers: (OPTForCausalLM)
- layers: (GPTNeoXForCausalLM, LlaMaForCausalLM)
- blocks: (MPTForCausalLM)
"""
hidden_layers_attrs = ('h', 'decoder.layers', 'layers', 'block', 'blocks')
layers = findattr(model, hidden_layers_attrs)
if layers is None:
raise ValueError(f'Unable to find hidden layer for {model}. Model must have one of the following attributes: {hidden_layers_attrs}')
return layers
def hf_get_init_device(init_device: Optional[str]) -> Optional[str]:
"""Returns the appropriate device to initialize models."""
if init_device == 'mixed':
if dist.get_local_rank() == 0:
return 'cpu'
return 'meta'
return init_device
def prepare_hf_model_for_fsdp(model: PreTrainedModel, init_device: Optional[str]) -> None:
"""FSDP wrap a HuggingFace model.
Call specific functions
"""
if model.config.is_encoder_decoder:
prepare_hf_enc_dec_model_for_fsdp(model, init_device)
else:
prepare_hf_causal_lm_model_for_fsdp(model, init_device)
def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel, 'PeftModel'], init_device: Optional[str]) -> None:
"""FSDP wrap a HuggingFace decoder.
Wrap any model for FSDP which follows one of the 3 existing conventions from
HuggingFace for decoder-only LLMs.
"""
causal_base_model = hf_get_causal_base_model(model)
if isinstance(causal_base_model, OPTDecoder) or model.config.model_type == 'olmo':
underlying_model = maybe_get_underlying_model(model)
underlying_model.model._fsdp_wrap = False
model_block = hf_get_hidden_layers(causal_base_model)
lm_head = model.get_output_embeddings()
try:
tied_embeddings = causal_base_model.get_input_embeddings()
except:
tied_embeddings = model.get_input_embeddings()
modules = {'base_model': causal_base_model, 'model_block': model_block, 'lm_head': lm_head, 'tied_embeddings': tied_embeddings}
for mod_name, module in modules.items():
if module is None:
raise ValueError(f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + 'follow common layer/weight naming conventions.')
block_type = type(model_block[0])
if model.config.tie_word_embeddings:
causal_base_model._fsdp_wrap = False
tied_embeddings._fsdp_wrap = False
lm_head._fsdp_wrap = False
if hasattr(model, 'peft_type') and model.peft_type is not None:
peft_type = model.peft_type.lower()
active_adapters = [adapter.lower() for adapter in model.active_adapters]
for name, module in model.named_modules():
if peft_type in name.lower() and any((adapter in name.lower() for adapter in active_adapters)):
has_parameters = next(module.parameters(), None) is not None
has_buffers = next(module.buffers(), None) is not None
if has_parameters or has_buffers:
module._fsdp_wrap = True
model.fsdp_wrap_fn = lambda module: isinstance(module, block_type)
model.activation_checkpointing_fn = lambda module: isinstance(module, block_type)
def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel, init_device: Optional[str]) -> None:
"""Wrap an encoder/decoder HF model.
This works for T5, BART, Pegasus, PegasusX, but not all enc/dec (ProphetNet)
You have model.shared, model.encoder, model.decoder and model.lm_head, where
model.shared are the embeddings which are tied to model.lm_head, and
model.shared == model.encoder.embed_tokens and model.shared ==
model.decoder.embed_tokens
"""
tied_embeddings = model.get_input_embeddings()
encoder = model.get_encoder()
decoder = model.get_decoder()
lm_head = model.get_output_embeddings()
encoder_block = hf_get_hidden_layers(encoder)
decoder_block = hf_get_hidden_layers(decoder)
modules = {'encoder': encoder, 'decoder': decoder, 'encoder_block': encoder_block, 'decoder_block': decoder_block, 'lm_head': lm_head, 'tied_embeddings': tied_embeddings}
for mod_name, module in modules.items():
if module is None:
raise ValueError(f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + 'follow common layer/weight naming conventions.')
decoder_block_type = type(decoder_block[0])
encoder_block_type = type(encoder_block[0])
if model.config.tie_word_embeddings:
tied_embeddings._fsdp_wrap = False
encoder._fsdp_wrap = False
decoder._fsdp_wrap = False
lm_head._fsdp_wrap = False
model.fsdp_wrap_fn = lambda module: isinstance(module, decoder_block_type)
model.activation_checkpointing_fn = lambda module: isinstance(module, decoder_block_type)
if encoder_block_type == decoder_block_type:
return
model.fsdp_wrap_fn = lambda module: isinstance(module, encoder_block_type)
model.activation_checkpointing_fn = lambda module: isinstance(module, encoder_block_type) |