|
import torch |
|
|
|
from transformers import PreTrainedModel |
|
|
|
from .custom_configuration import CustomConfig, NoSuperInitConfig |
|
|
|
|
|
class CustomModel(PreTrainedModel): |
|
config_class = CustomConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size) |
|
|
|
def forward(self, x): |
|
return self.linear(x) |
|
|
|
def _init_weights(self, module): |
|
pass |
|
|
|
|
|
class NoSuperInitModel(PreTrainedModel): |
|
config_class = NoSuperInitConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.linear = torch.nn.Linear(config.attribute, config.attribute) |
|
|
|
def forward(self, x): |
|
return self.linear(x) |
|
|
|
def _init_weights(self, module): |
|
pass |
|
|