nz commited on
Commit
cd79b59
1 Parent(s): 76ae007

Create rita_configuration.py

Browse files
Files changed (1) hide show
  1. rita_configuration.py +29 -0
rita_configuration.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+ class RITAConfig(PretrainedConfig):
7
+ model_type = "rita"
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size=26,
12
+ d_model=1536,
13
+ num_layers=24,
14
+ max_seq_len=1024,
15
+ num_heads=24,
16
+ dropout=0.,
17
+ ff_ratio=4,
18
+ eos_token_id=2,
19
+ **kwargs,
20
+ ):
21
+ super().__init__(eos_token_id=eos_token_id, **kwargs)
22
+ self.vocab_size = vocab_size
23
+ self.d_model = d_model
24
+ self.num_heads = num_heads
25
+ self.d_feedforward = d_model*ff_ratio
26
+ self.num_layers = num_layers
27
+ self.max_seq_len=max_seq_len
28
+ self.dropout = dropout
29
+ self.eos_token_id=eos_token_id