π― # Image Classification Model for Medical Waste Classification
This is an image classification model trained to classify medical waste into 4 categories, namely cytotoxic, infectious, pathological, and pharmaceutical. The model is based on the Inception v3 architecture and has been adapted to a specific dataset for the task of medical waste classification.
π― Model Description
The model is based on the Inception v3 architecture with modifications to the fully connected layers for adapting it to the specific image classification task. The architecture consists of a feature extractor followed by a global average pooling layer and fully connected layers with ReLU activation and dropout.
π― Usage
You can use the model that I have saved in pt format as follows:
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
def predict_image(image_path, model, transform, class_names):
# Load the image
image = Image.open(image_path)
# Apply transformations
image = transform(image).unsqueeze(0) # Add batch dimension
# Set the model to evaluation mode
model.eval()
# Make predictions
with torch.no_grad():
outputs = model(image.to(device))
_, predicted = torch.max(outputs, 1)
predicted_class = predicted.item()
predicted_label = class_names[predicted_class]
probabilities = torch.softmax(outputs, dim=1)[0]
confidence = probabilities[predicted_class].item()
return predicted_class, predicted_label, confidence
# Define transformation to be applied to the input image
image_transform = transforms.Compose([
transforms.Resize((299, 299)), # Resize to match InceptionV3 input size
transforms.ToTensor(),
# You can add more transformations such as normalization if needed
])
# Load the trained model
model = torch.load('__directory where you save the model__')
model.to(device)
# Load class names (assuming you have a list of class names)
class_names = ['cytotoxic', 'infectious', 'pathological', 'pharmaceutical']
# Provide the path to the image you want to predict
image_path = '__the directory where you store the images you want to classify__'
# Load the true label (assuming you have it)
true_label = 'pathological'
# Predict the class label
predicted_class, predicted_label, confidence = predict_image(image_path, model, image_transform, class_names)
# Display the image
image = Image.open(image_path)
plt.imshow(np.array(image))
plt.axis('off')
plt.title(f'True Class: {true_label} \n Predicted Class: {predicted_label} (Confidence: {confidence*100:.2f}%)')
plt.show()