|
"""Re-usable :class:`.ComposerModel` for LLM HF Models.""" |
|
from __future__ import annotations |
|
from collections import UserDict |
|
from typing import TYPE_CHECKING, List, Mapping, Optional |
|
import transformers |
|
from torchmetrics import Metric |
|
from transformers import PreTrainedTokenizerBase |
|
from transformers.utils.generic import ModelOutput |
|
from .hf_fsdp import prepare_hf_model_for_fsdp |
|
if TYPE_CHECKING: |
|
from peft import PeftConfig |
|
_HF_IGNORE_INDEX = -100 |
|
|
|
class HuggingFaceModelWithFSDP(HuggingFaceModel): |
|
"""Wrapper around HuggingFaceModel. |
|
|
|
Handles preparation for FSDP wrapping. |
|
""" |
|
|
|
def __init__(self, model: transformers.PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase]=None, metrics: Optional[List[Metric]]=None, eval_metrics: Optional[List[Metric]]=None, shift_labels: bool=False, init_device: Optional[str]=None, peft_config: Optional['PeftConfig']=None): |
|
super().__init__(model, tokenizer, use_logits=True, metrics=metrics, eval_metrics=eval_metrics, shift_labels=shift_labels, peft_config=peft_config, should_save_peft_only=True) |
|
prepare_hf_model_for_fsdp(self.model, init_device) |
|
self.model.param_init_fn = lambda module: self.model._init_weights(module) |
|
|
|
def forward(self, batch: Mapping): |
|
if isinstance(batch, dict) or isinstance(batch, UserDict): |
|
batch = {k: v for k, v in batch.items() if k in self.model_forward_args} |
|
output = self.model(**batch) |
|
else: |
|
raise ValueError('Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model') |
|
return output |
|
|
|
def loss(self, outputs: ModelOutput, batch: Mapping): |
|
if self.config.use_return_dict: |
|
return outputs['loss'] |
|
return outputs[:2] |