File size: 3,205 Bytes
8d9306e
6f0178d
2a5dc71
8d9306e
 
 
c951094
374fa3e
8d9306e
0bb133b
f884ea7
6f0178d
 
 
c951094
 
6f0178d
f884ea7
 
 
d28411b
8d9306e
 
144ec50
8d9306e
6f0178d
144ec50
6f0178d
 
f705683
8d9306e
6f0178d
 
144ec50
6f0178d
2a5dc71
 
 
 
 
 
8d9306e
8f85ccf
8d9306e
8f85ccf
8d9306e
8f85ccf
c951094
8f85ccf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
import requests
import io


# Designing the interface
st.title("🖼️ Image Captioning Demo 📝")
st.write("[Yih-Dar SHIEH](https://huggingface.co/ydshieh)")

st.sidebar.markdown(
    """
    An image captioning model by combining ViT model with GPT2 model.
    The encoder (ViT) and decoder (GPT2) are combined using Hugging Face transformers' [Vision-To-Text Encoder-Decoder
    framework](https://huggingface.co/transformers/master/model_doc/visionencoderdecoder.html).
    The pretrained weights of both models are loaded, with a set of randomly initialized cross-attention weights.
    The model is trained on the COCO 2017 dataset for about 6900 steps (batch_size=256).
    [Follow-up work of [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n
    """
)

with st.spinner('Loading and compiling ViT-GPT2 model ...'):
    from model import *

random_image_id = get_random_image_id()

st.sidebar.title("Select a sample image")
sample_image_id = st.sidebar.selectbox(
    "Please choose a sample image",
    sample_image_ids
)

if st.sidebar.button("Random COCO 2017 (val) images"):
    random_image_id = get_random_image_id()
    sample_image_id = "None"

bytes_data = None
with st.sidebar.form("file-uploader-form", clear_on_submit=True):
    uploaded_file = st.file_uploader("Choose a file")
    submitted = st.form_submit_button("Upload")
    if submitted and uploaded_file is not None:
        bytes_data = io.BytesIO(uploaded_file.getvalue())

if (bytes_data is None) and submitted:

    st.write("No file is selected to upload")

else:

    image_id = random_image_id
    if sample_image_id != "None":
        assert type(sample_image_id) == int
        image_id = sample_image_id

    sample_name = f"COCO_val2017_{str(image_id).zfill(12)}.jpg"
    sample_path = os.path.join(sample_dir, sample_name)

    if bytes_data is not None:
        image = Image.open(bytes_data)
    elif os.path.isfile(sample_path):
        image = Image.open(sample_path)
    else:
        url = f"http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg"
        image = Image.open(requests.get(url, stream=True).raw)

    width, height = image.size
    resized = image.resize(size=(width, height))
    if height > 384:
        width = int(width / height * 384)
        height = 384
        resized = resized.resize(size=(width, height))
    width, height = resized.size
    if width > 512:
        width = 512
        height = int(height / width * 512)
        resized = resized.resize(size=(width, height))

    if bytes_data is None:
        st.markdown(f"[{str(image_id).zfill(12)}.jpg](http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg)")
    show = st.image(resized)
    show.image(resized, '\n\nSelected Image')
    resized.close()

    # For newline
    st.sidebar.write('\n')

    with st.spinner('Generating image caption ...'):

        caption = predict(image)

        caption_en = caption
        st.header(f'Predicted caption:\n\n')
        st.subheader(caption_en)

    st.sidebar.header("ViT-GPT2 predicts: ")
    st.sidebar.write(f"{caption}")

    image.close()