jupyterjazz Jackmin108 commited on
Commit
4434bf3
1 Parent(s): 95fd08c

change rotary base (#31)

Browse files

- feat: rotary base as a property (c1200891411b6198ca6448cfebf5123d15bf2c31)
- Merge branch 'main' into pr/31 (c2ead96805f8278295d48fda36eba1d96ed3bffb)


Co-authored-by: Jack Min Ong <[email protected]>

configuration_xlm_roberta.py CHANGED
@@ -20,6 +20,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
20
  bos_token_id=0,
21
  eos_token_id=2,
22
  position_embedding_type="absolute",
 
23
  use_cache=True,
24
  classifier_dropout=None,
25
  lora_adaptations=None,
@@ -52,6 +53,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
52
  self.initializer_range = initializer_range
53
  self.layer_norm_eps = layer_norm_eps
54
  self.position_embedding_type = position_embedding_type
 
55
  self.use_cache = use_cache
56
  self.classifier_dropout = classifier_dropout
57
  self.load_trained_adapters = load_trained_adapters
 
20
  bos_token_id=0,
21
  eos_token_id=2,
22
  position_embedding_type="absolute",
23
+ rotary_emb_base=10000.0,
24
  use_cache=True,
25
  classifier_dropout=None,
26
  lora_adaptations=None,
 
53
  self.initializer_range = initializer_range
54
  self.layer_norm_eps = layer_norm_eps
55
  self.position_embedding_type = position_embedding_type
56
+ self.rotary_emb_base = rotary_emb_base
57
  self.use_cache = use_cache
58
  self.classifier_dropout = classifier_dropout
59
  self.load_trained_adapters = load_trained_adapters
modeling_lora.py CHANGED
@@ -262,6 +262,14 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
262
  self.main_params_trainable = config.lora_main_params_trainable
263
 
264
 
 
 
 
 
 
 
 
 
265
  @property
266
  def main_params_trainable(self):
267
  return self._main_params_trainable
 
262
  self.main_params_trainable = config.lora_main_params_trainable
263
 
264
 
265
+ @property
266
+ def rotary_emb_base(self):
267
+ return self.roberta.rotary_emb_base
268
+
269
+ @rotary_emb_base.setter
270
+ def rotary_emb_base(self, base):
271
+ self.roberta.rotary_emb_base = base
272
+
273
  @property
274
  def main_params_trainable(self):
275
  return self._main_params_trainable
modeling_xlm_roberta.py CHANGED
@@ -93,7 +93,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
93
  rotary_kwargs["rotary_emb_dim"] = getattr(
94
  config, "rotary_emb_dim", config.hidden_size / config.num_attention_heads
95
  )
96
- rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
97
  rotary_kwargs["rotary_emb_scale_base"] = getattr(
98
  config, "rotary_emb_scale_base", None
99
  )
@@ -450,6 +450,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
450
 
451
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
452
  self.tokenizer = AutoTokenizer.from_pretrained(self.name_or_path, trust_remote_code=True)
 
453
 
454
  @torch.inference_mode()
455
  def encode(
@@ -599,7 +600,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
599
  self.train(is_training)
600
  return all_embeddings
601
 
602
-
603
  def truncate_embeddings(self, embeddings, truncate_dim):
604
  if not self.config.matryoshka_dimensions:
605
  logger.warning(
@@ -622,12 +622,23 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
622
  input_mask_expanded.sum(1), min=1e-9
623
  )
624
 
625
-
626
  def cls_pooling(
627
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
628
  ):
629
  return token_embeddings[:,0]
630
 
 
 
 
 
 
 
 
 
 
 
 
 
631
 
632
  def forward(
633
  self,
 
93
  rotary_kwargs["rotary_emb_dim"] = getattr(
94
  config, "rotary_emb_dim", config.hidden_size / config.num_attention_heads
95
  )
96
+ rotary_kwargs["rotary_emb_base"] = config.rotary_emb_base
97
  rotary_kwargs["rotary_emb_scale_base"] = getattr(
98
  config, "rotary_emb_scale_base", None
99
  )
 
450
 
451
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
452
  self.tokenizer = AutoTokenizer.from_pretrained(self.name_or_path, trust_remote_code=True)
453
+ self._rotary_emb_base = config.rotary_emb_base
454
 
455
  @torch.inference_mode()
456
  def encode(
 
600
  self.train(is_training)
601
  return all_embeddings
602
 
 
603
  def truncate_embeddings(self, embeddings, truncate_dim):
604
  if not self.config.matryoshka_dimensions:
605
  logger.warning(
 
622
  input_mask_expanded.sum(1), min=1e-9
623
  )
624
 
 
625
  def cls_pooling(
626
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
627
  ):
628
  return token_embeddings[:,0]
629
 
630
+ @property
631
+ def rotary_emb_base(self):
632
+ return self._rotary_emb_base
633
+
634
+ @rotary_emb_base.setter
635
+ def rotary_emb_base(self, base):
636
+ if not isinstance(base, (int, float)):
637
+ raise TypeError("Base must be an integer or float")
638
+ logger.info(f'Changing RoPE base value to {base}')
639
+ for layer in self.encoder.layers:
640
+ layer.mixer.rotary_emb.base = base
641
+ self._rotary_emb_base = base
642
 
643
  def forward(
644
  self,
rotary.py CHANGED
@@ -443,7 +443,7 @@ class RotaryEmbedding(torch.nn.Module):
443
  """
444
  super().__init__()
445
  self.dim = dim
446
- self.base = float(base)
447
  self.pos_idx_in_fp32 = pos_idx_in_fp32
448
  # Generate and save the inverse frequency buffer (non trainable)
449
  inv_freq = self._compute_inv_freq(device)
@@ -463,6 +463,17 @@ class RotaryEmbedding(torch.nn.Module):
463
  self._cos_k_cached = None
464
  self._sin_k_cached = None
465
 
 
 
 
 
 
 
 
 
 
 
 
466
  def _compute_inv_freq(self, device=None):
467
  return 1.0 / (
468
  self.base
 
443
  """
444
  super().__init__()
445
  self.dim = dim
446
+ self._base = float(base)
447
  self.pos_idx_in_fp32 = pos_idx_in_fp32
448
  # Generate and save the inverse frequency buffer (non trainable)
449
  inv_freq = self._compute_inv_freq(device)
 
463
  self._cos_k_cached = None
464
  self._sin_k_cached = None
465
 
466
+ @property
467
+ def base(self):
468
+ return self._base
469
+
470
+ @base.setter
471
+ def base(self, new_base):
472
+ if new_base > 0:
473
+ self._base = float(new_base)
474
+ else:
475
+ raise ValueError("Rotary base value must be positive")
476
+
477
  def _compute_inv_freq(self, device=None):
478
  return 1.0 / (
479
  self.base