from transformers import PreTrainedModel from .configuration_avGFP import avGFPConfig from evo_prot_grad.models import OneHotCNN class avGFPModel(PreTrainedModel): config_class = avGFPConfig def __init__(self, config): super().__init__(config) self.model = OneHotCNN( vocab_size=config.vocab_size, kernel_size=config.kernel_size, input_size=config.input_size ) def forward(self, x): return self.model(x)