kaushalya commited on
Commit
aa31199
1 Parent(s): 1366c30

Add documentation

Browse files
app.py CHANGED
@@ -3,7 +3,7 @@ import pandas as pd
3
  import numpy as np
4
  import os
5
  import matplotlib.pyplot as plt
6
- from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor
7
  from medclip.modeling_hybrid_clip import FlaxHybridCLIP
8
 
9
  @st.cache(allow_output_mutation=True)
@@ -14,7 +14,7 @@ def load_model():
14
 
15
  @st.cache(allow_output_mutation=True)
16
  def load_image_embeddings():
17
- embeddings_df = pd.read_pickle('image_embeddings.pkl')
18
  image_embeds = np.stack(embeddings_df['image_embedding'])
19
  image_files = np.asarray(embeddings_df['files'].tolist())
20
  return image_files, image_embeds
@@ -24,7 +24,11 @@ image_list, image_embeddings = load_image_embeddings()
24
  model, processor = load_model()
25
  img_dir = './images'
26
 
27
- query = st.text_input("Search:")
 
 
 
 
28
 
29
  if st.button("Search"):
30
  st.write(f"Searching our image database for {query}...")
@@ -36,8 +40,6 @@ if st.button("Search"):
36
  query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
37
  dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
38
  matching_images = image_list[dot_prod.argsort()[-k:]]
39
-
40
- # st.write(f"matching images: {matching_images}")
41
  #show images
42
 
43
  for img_path in matching_images:
 
3
  import numpy as np
4
  import os
5
  import matplotlib.pyplot as plt
6
+ from transformers import AutoTokenizer, CLIPProcessor
7
  from medclip.modeling_hybrid_clip import FlaxHybridCLIP
8
 
9
  @st.cache(allow_output_mutation=True)
 
14
 
15
  @st.cache(allow_output_mutation=True)
16
  def load_image_embeddings():
17
+ embeddings_df = pd.read_pickle('feature_store/image_embeddings.pkl')
18
  image_embeds = np.stack(embeddings_df['image_embedding'])
19
  image_files = np.asarray(embeddings_df['files'].tolist())
20
  return image_files, image_embeds
 
24
  model, processor = load_model()
25
  img_dir = './images'
26
 
27
+ st.title("MedCLIP 🩺📎")
28
+ st.markdown("Search for medical images in natural language.")
29
+ st.markdown("""This demo uses a CLIP model finetuned on the
30
+ [Radiology Objects in COntext (ROCO) dataset](https://github.com/razorx89/roco-dataset).""")
31
+ query = st.text_input("Enter your query here:")
32
 
33
  if st.button("Search"):
34
  st.write(f"Searching our image database for {query}...")
 
40
  query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
41
  dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
42
  matching_images = image_list[dot_prod.argsort()[-k:]]
 
 
43
  #show images
44
 
45
  for img_path in matching_images:
feature_store/image_embeddings.pkl ADDED
Binary file (1.88 MB). View file
 
tools/create_embeddings.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import jax
3
+
4
+ from transformers import AutoTokenizer, CLIPProcessor
5
+ from configuration_hybrid_clip import HybridCLIPConfig
6
+ from modeling_hybrid_clip import FlaxHybridCLIP
7
+ from PIL import Image
8
+
9
+ import matplotlib.pyplot as plt
10
+ import torch
11
+ import torchvision
12
+ from torchvision.transforms.functional import InterpolationMode
13
+ from torchvision.transforms import Resize, Normalize, ConvertImageDtype, ToTensor
14
+ import numpy as np
15
+ import pandas as pd
16
+
17
+
18
+ def main():
19
+ model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")
20
+ vision_model_name = "openai/clip-vit-base-patch32"
21
+ img_dir = "/Users/kaumad/Documents/coding/hf-flax/demo/medclip-roco/images"
22
+
23
+ processor = CLIPProcessor.from_pretrained(vision_model_name)
24
+
25
+ img_list = os.listdir(img_dir)
26
+ embeddings = []
27
+
28
+ for idx, img_path in enumerate(img_list):
29
+ if idx % 10 == 0:
30
+ print(f"{idx} images processed")
31
+ img = Image.open(os.path.join(img_dir, img_path)).convert('RGB')
32
+ inputs = processor(images=img, return_tensors="jax", padding=True)
33
+ inputs['pixel_values'] = inputs['pixel_values'].transpose(0, 2, 3, 1)
34
+ img_vec = model.get_image_features(**inputs)
35
+ img_vec = np.array(img_vec).reshape(-1).tolist()
36
+ embeddings.append(img_vec)
37
+
38
+ if __name__=='__main__':
39
+ main()