Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import VisionEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer | |
import requests | |
from PIL import Image | |
import torch | |
CHECKPOINT = "g8a9/vit-geppetto-captioning" | |
def get_model(): | |
model = VisionEncoderDecoderModel.from_pretrained(CHECKPOINT) | |
return model | |
feature_extractor = AutoFeatureExtractor.from_pretrained(CHECKPOINT) | |
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT) | |
st.title("Image Captioning with ViT & GePpeTto ๐ฎ๐น") | |
st.sidebar.markdown("## Generation parameters") | |
max_length = st.sidebar.number_input("Max length", value=20, min_value=1) | |
no_repeat_ngram_size = st.sidebar.number_input("no repeat ngrams size", value=2, min_value=1) | |
num_return_sequences = st.sidebar.number_input("Generated sequences", value=3, min_value=1) | |
gen_mode = st.sidebar.selectbox("Generation mode", ["beam search", "sampling"]) | |
if gen_mode == "beam search": | |
num_beams = st.sidebar.number_input("Beam size", value=5, min_value=1) | |
early_stopping = st.sidebar.checkbox("Early stopping", value=True) | |
gen_params = { | |
"num_beams": num_beams, | |
"early_stopping": early_stopping | |
} | |
elif gen_mode == "sampling": | |
do_sample = True | |
top_k = st.sidebar.number_input("top_k", value=30, min_value=0) | |
top_p = st.sidebar.number_input("top_p", value=0, min_value=0) | |
temperature = st.sidebar.number_input("temperature", value=0.7, min_value=0.0) | |
gen_params = { | |
"do_sample": do_sample, | |
"top_k": top_k, | |
"top_p": top_p, | |
"temperature": temperature | |
} | |
def generate_caption(url): | |
image = Image.open(requests.get(url, stream=True).raw).convert("RGB") | |
inputs = feature_extractor(image, return_tensors="pt") | |
model = get_model() | |
model.eval() | |
generated_ids = model.generate( | |
inputs["pixel_values"], | |
max_length=20, | |
no_repeat_ngram_size=2, | |
num_return_sequences=3, | |
**gen_params | |
) | |
captions = tokenizer.batch_decode( | |
generated_ids, | |
skip_special_tokens=True, | |
) | |
return captions[0] | |
url = st.text_input( | |
"Insert your URL", "https://iheartcats.com/wp-content/uploads/2015/08/c84.jpg" | |
) | |
st.image(url) | |
if st.button("Run captioning"): | |
with st.spinner("Processing image..."): | |
caption = generate_caption(url) | |
st.text(caption) | |