Spaces:
Sleeping
Sleeping
ACMCMC
commited on
Commit
•
1f35211
1
Parent(s):
2408e3d
WIP app
Browse files
app.py
CHANGED
@@ -3,8 +3,10 @@ from streamlit_agraph import agraph, Node, Edge, Config
|
|
3 |
import os
|
4 |
from sqlalchemy import create_engine, text
|
5 |
import pandas as pd
|
6 |
-
from utils import get_all_diseases_name, get_most_similar_diseases_from_uri, get_uri_from_name, get_diseases_related_to_a_textual_description
|
7 |
import json
|
|
|
|
|
8 |
|
9 |
|
10 |
username = 'demo'
|
@@ -15,11 +17,17 @@ namespace = 'USER'
|
|
15 |
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
|
16 |
engine = create_engine(CONNECTION_STRING)
|
17 |
|
18 |
-
def handle_click_on_analyze_button():
|
19 |
# 1. Embed the textual description that the user entered using the model
|
20 |
-
diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(description_input)
|
21 |
# 2. Get 5 diseases with the highest cosine silimarity from the DB
|
|
|
|
|
|
|
|
|
22 |
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
|
|
|
|
|
|
|
23 |
# 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
|
24 |
# 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
|
25 |
# 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
|
@@ -31,7 +39,10 @@ def handle_click_on_analyze_button():
|
|
31 |
|
32 |
st.write("# Klìnic")
|
33 |
|
34 |
-
description_input = st.text_input(label="Enter the disease description 👇")
|
|
|
|
|
|
|
35 |
|
36 |
st.write(":red[Here should be the graph]") # TODO remove
|
37 |
chart_data = pd.DataFrame(
|
|
|
3 |
import os
|
4 |
from sqlalchemy import create_engine, text
|
5 |
import pandas as pd
|
6 |
+
from utils import get_all_diseases_name, get_most_similar_diseases_from_uri, get_uri_from_name, get_diseases_related_to_a_textual_description, get_similarities_among_diseases_uris
|
7 |
import json
|
8 |
+
import numpy as np
|
9 |
+
from sentence_transformers import SentenceTransformer
|
10 |
|
11 |
|
12 |
username = 'demo'
|
|
|
17 |
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
|
18 |
engine = create_engine(CONNECTION_STRING)
|
19 |
|
20 |
+
def handle_click_on_analyze_button(user_text):
|
21 |
# 1. Embed the textual description that the user entered using the model
|
|
|
22 |
# 2. Get 5 diseases with the highest cosine silimarity from the DB
|
23 |
+
encoder = SentenceTransformer("allenai-specter")
|
24 |
+
diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(user_text, encoder)
|
25 |
+
#for disease_label in diseases_related_to_the_user_text:
|
26 |
+
# st.text(disease_label)
|
27 |
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
|
28 |
+
diseases_uris = [disease['uri'] for disease in diseases_related_to_the_user_text]
|
29 |
+
get_similarities_among_diseases_uris(diseases_uris)
|
30 |
+
print(diseases_related_to_the_user_text)
|
31 |
# 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
|
32 |
# 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
|
33 |
# 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
|
|
|
39 |
|
40 |
st.write("# Klìnic")
|
41 |
|
42 |
+
description_input = st.text_input(label="Enter the disease description 👇", placeholder='A disease that causes memory loss and other cognitive impairments.')
|
43 |
+
if st.button("Analyze"):
|
44 |
+
handle_click_on_analyze_button(description_input)
|
45 |
+
# TODO: also when user clicks enter
|
46 |
|
47 |
st.write(":red[Here should be the graph]") # TODO remove
|
48 |
chart_data = pd.DataFrame(
|
utils.py
CHANGED
@@ -5,6 +5,15 @@ from sqlalchemy import create_engine, text
|
|
5 |
import requests
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
def get_all_diseases_name(engine) -> List[List[str]]:
|
10 |
with engine.connect() as conn:
|
@@ -98,46 +107,48 @@ def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str
|
|
98 |
return clinical_records
|
99 |
|
100 |
|
101 |
-
def
|
102 |
-
uri_list
|
|
|
|
|
103 |
with engine.connect() as conn:
|
104 |
with conn.begin():
|
105 |
sql = f"""
|
106 |
SELECT e1.uri AS uri1, e2.uri AS uri2, VECTOR_COSINE(e1.embedding, e2.embedding) AS distance
|
107 |
FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2
|
108 |
-
WHERE e1.uri IN {uri_list} AND e2.uri IN {uri_list} AND e1.uri != e2.uri
|
109 |
"""
|
110 |
result = conn.execute(text(sql))
|
111 |
data = result.fetchall()
|
112 |
return data
|
113 |
|
114 |
|
115 |
-
encoder
|
116 |
-
|
117 |
-
|
118 |
-
def get_embedding(string: str) -> List[float]:
|
119 |
# Embed the string using sentence-transformers
|
120 |
vector = encoder.encode(string, show_progress_bar=False)
|
121 |
return vector
|
122 |
|
123 |
|
124 |
-
def get_diseases_related_to_a_textual_description(
|
|
|
|
|
125 |
# Embed the description using sentence-transformers
|
126 |
-
description_embedding = get_embedding(description)
|
127 |
-
print(f
|
128 |
string_representation = str(description_embedding.tolist())[1:-1]
|
129 |
-
print(f
|
130 |
|
131 |
with engine.connect() as conn:
|
132 |
with conn.begin():
|
133 |
sql = f"""
|
134 |
-
SELECT TOP 5 uri, VECTOR_COSINE(
|
135 |
-
FROM Test.DiseaseDescriptions
|
136 |
ORDER BY distance DESC
|
137 |
"""
|
138 |
result = conn.execute(text(sql))
|
139 |
data = result.fetchall()
|
140 |
-
|
|
|
141 |
|
142 |
|
143 |
if __name__ == "__main__":
|
@@ -164,9 +175,29 @@ if __name__ == "__main__":
|
|
164 |
clinical_record_info = get_clinical_records_by_ids(["NCT00841061"])
|
165 |
print(clinical_record_info)
|
166 |
|
167 |
-
textual_description =
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
169 |
for disease in diseases:
|
170 |
print(disease)
|
171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
# %%
|
|
|
5 |
import requests
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
|
8 |
+
username = "demo"
|
9 |
+
password = "demo"
|
10 |
+
hostname = os.getenv("IRIS_HOSTNAME", "localhost")
|
11 |
+
port = "1972"
|
12 |
+
namespace = "USER"
|
13 |
+
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
|
14 |
+
|
15 |
+
engine = create_engine(CONNECTION_STRING)
|
16 |
+
|
17 |
|
18 |
def get_all_diseases_name(engine) -> List[List[str]]:
|
19 |
with engine.connect() as conn:
|
|
|
107 |
return clinical_records
|
108 |
|
109 |
|
110 |
+
def get_similarities_among_diseases_uris(
|
111 |
+
uri_list: List[str],
|
112 |
+
) -> List[tuple[str, str, float]]:
|
113 |
+
uri_list = ", ".join([f"'{uri}'" for uri in uri_list])
|
114 |
with engine.connect() as conn:
|
115 |
with conn.begin():
|
116 |
sql = f"""
|
117 |
SELECT e1.uri AS uri1, e2.uri AS uri2, VECTOR_COSINE(e1.embedding, e2.embedding) AS distance
|
118 |
FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2
|
119 |
+
WHERE e1.uri IN ({uri_list}) AND e2.uri IN ({uri_list}) AND e1.uri != e2.uri
|
120 |
"""
|
121 |
result = conn.execute(text(sql))
|
122 |
data = result.fetchall()
|
123 |
return data
|
124 |
|
125 |
|
126 |
+
def get_embedding(string: str, encoder) -> List[float]:
|
|
|
|
|
|
|
127 |
# Embed the string using sentence-transformers
|
128 |
vector = encoder.encode(string, show_progress_bar=False)
|
129 |
return vector
|
130 |
|
131 |
|
132 |
+
def get_diseases_related_to_a_textual_description(
|
133 |
+
description: str, encoder
|
134 |
+
) -> List[str]:
|
135 |
# Embed the description using sentence-transformers
|
136 |
+
description_embedding = get_embedding(description, encoder)
|
137 |
+
print(f"Size of the embedding: {len(description_embedding)}")
|
138 |
string_representation = str(description_embedding.tolist())[1:-1]
|
139 |
+
print(f"String representation: {string_representation}")
|
140 |
|
141 |
with engine.connect() as conn:
|
142 |
with conn.begin():
|
143 |
sql = f"""
|
144 |
+
SELECT TOP 5 d.uri, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
|
145 |
+
FROM Test.DiseaseDescriptions d
|
146 |
ORDER BY distance DESC
|
147 |
"""
|
148 |
result = conn.execute(text(sql))
|
149 |
data = result.fetchall()
|
150 |
+
|
151 |
+
return [{"uri": row[0], "distance": row[1]} for row in data]
|
152 |
|
153 |
|
154 |
if __name__ == "__main__":
|
|
|
175 |
clinical_record_info = get_clinical_records_by_ids(["NCT00841061"])
|
176 |
print(clinical_record_info)
|
177 |
|
178 |
+
textual_description = (
|
179 |
+
"A disease that causes memory loss and other cognitive impairments."
|
180 |
+
)
|
181 |
+
encoder = SentenceTransformer("allenai-specter")
|
182 |
+
diseases = get_diseases_related_to_a_textual_description(
|
183 |
+
textual_description, encoder
|
184 |
+
)
|
185 |
for disease in diseases:
|
186 |
print(disease)
|
187 |
|
188 |
+
try:
|
189 |
+
similarities = get_similarities_among_diseases_uris(
|
190 |
+
[
|
191 |
+
"http://identifiers.org/medgen/C4553765",
|
192 |
+
"http://identifiers.org/medgen/C4553176",
|
193 |
+
"http://identifiers.org/medgen/C4024935",
|
194 |
+
]
|
195 |
+
)
|
196 |
+
for similarity in similarities:
|
197 |
+
print(
|
198 |
+
f'{similarity[0].split("/")[-1]} and {similarity[1].split("/")[-1]} have a similarity of {similarity[2]}'
|
199 |
+
)
|
200 |
+
except Exception as e:
|
201 |
+
print(e)
|
202 |
+
|
203 |
# %%
|