Kaveh-Workstation commited on
Commit
2e11923
1 Parent(s): c538a72

bug fix for new model

Browse files
Files changed (1) hide show
  1. app.py +19 -17
app.py CHANGED
@@ -40,20 +40,19 @@ def find_similar_images(query_embedding, image_embeddings, k=2):
40
  return closest_indices, scores
41
 
42
 
43
- def main(query, model_id="rclip", k=2):
44
- if model_id=="rclip":
45
  # Load RCLIP model
46
  model = VisionTextDualEncoderModel.from_pretrained(model_path_rclip)
47
  processor = VisionTextDualEncoderProcessor.from_pretrained(model_path_rclip)
48
  # Load image embeddings
49
  image_embeddings = load_embeddings(embeddings_file_rclip)
50
- elif mode_id=="pubmedclip":
51
  model = CLIPModel.from_pretrained(model_path_pubmedclip)
52
  processor = CLIPProcessor.from_pretrained(model_path_pubmedclip)
53
  # Load image embeddings
54
  image_embeddings = load_embeddings(embeddings_file_pubmedclip)
55
 
56
-
57
  # Embed the query
58
  inputs = processor(text=query, images=None, return_tensors="pt", padding=True)
59
  with torch.no_grad():
@@ -75,10 +74,12 @@ def main(query, model_id="rclip", k=2):
75
 
76
  # Define the Gradio interface
77
  examples = [
78
- ["Chest X-ray photos",5],
79
- ["Orthopantogram (OPG)",5],
80
- ["Brain Scan",5],
81
- ["tomography",5]
 
 
82
  ]
83
 
84
  title="RCLIP Image Retrieval"
@@ -91,24 +92,25 @@ with gr.Blocks(title=title) as demo:
91
  gr.Markdown(description)
92
  gr.HTML(value="<img src=\"https://newresults.co.uk/wp-content/uploads/2022/02/teesside-university-logo.png\" alt=\"teesside logo\" width=\"120\" height=\"70\">", show_label=False,scale=1)
93
  #Image.open("./data/teesside university logo.png"), height=70, show_label=False, container=False)
94
- with gr.Column(variant="compact"):
95
- with gr.Row(variant="compact"):
96
- query = gr.Textbox(value="Chest X-Ray Photos", label="Enter your query", show_label=False, placeholder= "Enter your query" , scale=5)
97
- btn = gr.Button("Search query", variant="primary", scale=1)
98
-
99
- n_s = gr.Slider(2, 10, label='Number of Top Results', value=5, step=1.0, show_label=True)
 
100
 
101
-
102
  with gr.Column(variant="compact"):
103
  gr.Markdown("## Results")
104
  gallery = gr.Gallery(label="found images", show_label=True, elem_id="gallery", columns=[2], rows=[4], object_fit="contain", height="400px", preview=True)
105
  gr.Markdown("Information of the found images")
106
  df = gr.DataFrame()
107
- btn.click(main, [query, n_s], [gallery, df])
108
 
109
  with gr.Column(variant="compact"):
110
  gr.Markdown("## Examples")
111
- gr.Examples(examples, [query, n_s])
112
 
113
 
114
  demo.launch(debug='True')
 
40
  return closest_indices, scores
41
 
42
 
43
+ def main(query, model_id="RCLIP", k=2):
44
+ if model_id=="RCLIP":
45
  # Load RCLIP model
46
  model = VisionTextDualEncoderModel.from_pretrained(model_path_rclip)
47
  processor = VisionTextDualEncoderProcessor.from_pretrained(model_path_rclip)
48
  # Load image embeddings
49
  image_embeddings = load_embeddings(embeddings_file_rclip)
50
+ elif model_id=="PubMedCLIP":
51
  model = CLIPModel.from_pretrained(model_path_pubmedclip)
52
  processor = CLIPProcessor.from_pretrained(model_path_pubmedclip)
53
  # Load image embeddings
54
  image_embeddings = load_embeddings(embeddings_file_pubmedclip)
55
 
 
56
  # Embed the query
57
  inputs = processor(text=query, images=None, return_tensors="pt", padding=True)
58
  with torch.no_grad():
 
74
 
75
  # Define the Gradio interface
76
  examples = [
77
+ ["Chest X-ray photos", "RCLIP", 5],
78
+ ["Chest X-ray photos", "PubMedCLIP", 5],
79
+ ["Orthopantogram (OPG)", "RCLIP",5],
80
+ ["Brain Scan", "RCLIP",5],
81
+ ["Chest X-ray photos", "PubMedCLIP", 5],
82
+ ["tomography", "RCLIP",5],
83
  ]
84
 
85
  title="RCLIP Image Retrieval"
 
92
  gr.Markdown(description)
93
  gr.HTML(value="<img src=\"https://newresults.co.uk/wp-content/uploads/2022/02/teesside-university-logo.png\" alt=\"teesside logo\" width=\"120\" height=\"70\">", show_label=False,scale=1)
94
  #Image.open("./data/teesside university logo.png"), height=70, show_label=False, container=False)
95
+ with gr.Row(variant="compact"):
96
+ query = gr.Textbox(value="Chest X-Ray Photos", label="Enter your query", show_label=False, placeholder= "Enter your query" , scale=5)
97
+ btn = gr.Button("Search query", variant="primary", scale=1)
98
+
99
+ with gr.Row(variant="compact"):
100
+ model_id = gr.Dropdown(["RCLIP", "PubMedCLIP"], value="RCLIP", label="Model", type="value", scale=1)
101
+ n_s = gr.Slider(2, 10, label='Number of Top Results', value=5, step=1.0, show_label=True, scale=1)
102
 
103
+
104
  with gr.Column(variant="compact"):
105
  gr.Markdown("## Results")
106
  gallery = gr.Gallery(label="found images", show_label=True, elem_id="gallery", columns=[2], rows=[4], object_fit="contain", height="400px", preview=True)
107
  gr.Markdown("Information of the found images")
108
  df = gr.DataFrame()
109
+ btn.click(main, [query, model_id, n_s], [gallery, df])
110
 
111
  with gr.Column(variant="compact"):
112
  gr.Markdown("## Examples")
113
+ gr.Examples(examples, [query, model_id, n_s])
114
 
115
 
116
  demo.launch(debug='True')