File size: 935 Bytes
e58b0a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from transformers import PreTrainedModel
#from timm.models.resnet import BasicBlock, Bottleneck, ResNet
from .configuration_scgpt import ScgptConfig


#BLOCK_MAPPING = {"basic": BasicBlock, "bottleneck": Bottleneck}


class ScgptModel(PreTrainedModel):
    config_class = ScgptConfig

    def __init__(self, config):
        super().__init__(config)
        #block_layer = BLOCK_MAPPING[config.block_type]
        #self.model = ScgptModel(
        #    block_layer,
        #    config.layers,
        #    num_classes=config.num_classes,
        #    in_chans=config.input_channels,
        #    cardinality=config.cardinality,
        #    base_width=config.base_width,
        #    stem_width=config.stem_width,
        #    stem_type=config.stem_type,
        #    avg_down=config.avg_down,
        #)
        self.model = None

    def forward(self, tensor):
        #return self.model.forward_features(tensor)
        return None