kaushalya commited on
Commit
a269b46
1 Parent(s): 97009c1

Fix loading the text model

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -1,21 +1,22 @@
1
  import os
 
2
 
3
  import matplotlib.pyplot as plt
4
  import numpy as np
5
  import pandas as pd
6
  import streamlit as st
7
- from transformers import CLIPProcessor
8
 
9
  from medclip.modeling_hybrid_clip import FlaxHybridCLIP
10
 
11
 
12
- @st.cache(allow_output_mutation=True)
13
  def load_model():
14
- model, _ = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco", _do_init=False)
15
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
16
- return model, processor
17
 
18
- @st.cache(allow_output_mutation=True)
19
  def load_image_embeddings():
20
  embeddings_df = pd.read_hdf('feature_store/image_embeddings_large.hdf', key='emb')
21
  image_embeds = np.stack(embeddings_df['image_embedding'])
@@ -64,7 +65,7 @@ elif ex4_button:
64
 
65
 
66
  image_list, image_embeddings = load_image_embeddings()
67
- model, processor = load_model()
68
 
69
  query = st.text_input("Enter your query here:", value=text_value)
70
  dot_prod = None
@@ -78,8 +79,8 @@ if st.button("Search") or k_slider:
78
  else:
79
  with st.spinner(f"Searching ROCO test set for {query}..."):
80
  k = k_slider
81
- inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)
82
-
83
  query_embedding = model.get_text_features(**inputs)
84
  query_embedding = np.asarray(query_embedding)
85
  query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
@@ -91,4 +92,4 @@ if st.button("Search") or k_slider:
91
  for img_path, score in zip(matching_images, top_scores):
92
  img = plt.imread(os.path.join(img_dir, img_path))
93
  st.image(img, width=300)
94
- st.write(f"{img_path} ({score:.2f})", help="score")
 
1
  import os
2
+ import token
3
 
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
  import pandas as pd
7
  import streamlit as st
8
+ from transformers import CLIPProcessor, AutoTokenizer
9
 
10
  from medclip.modeling_hybrid_clip import FlaxHybridCLIP
11
 
12
 
13
+ @st.cache_resource
14
  def load_model():
15
+ model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco", _do_init=True)
16
+ tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
17
+ return model, tokenizer
18
 
19
+ @st.cache_resource
20
  def load_image_embeddings():
21
  embeddings_df = pd.read_hdf('feature_store/image_embeddings_large.hdf', key='emb')
22
  image_embeds = np.stack(embeddings_df['image_embedding'])
 
65
 
66
 
67
  image_list, image_embeddings = load_image_embeddings()
68
+ model, tokenizer = load_model()
69
 
70
  query = st.text_input("Enter your query here:", value=text_value)
71
  dot_prod = None
 
79
  else:
80
  with st.spinner(f"Searching ROCO test set for {query}..."):
81
  k = k_slider
82
+ inputs = tokenizer(text=[query], return_tensors="jax", padding=True)
83
+ # st.write(f"Query inputs: {inputs}")
84
  query_embedding = model.get_text_features(**inputs)
85
  query_embedding = np.asarray(query_embedding)
86
  query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
 
92
  for img_path, score in zip(matching_images, top_scores):
93
  img = plt.imread(os.path.join(img_dir, img_path))
94
  st.image(img, width=300)
95
+ st.write(f"{img_path} ({score:.2f})")