g8a9 commited on
Commit
a01e989
1 Parent(s): 34d02dd

[image2text] add initial version

Browse files
Files changed (4) hide show
  1. image2text.py +58 -1
  2. requirements.txt +2 -1
  3. text2image.py +16 -11
  4. utils.py +8 -0
image2text.py CHANGED
@@ -1,4 +1,10 @@
1
  import streamlit as st
 
 
 
 
 
 
2
 
3
  def app():
4
  st.title("From Image to Text")
@@ -12,4 +18,55 @@ def app():
12
  🤌 Italian mode on! 🤌
13
 
14
  """
15
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from text2image import get_model, get_tokenizer, get_image_transform
3
+ from utils import text_encoder, image_encoder
4
+ from PIL import Image
5
+ from jax import numpy as jnp
6
+ import pandas as pd
7
+
8
 
9
  def app():
10
  st.title("From Image to Text")
 
18
  🤌 Italian mode on! 🤌
19
 
20
  """
21
+ )
22
+
23
+ filename = st.file_uploader(
24
+ "Choose an image from your computer", type=["jpg", "jpeg", "png"]
25
+ )
26
+
27
+ MAX_CAP = 4
28
+
29
+ col1, col2 = st.beta_columns([3, 1])
30
+
31
+ with col2:
32
+ captions_count = st.selectbox(
33
+ "Number of captions", options=range(1, MAX_CAP + 1)
34
+ )
35
+ compute = st.button("Compute")
36
+
37
+ with col1:
38
+ captions = list()
39
+ for idx in range(min(MAX_CAP, captions_count)):
40
+ captions.append(st.text_input(f"Insert Caption {idx+1}"))
41
+
42
+ if compute:
43
+ captions = [c for c in captions if c != ""]
44
+
45
+ if not captions or not filename:
46
+ st.error("Please choose one image and at least one caption")
47
+ else:
48
+ with st.spinner("Computing..."):
49
+ model = get_model()
50
+ tokenizer = get_tokenizer()
51
+
52
+ text_embeds = list()
53
+ for i, c in enumerate(captions):
54
+ text_embeds.extend(text_encoder(c, model, tokenizer))
55
+
56
+ text_embeds = jnp.array(text_embeds)
57
+
58
+ image = Image.open(filename).convert("RGB")
59
+ transform = get_image_transform(model.config.vision_config.image_size)
60
+ image_embed = image_encoder(transform(image), model)
61
+
62
+ # we could have a softmax here
63
+ cos_similarities = jnp.matmul(image_embed, text_embeds.T)
64
+
65
+ chart_data = pd.Series(cos_similarities[0], index=captions)
66
+
67
+ col1, col2 = st.beta_columns(2)
68
+ with col1:
69
+ st.bar_chart(chart_data)
70
+
71
+ with col2:
72
+ st.image(image)
requirements.txt CHANGED
@@ -4,4 +4,5 @@ transformers
4
  torch
5
  torchvision
6
  natsort
7
- stqdm
 
 
4
  torch
5
  torchvision
6
  natsort
7
+ stqdm
8
+ pandas
text2image.py CHANGED
@@ -81,6 +81,20 @@ def load_urls(dataset_name):
81
  ValueError(f"{dataset_name} not supported here")
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def app():
85
 
86
  st.title("From Text to Image")
@@ -140,18 +154,9 @@ def app():
140
 
141
  if dataset_name == "Unsplash":
142
  image_size = model.config.vision_config.image_size
143
- val_preprocess = Compose(
144
- [
145
- Resize([image_size], interpolation=InterpolationMode.BICUBIC),
146
- CenterCrop(image_size),
147
- ToTensor(),
148
- Normalize(
149
- (0.48145466, 0.4578275, 0.40821073),
150
- (0.26862954, 0.26130258, 0.27577711),
151
- ),
152
- ]
153
  )
154
- dataset = utils.CustomDataSet("photos/", transform=val_preprocess)
155
  elif dataset_name == "CC":
156
  dataset = load_urls(dataset_name)
157
  else:
 
81
  ValueError(f"{dataset_name} not supported here")
82
 
83
 
84
+ def get_image_transform(image_size):
85
+ return Compose(
86
+ [
87
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
88
+ CenterCrop(image_size),
89
+ ToTensor(),
90
+ Normalize(
91
+ (0.48145466, 0.4578275, 0.40821073),
92
+ (0.26862954, 0.26130258, 0.27577711),
93
+ ),
94
+ ]
95
+ )
96
+
97
+
98
  def app():
99
 
100
  st.title("From Text to Image")
 
154
 
155
  if dataset_name == "Unsplash":
156
  image_size = model.config.vision_config.image_size
157
+ dataset = utils.CustomDataSet(
158
+ "photos/", transform=get_image_transform(image_size)
 
 
 
 
 
 
 
 
159
  )
 
160
  elif dataset_name == "CC":
161
  dataset = load_urls(dataset_name)
162
  else:
utils.py CHANGED
@@ -41,6 +41,14 @@ def text_encoder(text, model, tokenizer):
41
  return jnp.expand_dims(embedding, axis=0)
42
 
43
 
 
 
 
 
 
 
 
 
44
  def precompute_image_features(model, loader):
45
  image_features = []
46
  for i, (images) in enumerate(tqdm(loader)):
 
41
  return jnp.expand_dims(embedding, axis=0)
42
 
43
 
44
+ def image_encoder(image, model):
45
+ image = image.permute(1, 2, 0).numpy()
46
+ image = jnp.expand_dims(image, axis=0) #  add batch size
47
+ features = model.get_image_features(image,)
48
+ features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
49
+ return features
50
+
51
+
52
  def precompute_image_features(model, loader):
53
  image_features = []
54
  for i, (images) in enumerate(tqdm(loader)):