balakrish181's picture
first-commit
d957918
raw
history blame contribute delete
462 Bytes
import torch
from torch import nn
import torchvision
from torchvision import models
def model_classification():
weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT
model = models.efficientnet_b0(weights=weights)
tranforms = models.EfficientNet_B0_Weights.DEFAULT.transforms()
model.classifier[1] = nn.Linear(1280,2)
for params in model.parameters():
params.requires_grad=False
return model,tranforms