Commit
•
4434bf3
1
Parent(s):
95fd08c
change rotary base (#31)
Browse files- feat: rotary base as a property (c1200891411b6198ca6448cfebf5123d15bf2c31)
- Merge branch 'main' into pr/31 (c2ead96805f8278295d48fda36eba1d96ed3bffb)
Co-authored-by: Jack Min Ong <[email protected]>
- configuration_xlm_roberta.py +2 -0
- modeling_lora.py +8 -0
- modeling_xlm_roberta.py +14 -3
- rotary.py +12 -1
configuration_xlm_roberta.py
CHANGED
@@ -20,6 +20,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
20 |
bos_token_id=0,
|
21 |
eos_token_id=2,
|
22 |
position_embedding_type="absolute",
|
|
|
23 |
use_cache=True,
|
24 |
classifier_dropout=None,
|
25 |
lora_adaptations=None,
|
@@ -52,6 +53,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
|
|
52 |
self.initializer_range = initializer_range
|
53 |
self.layer_norm_eps = layer_norm_eps
|
54 |
self.position_embedding_type = position_embedding_type
|
|
|
55 |
self.use_cache = use_cache
|
56 |
self.classifier_dropout = classifier_dropout
|
57 |
self.load_trained_adapters = load_trained_adapters
|
|
|
20 |
bos_token_id=0,
|
21 |
eos_token_id=2,
|
22 |
position_embedding_type="absolute",
|
23 |
+
rotary_emb_base=10000.0,
|
24 |
use_cache=True,
|
25 |
classifier_dropout=None,
|
26 |
lora_adaptations=None,
|
|
|
53 |
self.initializer_range = initializer_range
|
54 |
self.layer_norm_eps = layer_norm_eps
|
55 |
self.position_embedding_type = position_embedding_type
|
56 |
+
self.rotary_emb_base = rotary_emb_base
|
57 |
self.use_cache = use_cache
|
58 |
self.classifier_dropout = classifier_dropout
|
59 |
self.load_trained_adapters = load_trained_adapters
|
modeling_lora.py
CHANGED
@@ -262,6 +262,14 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
262 |
self.main_params_trainable = config.lora_main_params_trainable
|
263 |
|
264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
@property
|
266 |
def main_params_trainable(self):
|
267 |
return self._main_params_trainable
|
|
|
262 |
self.main_params_trainable = config.lora_main_params_trainable
|
263 |
|
264 |
|
265 |
+
@property
|
266 |
+
def rotary_emb_base(self):
|
267 |
+
return self.roberta.rotary_emb_base
|
268 |
+
|
269 |
+
@rotary_emb_base.setter
|
270 |
+
def rotary_emb_base(self, base):
|
271 |
+
self.roberta.rotary_emb_base = base
|
272 |
+
|
273 |
@property
|
274 |
def main_params_trainable(self):
|
275 |
return self._main_params_trainable
|
modeling_xlm_roberta.py
CHANGED
@@ -93,7 +93,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
93 |
rotary_kwargs["rotary_emb_dim"] = getattr(
|
94 |
config, "rotary_emb_dim", config.hidden_size / config.num_attention_heads
|
95 |
)
|
96 |
-
rotary_kwargs["rotary_emb_base"] =
|
97 |
rotary_kwargs["rotary_emb_scale_base"] = getattr(
|
98 |
config, "rotary_emb_scale_base", None
|
99 |
)
|
@@ -450,6 +450,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
450 |
|
451 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
452 |
self.tokenizer = AutoTokenizer.from_pretrained(self.name_or_path, trust_remote_code=True)
|
|
|
453 |
|
454 |
@torch.inference_mode()
|
455 |
def encode(
|
@@ -599,7 +600,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
599 |
self.train(is_training)
|
600 |
return all_embeddings
|
601 |
|
602 |
-
|
603 |
def truncate_embeddings(self, embeddings, truncate_dim):
|
604 |
if not self.config.matryoshka_dimensions:
|
605 |
logger.warning(
|
@@ -622,12 +622,23 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
622 |
input_mask_expanded.sum(1), min=1e-9
|
623 |
)
|
624 |
|
625 |
-
|
626 |
def cls_pooling(
|
627 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
628 |
):
|
629 |
return token_embeddings[:,0]
|
630 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
631 |
|
632 |
def forward(
|
633 |
self,
|
|
|
93 |
rotary_kwargs["rotary_emb_dim"] = getattr(
|
94 |
config, "rotary_emb_dim", config.hidden_size / config.num_attention_heads
|
95 |
)
|
96 |
+
rotary_kwargs["rotary_emb_base"] = config.rotary_emb_base
|
97 |
rotary_kwargs["rotary_emb_scale_base"] = getattr(
|
98 |
config, "rotary_emb_scale_base", None
|
99 |
)
|
|
|
450 |
|
451 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
452 |
self.tokenizer = AutoTokenizer.from_pretrained(self.name_or_path, trust_remote_code=True)
|
453 |
+
self._rotary_emb_base = config.rotary_emb_base
|
454 |
|
455 |
@torch.inference_mode()
|
456 |
def encode(
|
|
|
600 |
self.train(is_training)
|
601 |
return all_embeddings
|
602 |
|
|
|
603 |
def truncate_embeddings(self, embeddings, truncate_dim):
|
604 |
if not self.config.matryoshka_dimensions:
|
605 |
logger.warning(
|
|
|
622 |
input_mask_expanded.sum(1), min=1e-9
|
623 |
)
|
624 |
|
|
|
625 |
def cls_pooling(
|
626 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
627 |
):
|
628 |
return token_embeddings[:,0]
|
629 |
|
630 |
+
@property
|
631 |
+
def rotary_emb_base(self):
|
632 |
+
return self._rotary_emb_base
|
633 |
+
|
634 |
+
@rotary_emb_base.setter
|
635 |
+
def rotary_emb_base(self, base):
|
636 |
+
if not isinstance(base, (int, float)):
|
637 |
+
raise TypeError("Base must be an integer or float")
|
638 |
+
logger.info(f'Changing RoPE base value to {base}')
|
639 |
+
for layer in self.encoder.layers:
|
640 |
+
layer.mixer.rotary_emb.base = base
|
641 |
+
self._rotary_emb_base = base
|
642 |
|
643 |
def forward(
|
644 |
self,
|
rotary.py
CHANGED
@@ -443,7 +443,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
443 |
"""
|
444 |
super().__init__()
|
445 |
self.dim = dim
|
446 |
-
self.
|
447 |
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
448 |
# Generate and save the inverse frequency buffer (non trainable)
|
449 |
inv_freq = self._compute_inv_freq(device)
|
@@ -463,6 +463,17 @@ class RotaryEmbedding(torch.nn.Module):
|
|
463 |
self._cos_k_cached = None
|
464 |
self._sin_k_cached = None
|
465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
def _compute_inv_freq(self, device=None):
|
467 |
return 1.0 / (
|
468 |
self.base
|
|
|
443 |
"""
|
444 |
super().__init__()
|
445 |
self.dim = dim
|
446 |
+
self._base = float(base)
|
447 |
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
448 |
# Generate and save the inverse frequency buffer (non trainable)
|
449 |
inv_freq = self._compute_inv_freq(device)
|
|
|
463 |
self._cos_k_cached = None
|
464 |
self._sin_k_cached = None
|
465 |
|
466 |
+
@property
|
467 |
+
def base(self):
|
468 |
+
return self._base
|
469 |
+
|
470 |
+
@base.setter
|
471 |
+
def base(self, new_base):
|
472 |
+
if new_base > 0:
|
473 |
+
self._base = float(new_base)
|
474 |
+
else:
|
475 |
+
raise ValueError("Rotary base value must be positive")
|
476 |
+
|
477 |
def _compute_inv_freq(self, device=None):
|
478 |
return 1.0 / (
|
479 |
self.base
|