File size: 6,923 Bytes
27140ac d849f5b 27140ac 28ff2e4 27140ac e87428b 27140ac e87428b 27140ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# -*- coding: utf-8 -*-
"""StripedHyena custom code port for the Hugging Face Hub"""
import torch
import functools
from torch.nn import functional as F
from .configuration_hyena import StripedHyenaConfig
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
from transformers.utils import logging
from typing import Optional, Tuple, Union
from .model import StripedHyena
from .utils import dotdict
from .cache import InferenceParams
from .engine import HyenaInferenceEngine
from .layers import RMSNorm
from .utils import dotdict, column_split
logger = logging.get_logger(__name__)
class StripedHyenaPreTrainedModel(PreTrainedModel):
config_class = StripedHyenaConfig
base_model_prefix = "sh"
supports_gradient_checkpointing = False
_no_split_modules = ["AttentionBlock", "ParallelGatedConvBlock"]
_skip_keys_device_placement = "past_key_values"
_keys_to_ignore_on_load_missing = [r"freq"]
_keys_to_ignore_on_load_unexpected = [r"fftconv", r"twiddle_factors"]
_supports_flash_attn_2 = False
class StripedHyenaModelForCausalLM(StripedHyenaPreTrainedModel):
supports_gradient_checkpointing = True
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
model_config = dotdict(config.to_dict())
self.backbone = StripedHyena(model_config)
self.backbone.gradient_checkpointing = False
self.config = config
vocab_size = config.vocab_size
if vocab_size % config.make_vocab_size_divisible_by != 0:
vocab_size += config.make_vocab_size_divisible_by - (
vocab_size % config.make_vocab_size_divisible_by
)
self.vocab_size = vocab_size
self.post_init()
self.force_dtype()
def force_dtype(self):
self.backbone.to_bfloat16_except_poles_residues()
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
# TODO support deepspeed checkpoint
gradient_checkpointing_func = functools.partial(
torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs
)
self._set_gradient_checkpointing(
enable=True, gradient_checkpointing_func=gradient_checkpointing_func
)
if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()
def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
self.backbone.gradient_checkpointing = enable
self.backbone._gradient_checkpointing_func = gradient_checkpointing_func
def get_input_embeddings(self):
return self.backbone.embedding_layer
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
past_key_values=None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache:
if self.backbone.gradient_checkpointing and self.backbone.training:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
elif labels is not None:
logger.warning_once(
"`use_cache=True` is incompatible with loss calculation. Setting `use_cache=False`..."
)
use_cache = False
inputs = input_ids
if use_cache:
if past_key_values is None:
past_key_values = self.backbone.initialize_inference_params()
batch_size = input_ids.shape[0]
past_key_values["mha"].max_batch_size = batch_size
past_key_values["hyena"].max_batch_size = batch_size
else:
seqlen_offset = past_key_values["mha"].seqlen_offset
if seqlen_offset == 0:
# second loop through generate will have prompt_len + 1 as seqlen
seqlen_offset = input_ids.shape[-1] - 1
past_key_values["hyena"].seqlen_offset = seqlen_offset
past_key_values["mha"].seqlen_offset = seqlen_offset
else:
past_key_values["mha"].seqlen_offset += 1
past_key_values["hyena"].seqlen_offset += 1
inputs = input_ids[
:,
-1:,
]
logits, past_key_values = self.backbone(
inputs,
padding_mask=attention_mask,
inference_params_dict=past_key_values if use_cache else None,
)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = F.cross_entropy(shift_logits, shift_labels)
if return_dict:
return CausalLMOutputWithPast(
logits=logits,
hidden_states=None,
past_key_values=past_key_values if use_cache else None,
loss=loss,
)
else:
return logits
@classmethod
def can_generate(cls) -> bool:
return True
def prepare_inputs_for_generation(
self, input_ids, attention_mask=None, past_key_values=None, **kwargs
):
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
}
|