import streamlit as st from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer import requests from PIL import Image import torch CHECKPOINT = "g8a9/vit-geppetto-captioning" @st.cache def get_model(): model = VisionEncoderDecoderModel.from_pretrained(CHECKPOINT) return model feature_extractor = AutoFeatureExtractor.from_pretrained(CHECKPOINT) tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT) model.eval() st.title("Image Captioning with ViT & GePpeTto 🇮🇹") st.sidebar.markdown("## Generation parameters") max_length = st.sidebar.number_input("Max length", value=20, min_value=1) no_repeat_ngram_size = st.sidebar.number_input("no repeat ngrams size", value=2, min_value=1) num_return_sequences = st.sidebar.number_input("Generated sequences", value=3, min_value=1) gen_mode = st.sidebar.selectbox("Generation mode", ["beam search", "sampling"]) if gen_mode == "beam search": num_beams = st.sidebar.number_input("Beam size", value=5, min_value=1) early_stopping = st.sidebar.checkbox("Early stopping", value=True) gen_params = { "num_beams": num_beams, "early_stopping": early_stopping } elif gen_mode == "sampling": do_sample = True top_k = st.sidebar.number_input("top_k", value=30, min_value=0) top_p = st.sidebar.number_input("top_p", value=0, min_value=0) temperature = st.sidebar.number_input("temperature", value=0.7, min_value=0.0) gen_params = { "do_sample": do_sample, "top_k": top_k, "top_p": top_p, "temperature": temperature } def generate_caption(url): image = Image.open(requests.get(url, stream=True).raw).convert("RGB") inputs = feature_extractor(image, return_tensors="pt") model = get_model() generated_ids = model.generate( inputs["pixel_values"], max_length=20, no_repeat_ngram_size=2, num_return_sequences=3, **gen_params ) captions = tokenizer.batch_decode( generated_ids, skip_special_tokens=True, ) return captions[0] url = st.text_input( "Insert your URL", "https://iheartcats.com/wp-content/uploads/2015/08/c84.jpg" ) st.image(url) if st.button("Run captioning"): with st.spinner("Processing image..."): caption = generate_caption(url) st.text(caption)