|
from transformers import PretrainedConfig |
|
from typing import List |
|
|
|
|
|
class MXMConfig(PretrainedConfig): |
|
model_type = "mxm" |
|
|
|
def __init__( |
|
self, |
|
dim: int=128, |
|
n_layer: int=6, |
|
cutoff: float=5.0, |
|
num_spherical: int=7, |
|
num_radial: int=6, |
|
envelope_exponent: int=5, |
|
|
|
smiles: List[str] = None, |
|
processor_class: str = "SmilesProcessor", |
|
**kwargs, |
|
): |
|
|
|
self.dim = dim |
|
self.n_layer = n_layer |
|
self.cutoff = cutoff |
|
self.num_spherical = num_spherical |
|
self.num_radial = num_radial |
|
self.envelope_exponent = envelope_exponent |
|
|
|
self.smiles = smiles |
|
self.processor_class = processor_class |
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
if __name__ == "__main__": |
|
mxm_config = MXMConfig( |
|
dim=128, |
|
n_layer=6, |
|
cutoff=5.0, |
|
num_spherical=7, |
|
num_radial=6, |
|
envelope_exponent=5, |
|
smiles=["C", "CC", "CCC"], |
|
processor_class="SmilesProcessor" |
|
) |
|
mxm_config.save_pretrained("custom-mxm") |