import gradio as gr from transformers import ViTForImageClassification, ViTFeatureExtractor from PIL import Image import torch import numpy as np # Load the pre-trained model and preprocessor (feature extractor) model_name = "jjuarez/Vit_waste_image_class" model = ViTForImageClassification.from_pretrained(model_name) feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") def classify_image(image): # Convert the PIL Image to a format compatible with the feature extractor image = np.array(image) # Preprocess the image and prepare it for the model inputs = feature_extractor(images=image, return_tensors="pt") # Make prediction with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Retrieve the highest probability class label index predicted_class_idx = logits.argmax(-1).item() # Define a manual mapping of label indices to human-readable labels index_to_label = { 0: "Aluminium", 1: "Batteries", 2: "Cardboard", 3: "Glass", 4: "Hard Plastic", 5: "Paper", 6: "Soft Plastics", } # Convert the index to the model's class label label = index_to_label.get(predicted_class_idx, "Unknown Label") return label # Create Gradio interface iface = gr.Interface(fn=classify_image, inputs=gr.Image(), # Accepts image of any size outputs=gr.Label(), title="Waste Classification with ViT", description="Upload an image of waste, and the model will classify it.") # Launch the app iface.launch()