Build error
Build error
File size: 2,804 Bytes
8ff0261 5dce03a 7326e2c bf20f73 f1d50b1 e4b9c8b f1d50b1 e4b9c8b 2cf3514 a811816 fedeff8 2cf3514 e4b9c8b df702af 14261e1 df702af 7326e2c 83d94a8 bf9c2d9 a811816 83d94a8 e4b9c8b 83d94a8 a811816 83d94a8 7326e2c e4b9c8b 83d94a8 bf9c2d9 83d94a8 bf9c2d9 83d94a8 a811816 83d94a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import jax
import jax.numpy as jnp
import pandas as pd
import requests
import streamlit as st
from PIL import Image
from utils import load_model
def app(model_name):
model, processor = load_model(f"koclip/{model_name}")
st.title("Zero-shot Image Classification")
This demo explores KoCLIP's zero-shot prediction capabilities. The model takes an image and a list of candidate captions from the user and predicts the most likely caption that best describes the given image.
query1 = st.text_input(
"Enter a URL to an image...",
query2 = st.file_uploader("or upload an image...", type=["jpg", "jpeg", "png"])
col1, col2 = st.beta_columns([3, 1])
with col2:
captions_count = st.selectbox("Number of labels", options=range(1, 6), index=2)
normalize = st.checkbox("Apply Softmax")
compute = st.button("Classify")
with col1:
captions = []
defaults = ["κ·μ¬μ΄ κ³ μμ΄", "λ©μλ κ°μμ§", "ν¬λν¬λν νμ€ν°"]
for idx in range(captions_count):
value = defaults[idx] if idx < len(defaults) else ""
captions.append(st.text_input(f"Insert caption {idx+1}", value=value))
if compute:
if not any([query1, query2]):
st.error("Please upload an image or paste an image URL.")
with st.spinner("Computing..."):
image_data = (
if query2 is not None
else requests.get(query1, stream=True).raw
image =
# captions = [caption.strip() for caption in captions.split(",")]
captions = [f"μ΄κ²μ {caption.strip()}μ΄λ€." for caption in captions]
inputs = processor(
text=captions, images=image, return_tensors="jax", padding=True
inputs["pixel_values"] = jnp.transpose(
inputs["pixel_values"], axes=[0, 2, 3, 1]
outputs = model(**inputs)
if normalize:
name = "normalized prob"
probs = jax.nn.softmax(outputs.logits_per_image, axis=1)
name = "cosine sim"
probs = outputs.logits_per_image
chart_data = pd.Series(probs[0], index=captions, name=name)
col1, col2 = st.beta_columns(2)
with col1:
with col2: