Upload folder using huggingface_hub
Browse files- gemma_model.py +31 -4
gemma_model.py
CHANGED
@@ -54,7 +54,7 @@ from transformers.utils import (
|
|
54 |
from .gemma_config import CostWiseGemmaConfig
|
55 |
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2RotaryEmbedding, rotate_half, apply_rotary_pos_emb
|
56 |
from transformers.models.gemma2.modeling_gemma2 import Gemma2MLP, repeat_kv, Gemma2Attention, Gemma2FlashAttention2, Gemma2SdpaAttention, GEMMA2_ATTENTION_CLASSES, Gemma2DecoderLayer, GEMMA2_START_DOCSTRING
|
57 |
-
from transformers.models.gemma2.modeling_gemma2 import
|
58 |
|
59 |
if is_flash_attn_2_available():
|
60 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
@@ -77,6 +77,33 @@ def _get_unpad_data(attention_mask):
|
|
77 |
max_seqlen_in_batch,
|
78 |
)
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
GEMMA2_ATTENTION_CLASSES = {
|
82 |
"eager": Gemma2Attention,
|
@@ -213,7 +240,7 @@ def token_compress(compress_ratio,
|
|
213 |
"The bare Gemma2 Model outputting raw hidden-states without any specific head on top.",
|
214 |
GEMMA2_START_DOCSTRING,
|
215 |
)
|
216 |
-
class CostWiseGemmaModel(
|
217 |
"""
|
218 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`]
|
219 |
|
@@ -466,10 +493,10 @@ class CostWiseHead(nn.Module):
|
|
466 |
return self.linear_head(**kwargs)
|
467 |
|
468 |
|
469 |
-
class CostWiseGemmaForCausalLM(
|
470 |
_tied_weights_keys = ["lm_head.weight"]
|
471 |
|
472 |
-
def __init__(self, config):
|
473 |
super().__init__(config)
|
474 |
self.model = CostWiseGemmaModel(config)
|
475 |
self.vocab_size = config.vocab_size
|
|
|
54 |
from .gemma_config import CostWiseGemmaConfig
|
55 |
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2RotaryEmbedding, rotate_half, apply_rotary_pos_emb
|
56 |
from transformers.models.gemma2.modeling_gemma2 import Gemma2MLP, repeat_kv, Gemma2Attention, Gemma2FlashAttention2, Gemma2SdpaAttention, GEMMA2_ATTENTION_CLASSES, Gemma2DecoderLayer, GEMMA2_START_DOCSTRING
|
57 |
+
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
|
58 |
|
59 |
if is_flash_attn_2_available():
|
60 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
77 |
max_seqlen_in_batch,
|
78 |
)
|
79 |
|
80 |
+
@add_start_docstrings(
|
81 |
+
"The bare Gemma2 Model outputting raw hidden-states without any specific head on top.",
|
82 |
+
GEMMA2_START_DOCSTRING,
|
83 |
+
)
|
84 |
+
class CostWiseGemma2PreTrainedModel(PreTrainedModel):
|
85 |
+
config_class = CostWiseGemmaConfig
|
86 |
+
base_model_prefix = "model"
|
87 |
+
supports_gradient_checkpointing = True
|
88 |
+
_no_split_modules = ["Gemma2DecoderLayer"]
|
89 |
+
_skip_keys_device_placement = ["past_key_values"]
|
90 |
+
_supports_flash_attn_2 = True
|
91 |
+
_supports_sdpa = True
|
92 |
+
_supports_cache_class = False
|
93 |
+
_supports_quantized_cache = False
|
94 |
+
_supports_static_cache = True
|
95 |
+
_is_stateful = True
|
96 |
+
|
97 |
+
def _init_weights(self, module):
|
98 |
+
std = self.config.initializer_range
|
99 |
+
if isinstance(module, nn.Linear):
|
100 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
101 |
+
if module.bias is not None:
|
102 |
+
module.bias.data.zero_()
|
103 |
+
elif isinstance(module, nn.Embedding):
|
104 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
105 |
+
if module.padding_idx is not None:
|
106 |
+
module.weight.data[module.padding_idx].zero_()
|
107 |
|
108 |
GEMMA2_ATTENTION_CLASSES = {
|
109 |
"eager": Gemma2Attention,
|
|
|
240 |
"The bare Gemma2 Model outputting raw hidden-states without any specific head on top.",
|
241 |
GEMMA2_START_DOCSTRING,
|
242 |
)
|
243 |
+
class CostWiseGemmaModel(CostWiseGemma2PreTrainedModel):
|
244 |
"""
|
245 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`]
|
246 |
|
|
|
493 |
return self.linear_head(**kwargs)
|
494 |
|
495 |
|
496 |
+
class CostWiseGemmaForCausalLM(CostWiseGemma2PreTrainedModel):
|
497 |
_tied_weights_keys = ["lm_head.weight"]
|
498 |
|
499 |
+
def __init__(self, config: CostWiseGemmaConfig):
|
500 |
super().__init__(config)
|
501 |
self.model = CostWiseGemmaModel(config)
|
502 |
self.vocab_size = config.vocab_size
|