youssefabdelmottaleb's picture
Add: SWIN-Transformer-Model-Deployment
9b01bfc
import torch
from transformers import SwinForImageClassification, AutoImageProcessor
from PIL import Image
import matplotlib.pyplot as plt
from transformers import SwinForImageClassification, AutoImageProcessor
class GarbageClassifier:
def __init__(self, model_dir, num_labels=4):
self.labels = ['glass', 'metal', 'paper', 'plastic']
self.model, self.processor = self.load_model_and_processor(model_dir, num_labels)
def load_model_and_processor(self, model_dir, num_labels):
model = SwinForImageClassification.from_pretrained(model_dir, num_labels=num_labels)
processor = AutoImageProcessor.from_pretrained(model_dir)
return model, processor
def evaluate_image(self, image_path):
image = Image.open(image_path)
inputs = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax(-1).item()
plt.imshow(image)
plt.title(f"Predicted Class: {self.labels[predicted_class]}")
plt.axis('off')
plt.show()