File size: 1,985 Bytes
8d9306e
 
 
 
c951094
374fa3e
8d9306e
0bb133b
f884ea7
c951094
0514a3d
c951094
 
 
f884ea7
 
 
fd6cb9f
 
 
8d9306e
c951094
d28411b
8d9306e
 
0bb133b
8d9306e
f705683
8d9306e
f705683
cc62e3c
9a6a97f
f705683
8d9306e
3568832
9a6a97f
8d9306e
f705683
c0cae7b
 
8d9306e
 
 
 
c951094
686f21e
8d9306e
d28411b
c951094
 
de90f6d
e53a130
c951094
 
 
 
8d9306e
686f21e
c951094
 
e53a130
c951094
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import streamlit as st


# Designing the interface
st.title("🖼️ Image Captioning Demo 📝")
st.write("[Yih-Dar SHIEH](https://huggingface.co/ydshieh)")

st.sidebar.markdown(
    """
    An image captioning model [ViT-GPT2](https://huggingface.co/flax-community/vit-gpt2) by combining the ViT model with the GPT2 model.
    [Part of the [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n
    The encoder (ViT) and decoder (GPT2) are combined using Hugging Face transformers' `FlaxVisionEncoderDecoderModel`.
    The pretrained weights of both models are loaded, with a set of randomly initialized cross-attention weights.
    The model is trained on the COCO 2017 dataset for about 6900 steps (batch_size=256).
    """
)

#image = Image.open('samples/val_000000039769.jpg')
#show = st.image(image, use_column_width=True)
#show.image(image, 'Preloaded Image', use_column_width=True)


with st.spinner('Loading and compiling ViT-GPT2 model ...'):

    from model import *
    # st.sidebar.write(f'Vit-GPT2 model loaded :)')

st.sidebar.title("Select a sample image")

sample_name = st.sidebar.selectbox(
    "Please choose an image",
    sample_fns
)

sample_name = f"COCO_val2014_{sample_name.replace('.jpg', '').zfill(12)}.jpg"
sample_path = os.path.join(sample_dir, sample_name)

image = Image.open(sample_path)
show = st.image(image, width=480)
show.image(image, '\n\nSelected Image', width=480)

# For newline
st.sidebar.write('\n')


with st.spinner('Generating image caption ...'):

    caption = predict(image)

    caption_en = caption
    st.header(f'**Prediction (in English)**: {caption_en}')
    
    # caption_en = translator.translate(caption, src='fr', dest='en').text
    # st.header(f'**Prediction (in French) **{caption}')
    # st.header(f'**English Translation**: {caption_en}')


st.sidebar.header("ViT-GPT2 predicts:")
st.sidebar.write(f"**English**: {caption}")


image.close()