|
import functools |
|
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union |
|
from transformers import PreTrainedModel |
|
from transformers.models.opt.modeling_opt import OPTDecoder |
|
if TYPE_CHECKING: |
|
from peft import PeftModel |
|
|
|
def rhasattr(obj: Any, attr: str) -> bool: |
|
"""A chain-able attribute version of hasattr. |
|
|
|
For example, to check if |
|
`obj` has the attribute `foo.bar.baz`, you can use: |
|
`rhasattr(obj, "foo.bar.baz")` |
|
Reference: https://stackoverflow.com/a/67303315 |
|
""" |
|
_nested_attrs = attr.split('.') |
|
_curr_obj = obj |
|
for _a in _nested_attrs[:-1]: |
|
if hasattr(_curr_obj, _a): |
|
_curr_obj = getattr(_curr_obj, _a) |
|
else: |
|
return False |
|
return hasattr(_curr_obj, _nested_attrs[-1]) |
|
|
|
def rgetattr(obj: Any, attr: str, *args: List[Any]) -> Any: |
|
"""A chain-able attribute version of getattr. |
|
|
|
For example, to get the attribute `foo.bar.baz` from `obj`, you can use: |
|
`rgetattr(obj, "foo.bar.baz")` |
|
Reference: https://stackoverflow.com/a/31174427 |
|
""" |
|
|
|
def _getattr(obj: Any, attr: str): |
|
return getattr(obj, attr, *args) |
|
return functools.reduce(_getattr, [obj] + attr.split('.')) |
|
|
|
def findattr(obj: Any, attrs: Iterable[str]) -> Optional[Any]: |
|
for attr in attrs: |
|
if rhasattr(obj, attr): |
|
return rgetattr(obj, attr) |
|
return None |
|
|
|
def hf_get_causal_base_model(model: PreTrainedModel) -> Any: |
|
"""Returns the causal decoder backbone of the specified HuggingFace model. |
|
|
|
Newer HF models have a `self.get_decoder()` method. Older models do not. |
|
|
|
NOTE: Different model configurations have different causal decoder attribute |
|
names. |
|
- transformer: (GPT2LMHeadModel, GPTJConfig) |
|
- model.decoder: (OPTConfig, BloomConfig) |
|
- gpt_neox: (GPTNeoXConfig) |
|
""" |
|
if hasattr(model, 'get_decoder'): |
|
return model.get_decoder() |
|
decoder_attrs = ('transformer', 'model.decoder', 'gpt_neox', 'model.transformer') |
|
causal_base_model = findattr(model, decoder_attrs) |
|
if causal_base_model is None: |
|
raise ValueError(f'Unable to FSDP-wrap model {model}. Please open a github issue to add support.') |
|
return causal_base_model |
|
|
|
def hf_get_hidden_layers(model: PreTrainedModel) -> Any: |
|
"""Returns the hidden layers of the specified model. |
|
|
|
Expects to receive the causal decoder backbone, not he XXForCausalLM wrapper. |
|
|
|
NOTE: Different model configurations have different hidden layer attribute names. |
|
- h: (BloomForCausalLM, GPT2LMHeadModel, GPTJForCausalLM) |
|
- decoder.layers: (OPTForCausalLM) |
|
- layers: (GPTNeoXForCausalLM, LlaMaForCausalLM) |
|
- blocks: (MPTForCausalLM) |
|
""" |
|
hidden_layers_attrs = ('h', 'decoder.layers', 'layers', 'block', 'blocks') |
|
layers = findattr(model, hidden_layers_attrs) |
|
if layers is None: |
|
raise ValueError(f'Unable to find hidden layer for {model}. Model must have one of the following attributes: {hidden_layers_attrs}') |
|
return layers |
|
|
|
def hf_get_init_device(init_device: Optional[str]) -> Optional[str]: |
|
"""Returns the appropriate device to initialize models.""" |
|
if init_device == 'mixed': |
|
if dist.get_local_rank() == 0: |
|
return 'cpu' |
|
return 'meta' |
|
return init_device |
|
|
|
def prepare_hf_model_for_fsdp(model: PreTrainedModel, init_device: Optional[str]) -> None: |
|
"""FSDP wrap a HuggingFace model. |
|
|
|
Call specific functions |
|
""" |
|
if model.config.is_encoder_decoder: |
|
prepare_hf_enc_dec_model_for_fsdp(model, init_device) |
|
else: |
|
prepare_hf_causal_lm_model_for_fsdp(model, init_device) |
|
|
|
def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel, 'PeftModel'], init_device: Optional[str]) -> None: |
|
"""FSDP wrap a HuggingFace decoder. |
|
|
|
Wrap any model for FSDP which follows one of the 3 existing conventions from |
|
HuggingFace for decoder-only LLMs. |
|
""" |
|
causal_base_model = hf_get_causal_base_model(model) |
|
if isinstance(causal_base_model, OPTDecoder) or model.config.model_type == 'olmo': |
|
underlying_model = maybe_get_underlying_model(model) |
|
underlying_model.model._fsdp_wrap = False |
|
model_block = hf_get_hidden_layers(causal_base_model) |
|
lm_head = model.get_output_embeddings() |
|
try: |
|
tied_embeddings = causal_base_model.get_input_embeddings() |
|
except: |
|
tied_embeddings = model.get_input_embeddings() |
|
modules = {'base_model': causal_base_model, 'model_block': model_block, 'lm_head': lm_head, 'tied_embeddings': tied_embeddings} |
|
for mod_name, module in modules.items(): |
|
if module is None: |
|
raise ValueError(f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + 'follow common layer/weight naming conventions.') |
|
block_type = type(model_block[0]) |
|
if model.config.tie_word_embeddings: |
|
causal_base_model._fsdp_wrap = False |
|
tied_embeddings._fsdp_wrap = False |
|
lm_head._fsdp_wrap = False |
|
if hasattr(model, 'peft_type') and model.peft_type is not None: |
|
peft_type = model.peft_type.lower() |
|
active_adapters = [adapter.lower() for adapter in model.active_adapters] |
|
for name, module in model.named_modules(): |
|
if peft_type in name.lower() and any((adapter in name.lower() for adapter in active_adapters)): |
|
has_parameters = next(module.parameters(), None) is not None |
|
has_buffers = next(module.buffers(), None) is not None |
|
if has_parameters or has_buffers: |
|
module._fsdp_wrap = True |
|
model.fsdp_wrap_fn = lambda module: isinstance(module, block_type) |
|
model.activation_checkpointing_fn = lambda module: isinstance(module, block_type) |
|
|
|
def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel, init_device: Optional[str]) -> None: |
|
"""Wrap an encoder/decoder HF model. |
|
|
|
This works for T5, BART, Pegasus, PegasusX, but not all enc/dec (ProphetNet) |
|
You have model.shared, model.encoder, model.decoder and model.lm_head, where |
|
model.shared are the embeddings which are tied to model.lm_head, and |
|
model.shared == model.encoder.embed_tokens and model.shared == |
|
model.decoder.embed_tokens |
|
""" |
|
tied_embeddings = model.get_input_embeddings() |
|
encoder = model.get_encoder() |
|
decoder = model.get_decoder() |
|
lm_head = model.get_output_embeddings() |
|
encoder_block = hf_get_hidden_layers(encoder) |
|
decoder_block = hf_get_hidden_layers(decoder) |
|
modules = {'encoder': encoder, 'decoder': decoder, 'encoder_block': encoder_block, 'decoder_block': decoder_block, 'lm_head': lm_head, 'tied_embeddings': tied_embeddings} |
|
for mod_name, module in modules.items(): |
|
if module is None: |
|
raise ValueError(f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + 'follow common layer/weight naming conventions.') |
|
decoder_block_type = type(decoder_block[0]) |
|
encoder_block_type = type(encoder_block[0]) |
|
if model.config.tie_word_embeddings: |
|
tied_embeddings._fsdp_wrap = False |
|
encoder._fsdp_wrap = False |
|
decoder._fsdp_wrap = False |
|
lm_head._fsdp_wrap = False |
|
model.fsdp_wrap_fn = lambda module: isinstance(module, decoder_block_type) |
|
model.activation_checkpointing_fn = lambda module: isinstance(module, decoder_block_type) |
|
if encoder_block_type == decoder_block_type: |
|
return |
|
model.fsdp_wrap_fn = lambda module: isinstance(module, encoder_block_type) |
|
model.activation_checkpointing_fn = lambda module: isinstance(module, encoder_block_type) |