File size: 2,804 Bytes
8ff0261
5dce03a
7326e2c
bf20f73
 
 
f1d50b1
 
 
 
 
e4b9c8b
f1d50b1
e4b9c8b
2cf3514
 
a811816
 
 
fedeff8
2cf3514
e4b9c8b
df702af
14261e1
 
df702af
 
7326e2c
83d94a8
 
 
bf9c2d9
a811816
83d94a8
e4b9c8b
83d94a8
 
 
 
 
a811816
83d94a8
 
7326e2c
 
e4b9c8b
83d94a8
 
 
bf9c2d9
 
 
83d94a8
 
bf9c2d9
83d94a8
 
 
 
 
 
 
 
 
a811816
 
 
 
 
 
 
83d94a8
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import jax
import jax.numpy as jnp
import pandas as pd
import requests
import streamlit as st
from PIL import Image

from utils import load_model


def app(model_name):
    model, processor = load_model(f"koclip/{model_name}")

    st.title("Zero-shot Image Classification")
    st.markdown(
        """
        This demo explores KoCLIP's zero-shot prediction capabilities. The model takes an image and a list of candidate captions from the user and predicts the most likely caption that best describes the given image. 

        ---
        """
    )

    query1 = st.text_input(
        "Enter a URL to an image...",
        value="http://images.cocodataset.org/val2017/000000039769.jpg",
    )
    query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])

    col1, col2 = st.beta_columns([3, 1])

    with col2:
        captions_count = st.selectbox("Number of labels", options=range(1, 6), index=2)
        normalize = st.checkbox("Apply Softmax")
        compute = st.button("Classify")

    with col1:
        captions = []
        defaults = ["κ·€μ—¬μš΄ 고양이", "λ©‹μžˆλŠ” 강아지", "ν¬λ™ν¬λ™ν•œ ν–„μŠ€ν„°"]
        for idx in range(captions_count):
            value = defaults[idx] if idx < len(defaults) else ""
            captions.append(st.text_input(f"Insert caption {idx+1}", value=value))

    if compute:
        if not any([query1, query2]):
            st.error("Please upload an image or paste an image URL.")
        else:
            st.markdown("""---""")
            with st.spinner("Computing..."):
                image_data = (
                    query2
                    if query2 is not None
                    else requests.get(query1, stream=True).raw
                )
                image = Image.open(image_data)

                # captions = [caption.strip() for caption in captions.split(",")]
                captions = [f"이것은 {caption.strip()}이닀." for caption in captions]
                inputs = processor(
                    text=captions, images=image, return_tensors="jax", padding=True
                )
                inputs["pixel_values"] = jnp.transpose(
                    inputs["pixel_values"], axes=[0, 2, 3, 1]
                )
                outputs = model(**inputs)
                if normalize:
                    name = "normalized prob"
                    probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
                else:
                    name = "cosine sim"
                    probs = outputs.logits_per_image
                chart_data = pd.Series(probs[0], index=captions, name=name)

                col1, col2 = st.beta_columns(2)
                with col1:
                    st.image(image)
                with col2:
                    st.bar_chart(chart_data)