import torch from torch import nn import torchvision from torchvision import models def model_classification(): weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT model = models.efficientnet_b0(weights=weights) tranforms = models.EfficientNet_B0_Weights.DEFAULT.transforms() model.classifier[1] = nn.Linear(1280,2) for params in model.parameters(): params.requires_grad=False return model,tranforms