import torch from PIL import Image from transformers import (AutoTokenizer, VisionEncoderDecoderModel, ViTFeatureExtractor) import gradio as gr if torch.cuda.is_available(): device = "cuda" else: device = "cpu" encoder_checkpoint = "google/vit-base-patch16-224-in21k" decoder_checkpoint = "gpt2" model_checkpoint = "gagan3012/ViTGPT2I2A" feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint) tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint) model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device) def predict(image): clean_text = lambda x: x.replace("<|endoftext|>", "").split("\n")[0] sample = feature_extractor(image, return_tensors="pt").pixel_values.to(device) caption_ids = model.generate(sample, max_length=50)[0] caption_text = clean_text(tokenizer.decode(caption_ids)) return caption_text inputs = [ gr.inputs.Image(type="pil", label="Original Image") ] outputs = [ gr.outputs.Textbox(label = 'Caption') ] title = "Image Captioning using ViT + GPT2" description = "ViT and GPT2 are used to generate Image Caption for the uploaded images" article = " Model Repo on Hugging Face Model Hub" examples = [ ["duck.jpg"], ["dice.jpg"], ["banana.jpg"], ["avacado.jpg"] ] gr.Interface( predict, inputs, outputs, title=title, description=description, article=article, examples=examples, theme="huggingface", ).launch(debug=True, enable_queue=True)