from transformers import PretrainedConfig from typing import List class TransmxmConfig(PretrainedConfig): model_type = "transmxm" def __init__( self, dim: int = 128, n_layer: int = 6, cutoff: float = 5.0, num_spherical: int = 7, num_radial: int = 6, envelope_exponent: int = 5, smiles: List[str] = None, processor_class: str = "SmilesProcessor", **kwargs, ): self.dim = dim # the dimension of input feature self.n_layer = n_layer # the number of GCN layers self.cutoff = cutoff # the cutoff distance for neighbor searching self.num_spherical = num_spherical # the number of spherical harmonics self.num_radial = num_radial # the number of radial basis self.envelope_exponent = envelope_exponent # the envelope exponent self.smiles = smiles # process smiles self.processor_class = processor_class super().__init__(**kwargs) if __name__ == "__main__": transmxm_config = TransmxmConfig( dim=128, n_layer=6, cutoff=5.0, num_spherical=7, num_radial=6, envelope_exponent=5, smiles=["C", "CC", "CCC"], processor_class="SmilesProcessor" ) transmxm_config.save_pretrained("custom-transmxm")