fix integration with huggingface
#2
by
not-lain
- opened
- README.md +10 -0
- __init__.py +1 -1
- config.json +2 -2
- configuration_gemma.py +2 -9
- modeling_cerule_gemma.py +9 -9
- requirements.txt +2 -0
README.md
CHANGED
@@ -39,6 +39,16 @@ The training setup was `4xA100's 80GB` and took ~6 hours to pretrain and ~13 hou
|
|
39 |
| ![extreme_ironing](examples/extreme_ironing.jpg) | **What's funny about this image?**<br>The image is quite humorous as it depicts a man ironing clothes on the back of a yellow taxi cab. This is not a typical sight you'd expect to see in everyday life. |
|
40 |
---
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
## Training:
|
44 |
We will release the training code in some time.
|
|
|
39 |
| ![extreme_ironing](examples/extreme_ironing.jpg) | **What's funny about this image?**<br>The image is quite humorous as it depicts a man ironing clothes on the back of a yellow taxi cab. This is not a typical sight you'd expect to see in everyday life. |
|
40 |
---
|
41 |
|
42 |
+
## Loading the model
|
43 |
+
|
44 |
+
```
|
45 |
+
pip install -qr https://huggingface.co/Tensoic/Cerule-v0.1/resolve/main/requirements.txt
|
46 |
+
```
|
47 |
+
|
48 |
+
```python
|
49 |
+
from transformers import AutoModelForCausalLM
|
50 |
+
model = AutoModelForCausalLM.from_pretrained("Tensoic/Cerule-v0.1", trust_remote_code=True)
|
51 |
+
```
|
52 |
|
53 |
## Training:
|
54 |
We will release the training code in some time.
|
__init__.py
CHANGED
@@ -3,5 +3,5 @@ from .modeling_cerule_gemma import CeruleGemmaForCausalLM
|
|
3 |
|
4 |
from transformers import AutoConfig, AutoModelForCausalLM
|
5 |
|
6 |
-
AutoConfig.register("
|
7 |
AutoModelForCausalLM.register(CeruleGemmaConfig, CeruleGemmaForCausalLM)
|
|
|
3 |
|
4 |
from transformers import AutoConfig, AutoModelForCausalLM
|
5 |
|
6 |
+
AutoConfig.register("phi-msft", CeruleGemmaConfig)
|
7 |
AutoModelForCausalLM.register(CeruleGemmaConfig, CeruleGemmaForCausalLM)
|
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "Tensoic/Cerule",
|
3 |
"architectures": [
|
4 |
"CeruleGemmaForCausalLM"
|
5 |
],
|
@@ -23,7 +23,7 @@
|
|
23 |
"mm_projector_lr": null,
|
24 |
"mm_projector_type": "mlp2x_gelu",
|
25 |
"mm_vision_tower": "google/siglip-so400m-patch14-384",
|
26 |
-
"model_type": "
|
27 |
"num_attention_heads": 8,
|
28 |
"num_hidden_layers": 18,
|
29 |
"num_key_value_heads": 1,
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "Tensoic/Cerule-v0.1",
|
3 |
"architectures": [
|
4 |
"CeruleGemmaForCausalLM"
|
5 |
],
|
|
|
23 |
"mm_projector_lr": null,
|
24 |
"mm_projector_type": "mlp2x_gelu",
|
25 |
"mm_vision_tower": "google/siglip-so400m-patch14-384",
|
26 |
+
"model_type": "phi-msft",
|
27 |
"num_attention_heads": 8,
|
28 |
"num_hidden_layers": 18,
|
29 |
"num_key_value_heads": 1,
|
configuration_gemma.py
CHANGED
@@ -25,8 +25,8 @@ GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|
25 |
}
|
26 |
|
27 |
|
28 |
-
class
|
29 |
-
model_type = "
|
30 |
keys_to_ignore_at_inference = ["past_key_values"]
|
31 |
|
32 |
def __init__(
|
@@ -162,10 +162,3 @@ class SigLipVisionConfig(PretrainedConfig):
|
|
162 |
|
163 |
return cls.from_dict(config_dict, **kwargs)
|
164 |
|
165 |
-
|
166 |
-
class CeruleGemmaConfig(GemmaConfig):
|
167 |
-
model_type = "cerule-gemma"
|
168 |
-
|
169 |
-
def __init__(self, **kwargs):
|
170 |
-
self.gemma_config = GemmaConfig(**kwargs)
|
171 |
-
super().__init__(**kwargs)
|
|
|
25 |
}
|
26 |
|
27 |
|
28 |
+
class CeruleGemmaConfig(PretrainedConfig):
|
29 |
+
model_type = "phi-msft"
|
30 |
keys_to_ignore_at_inference = ["past_key_values"]
|
31 |
|
32 |
def __init__(
|
|
|
162 |
|
163 |
return cls.from_dict(config_dict, **kwargs)
|
164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modeling_cerule_gemma.py
CHANGED
@@ -853,7 +853,7 @@ from transformers.utils import (
|
|
853 |
replace_return_docstrings,
|
854 |
)
|
855 |
from transformers.utils.import_utils import is_torch_fx_available
|
856 |
-
from .configuration_gemma import
|
857 |
|
858 |
|
859 |
if is_flash_attn_2_available():
|
@@ -872,7 +872,7 @@ if is_torch_fx_available():
|
|
872 |
|
873 |
logger = logging.get_logger(__name__)
|
874 |
|
875 |
-
_CONFIG_FOR_DOC = "
|
876 |
|
877 |
|
878 |
def _get_unpad_data(attention_mask):
|
@@ -1003,7 +1003,7 @@ class GemmaAttention(nn.Module):
|
|
1003 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
1004 |
|
1005 |
# Ignore copy
|
1006 |
-
def __init__(self, config:
|
1007 |
super().__init__()
|
1008 |
self.config = config
|
1009 |
self.layer_idx = layer_idx
|
@@ -1396,7 +1396,7 @@ GEMMA_ATTENTION_CLASSES = {
|
|
1396 |
|
1397 |
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma
|
1398 |
class GemmaDecoderLayer(nn.Module):
|
1399 |
-
def __init__(self, config:
|
1400 |
super().__init__()
|
1401 |
self.hidden_size = config.hidden_size
|
1402 |
|
@@ -1480,7 +1480,7 @@ GEMMA_START_DOCSTRING = r"""
|
|
1480 |
and behavior.
|
1481 |
|
1482 |
Parameters:
|
1483 |
-
config ([`
|
1484 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
1485 |
load the weights associated with the model, only the configuration. Check out the
|
1486 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
@@ -1492,7 +1492,7 @@ GEMMA_START_DOCSTRING = r"""
|
|
1492 |
GEMMA_START_DOCSTRING,
|
1493 |
)
|
1494 |
class GemmaPreTrainedModel(PreTrainedModel):
|
1495 |
-
config_class =
|
1496 |
base_model_prefix = "model"
|
1497 |
supports_gradient_checkpointing = True
|
1498 |
_keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
|
@@ -1618,7 +1618,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|
1618 |
config: GemmaConfig
|
1619 |
"""
|
1620 |
|
1621 |
-
def __init__(self, config:
|
1622 |
super().__init__(config)
|
1623 |
self.padding_idx = config.pad_token_id
|
1624 |
self.vocab_size = config.vocab_size
|
@@ -2155,7 +2155,7 @@ from .configuration_gemma import CeruleGemmaConfig
|
|
2155 |
class CeruleGemmaModel(CeruleMetaModel, GemmaModel):
|
2156 |
config_class = CeruleGemmaConfig
|
2157 |
|
2158 |
-
def __init__(self, config:
|
2159 |
super(CeruleGemmaModel, self).__init__(config)
|
2160 |
|
2161 |
|
@@ -2264,5 +2264,5 @@ class CeruleGemmaForCausalLM(GemmaForCausalLM, CeruleMetaForCausalLM):
|
|
2264 |
return new_images
|
2265 |
|
2266 |
|
2267 |
-
AutoConfig.register("
|
2268 |
AutoModelForCausalLM.register(CeruleGemmaConfig, CeruleGemmaForCausalLM)
|
|
|
853 |
replace_return_docstrings,
|
854 |
)
|
855 |
from transformers.utils.import_utils import is_torch_fx_available
|
856 |
+
from .configuration_gemma import CeruleGemmaConfig
|
857 |
|
858 |
|
859 |
if is_flash_attn_2_available():
|
|
|
872 |
|
873 |
logger = logging.get_logger(__name__)
|
874 |
|
875 |
+
_CONFIG_FOR_DOC = "CeruleGemmaConfig"
|
876 |
|
877 |
|
878 |
def _get_unpad_data(attention_mask):
|
|
|
1003 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
1004 |
|
1005 |
# Ignore copy
|
1006 |
+
def __init__(self, config: CeruleGemmaConfig, layer_idx: Optional[int] = None):
|
1007 |
super().__init__()
|
1008 |
self.config = config
|
1009 |
self.layer_idx = layer_idx
|
|
|
1396 |
|
1397 |
# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma
|
1398 |
class GemmaDecoderLayer(nn.Module):
|
1399 |
+
def __init__(self, config: CeruleGemmaConfig, layer_idx: int):
|
1400 |
super().__init__()
|
1401 |
self.hidden_size = config.hidden_size
|
1402 |
|
|
|
1480 |
and behavior.
|
1481 |
|
1482 |
Parameters:
|
1483 |
+
config ([`CeruleGemmaConfig`]):
|
1484 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
1485 |
load the weights associated with the model, only the configuration. Check out the
|
1486 |
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
|
1492 |
GEMMA_START_DOCSTRING,
|
1493 |
)
|
1494 |
class GemmaPreTrainedModel(PreTrainedModel):
|
1495 |
+
config_class = CeruleGemmaConfig
|
1496 |
base_model_prefix = "model"
|
1497 |
supports_gradient_checkpointing = True
|
1498 |
_keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
|
|
|
1618 |
config: GemmaConfig
|
1619 |
"""
|
1620 |
|
1621 |
+
def __init__(self, config: CeruleGemmaConfig):
|
1622 |
super().__init__(config)
|
1623 |
self.padding_idx = config.pad_token_id
|
1624 |
self.vocab_size = config.vocab_size
|
|
|
2155 |
class CeruleGemmaModel(CeruleMetaModel, GemmaModel):
|
2156 |
config_class = CeruleGemmaConfig
|
2157 |
|
2158 |
+
def __init__(self, config: CeruleGemmaConfig):
|
2159 |
super(CeruleGemmaModel, self).__init__(config)
|
2160 |
|
2161 |
|
|
|
2264 |
return new_images
|
2265 |
|
2266 |
|
2267 |
+
AutoConfig.register("phi-msft", CeruleGemmaConfig)
|
2268 |
AutoModelForCausalLM.register(CeruleGemmaConfig, CeruleGemmaForCausalLM)
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
flash_attn
|
2 |
+
transformers>=4.39.1
|