|
import torch |
|
from torchvision import transforms as tt |
|
from PIL import Image |
|
import cv2 |
|
|
|
def predict_potato(image_path, model): |
|
|
|
|
|
transforms = tt.Compose([ |
|
tt.Resize((224, 224)), |
|
tt.ToTensor() |
|
]) |
|
|
|
image = Image.open(image_path).convert("RGB") |
|
|
|
|
|
image_tensor = transforms(image).unsqueeze(0) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(image_tensor) |
|
|
|
|
|
probabilities = torch.nn.functional.softmax(output[0], dim=0) |
|
|
|
|
|
predicted_class = torch.argmax(probabilities).item() |
|
|
|
|
|
predicted_probability = probabilities[predicted_class].item() |
|
|
|
|
|
class_labels = ['Potato Early Blight', 'Potato Late Blight', 'Potato Healthy'] |
|
|
|
return class_labels[predicted_class], predicted_probability, image |