scgpt / modeling_scgpt.py
agemagician's picture
Create modeling_scgpt.py
e58b0a2 verified
raw
history blame contribute delete
935 Bytes
from transformers import PreTrainedModel
#from timm.models.resnet import BasicBlock, Bottleneck, ResNet
from .configuration_scgpt import ScgptConfig
#BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck}
class ScgptModel(PreTrainedModel):
config_class = ScgptConfig
def __init__(self, config):
super().__init__(config)
#block_layer = BLOCK_MAPPING[config.block_type]
#self.model = ScgptModel(
# block_layer,
# config.layers,
# num_classes=config.num_classes,
# in_chans=config.input_channels,
# cardinality=config.cardinality,
# base_width=config.base_width,
# stem_width=config.stem_width,
# stem_type=config.stem_type,
# avg_down=config.avg_down,
#)
self.model = None
def forward(self, tensor):
#return self.model.forward_features(tensor)
return None