Spaces:
Runtime error
Runtime error
import torch | |
from transformers import SwinForImageClassification, AutoImageProcessor | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from transformers import SwinForImageClassification, AutoImageProcessor | |
class GarbageClassifier: | |
def __init__(self, model_dir, num_labels=4): | |
self.labels = ['glass', 'metal', 'paper', 'plastic'] | |
self.model, self.processor = self.load_model_and_processor(model_dir, num_labels) | |
def load_model_and_processor(self, model_dir, num_labels): | |
model = SwinForImageClassification.from_pretrained(model_dir, num_labels=num_labels) | |
processor = AutoImageProcessor.from_pretrained(model_dir) | |
return model, processor | |
def evaluate_image(self, image_path): | |
image = Image.open(image_path) | |
inputs = self.processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
logits = outputs.logits | |
predicted_class = logits.argmax(-1).item() | |
plt.imshow(image) | |
plt.title(f"Predicted Class: {self.labels[predicted_class]}") | |
plt.axis('off') | |
plt.show() |