File size: 1,251 Bytes
8bf4df8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

import torch
import torchvision
from torch import nn
from torchvision.models._api import WeightsEnum
from torch.hub import load_state_dict_from_url

def create_effnetb2_model(num_classes:int=3, seed:int=42):
  # https://pytorch.org/vision/main/models/generated/torchvision.models.efficientnet_b2.html#torchvision.models.efficientnet_b2

  def get_state_dict(self, *args, **kwargs):
      kwargs.pop("check_hash")
      return load_state_dict_from_url(self.url, *args, **kwargs)
  WeightsEnum.get_state_dict = get_state_dict


  # 1. Setup pretrained EffNetB2 weights
  effnetb2_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT # DEFAULT = BEST

  # 2. Get EffNetB2 transforms
  effnetb2_transforms = effnetb2_weights.transforms()

  # 3. Setup pretrained model instance
  effnetb2_model = torchvision.models.efficientnet_b2(weights=effnetb2_weights)

  # 4. Freeze the base layers in the model
  for param in effnetb2_model.features.parameters():
    param.requires_grad = False

  # 5. Modify the classifier
  torch.manual_seed(seed)
  effnetb2_model.classifier = nn.Sequential(
      nn.Dropout(p=0.3, inplace=True),
      nn.Linear(in_features=1408,out_features=num_classes,bias=True)
  )

  return effnetb2_model, effnetb2_transforms