Spaces:
Build error
Build error
import streamlit as st | |
import torch | |
from transformers import AutoFeatureExtractor, AutoModelForSequenceClassification, AutoTokenizer | |
from PIL import Image | |
# Load the pretrained model and tokenizer | |
model_name = "nlpconnect/vit-gpt2-image-captioning" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
# Define a function to generate captions from an image | |
def generate_caption(image): | |
inputs = tokenizer(image, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
caption = tokenizer.decode(logits.argmax(1)[0], skip_special_tokens=True) | |
return caption | |
def main(): | |
st.title("Image to Text Captioning") | |
with st.form("my_form"): | |
uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# Display the uploaded image | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
clicked = st.form_submit_button("Generate Caption") | |
if clicked: | |
if "image" in locals(): | |
caption = generate_caption(image) | |
st.subheader("Generated Caption:") | |
st.write(caption) | |
if __name__ == "__main__": | |
main() | |