kz919 commited on
Commit
b95f302
1 Parent(s): 4d297e6

Update configuration_sliding_llama.py

Browse files
Files changed (1) hide show
  1. configuration_sliding_llama.py +7 -24
configuration_sliding_llama.py CHANGED
@@ -21,7 +21,7 @@
21
 
22
  from transformers.configuration_utils import PretrainedConfig
23
  from transformers.utils import logging
24
-
25
 
26
  logger = logging.get_logger(__name__)
27
 
@@ -159,37 +159,20 @@ class LlamaConfig(PretrainedConfig):
159
  self.use_cache = use_cache
160
  self.rope_theta = rope_theta
161
  self.rope_scaling = rope_scaling
162
- self._rope_scaling_validation()
163
  self.attention_bias = attention_bias
164
  self.attention_dropout = attention_dropout
165
  self.mlp_bias = mlp_bias
166
  self.sliding_windows = sliding_windows if sliding_windows is not None else [0 for _ in range(num_hidden_layers)]
167
  assert len(self.sliding_windows) == self.num_hidden_layers
168
-
 
 
 
 
169
  super().__init__(
170
  pad_token_id=pad_token_id,
171
  bos_token_id=bos_token_id,
172
  eos_token_id=eos_token_id,
173
  tie_word_embeddings=tie_word_embeddings,
174
  **kwargs,
175
- )
176
-
177
- def _rope_scaling_validation(self):
178
- """
179
- Validate the `rope_scaling` configuration.
180
- """
181
- if self.rope_scaling is None:
182
- return
183
-
184
- if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
185
- raise ValueError(
186
- "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
187
- )
188
- rope_scaling_type = self.rope_scaling.get("type", None)
189
- rope_scaling_factor = self.rope_scaling.get("factor", None)
190
- if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
191
- raise ValueError(
192
- f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
193
- )
194
- if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
195
- raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
 
21
 
22
  from transformers.configuration_utils import PretrainedConfig
23
  from transformers.utils import logging
24
+ from transformers.modeling_rope_utils import rope_config_validation
25
 
26
  logger = logging.get_logger(__name__)
27
 
 
159
  self.use_cache = use_cache
160
  self.rope_theta = rope_theta
161
  self.rope_scaling = rope_scaling
 
162
  self.attention_bias = attention_bias
163
  self.attention_dropout = attention_dropout
164
  self.mlp_bias = mlp_bias
165
  self.sliding_windows = sliding_windows if sliding_windows is not None else [0 for _ in range(num_hidden_layers)]
166
  assert len(self.sliding_windows) == self.num_hidden_layers
167
+
168
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
169
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
170
+ rope_config_validation(self)
171
+
172
  super().__init__(
173
  pad_token_id=pad_token_id,
174
  bos_token_id=bos_token_id,
175
  eos_token_id=eos_token_id,
176
  tie_word_embeddings=tie_word_embeddings,
177
  **kwargs,
178
+ )