|
|
|
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel |
|
import torch.nn as nn |
|
import subprocess |
|
|
|
class CustomModelConfig(PretrainedConfig): |
|
model_type = "custom-model" |
|
def __init__(self, hidden_size=128, **kwargs): |
|
super().__init__(**kwargs) |
|
self.hidden_size = hidden_size |
|
|
|
|
|
class CustomModel(PreTrainedModel): |
|
config_class = CustomModelConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.linear = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
|
def forward(self, input_ids): |
|
output = self.linear(input_ids) |
|
return output |
|
|
|
AutoConfig.register("custom-model", CustomModelConfig) |
|
AutoModel.register(CustomModelConfig, CustomModel) |
|
|
|
|
|
|