File size: 935 Bytes
e58b0a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from transformers import PreTrainedModel
#from timm.models.resnet import BasicBlock, Bottleneck, ResNet
from .configuration_scgpt import ScgptConfig
#BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck}
class ScgptModel(PreTrainedModel):
config_class = ScgptConfig
def __init__(self, config):
super().__init__(config)
#block_layer = BLOCK_MAPPING[config.block_type]
#self.model = ScgptModel(
# block_layer,
# config.layers,
# num_classes=config.num_classes,
# in_chans=config.input_channels,
# cardinality=config.cardinality,
# base_width=config.base_width,
# stem_width=config.stem_width,
# stem_type=config.stem_type,
# avg_down=config.avg_down,
#)
self.model = None
def forward(self, tensor):
#return self.model.forward_features(tensor)
return None |