import torch from transformers import SwinForImageClassification, AutoImageProcessor from PIL import Image import matplotlib.pyplot as plt from transformers import SwinForImageClassification, AutoImageProcessor class GarbageClassifier: def __init__(self, model_dir, num_labels=4): self.labels = ['glass', 'metal', 'paper', 'plastic'] self.model, self.processor = self.load_model_and_processor(model_dir, num_labels) def load_model_and_processor(self, model_dir, num_labels): model = SwinForImageClassification.from_pretrained(model_dir, num_labels=num_labels) processor = AutoImageProcessor.from_pretrained(model_dir) return model, processor def evaluate_image(self, image_path): image = Image.open(image_path) inputs = self.processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits predicted_class = logits.argmax(-1).item() plt.imshow(image) plt.title(f"Predicted Class: {self.labels[predicted_class]}") plt.axis('off') plt.show()