import torch from transformers import ViTForImageClassification, ViTImageProcessor from datasets import load_dataset import gradio as gr eval = load_dataset("Marxulia/asl_sign_languages_alphabets_v02", split="train") eval = eval.rename_column('label', 'labels') id2label = {str(i): lab for i, lab in enumerate(eval.features["labels"].names)} trained_model = ViTForImageClassification.from_pretrained("falba/google-vit-base-ASL") processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k') with gr.Blocks() as demo: gallery = gr.Gallery([i for i in eval['image']]) statement = gr.Label() def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData chosen_index = evt.index chosen_image = eval['image'][chosen_index] inputs = processor(images=chosen_image, return_tensors="pt") outputs = trained_model(**inputs) predicted_label_id = outputs.logits.argmax(-1).item() predicted_label = id2label[str(predicted_label_id)] actual_label = eval['labels'][chosen_index] return f"Actual Label: {id2label[str(actual_label)]} | Predicted label: {predicted_label}" gallery.select(on_select, None, statement) demo.launch()