Spaces:
Sleeping
Sleeping
# %% | |
from typing import List, Dict, Any | |
import os | |
from sqlalchemy import create_engine, text | |
import requests | |
from sentence_transformers import SentenceTransformer | |
import streamlit as st | |
username = "demo" | |
password = "demo" | |
hostname = os.getenv("IRIS_HOSTNAME", "localhost") | |
port = "1972" | |
namespace = "USER" | |
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}" | |
engine = create_engine(CONNECTION_STRING) | |
def get_all_diseases_name(engine) -> List[List[str]]: | |
with engine.connect() as conn: | |
with conn.begin(): | |
sql = f""" | |
SELECT * FROM Test.EntityEmbeddings | |
""" | |
result = conn.execute(text(sql)) | |
data = result.fetchall() | |
all_diseases = [row[1] for row in data if row[1] != "nan"] | |
return all_diseases | |
def get_uri_from_name(engine, name: str) -> str: | |
with engine.connect() as conn: | |
with conn.begin(): | |
sql = f""" | |
SELECT uri FROM Test.EntityEmbeddings | |
WHERE label = '{name}' | |
""" | |
result = conn.execute(text(sql)) | |
data = result.fetchall() | |
return data[0][0].split("/")[-1] | |
def get_most_similar_diseases_from_uri( | |
engine, original_disease_uri: str, threshold: float = 0.8 | |
) -> List[str]: | |
with engine.connect() as conn: | |
with conn.begin(): | |
sql = f""" | |
SELECT * FROM Test.EntityEmbeddings | |
""" | |
result = conn.execute(text(sql)) | |
data = result.fetchall() | |
all_diseases = [row[1] for row in data if row[1] != "nan"] | |
return all_diseases | |
def get_uri_from_name(engine, name: str) -> str: | |
with engine.connect() as conn: | |
with conn.begin(): | |
sql = f""" | |
SELECT uri FROM Test.EntityEmbeddings | |
WHERE label = '{name}' | |
""" | |
result = conn.execute(text(sql)) | |
data = result.fetchall() | |
return data[0][0].split("/")[-1] | |
def get_most_similar_diseases_from_uri( | |
engine, original_disease_uri: str, threshold: float = 0.8 | |
) -> List[str]: | |
with engine.connect() as conn: | |
with conn.begin(): | |
sql = f""" | |
SELECT TOP 10 e1.uri AS uri1, e2.uri AS uri2, e1.label AS label1, e2.label AS label2, | |
VECTOR_COSINE(e1.embedding, e2.embedding) AS distance | |
FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2 | |
WHERE e1.uri = 'http://identifiers.org/medgen/{original_disease_uri}' | |
AND VECTOR_COSINE(e1.embedding, e2.embedding) > {threshold} | |
AND e1.uri != e2.uri | |
ORDER BY distance DESC | |
""" | |
result = conn.execute(text(sql)) | |
data = result.fetchall() | |
similar_diseases = [ | |
(row[1].split("/")[-1], row[3], row[4]) for row in data if row[3] != "nan" | |
] | |
return similar_diseases | |
def get_clinical_record_info(clinical_record_id: str) -> Dict[str, Any]: | |
# Request: | |
# curl -X GET "https://clinicaltrials.gov/api/v2/studies/NCT00841061" \ | |
# -H "accept: text/csv" | |
request_url = f"https://clinicaltrials.gov/api/v2/studies/{clinical_record_id}" | |
response = requests.get(request_url, headers={"accept": "application/json"}) | |
return response.json() | |
def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str, Any]]: | |
clinical_records = [] | |
for clinical_record_id in clinical_record_ids: | |
clinical_record_info = get_clinical_record_info(clinical_record_id) | |
clinical_records.append(clinical_record_info) | |
return clinical_records | |
def get_similarities_among_diseases_uris( | |
uri_list: List[str], | |
) -> List[tuple[str, str, float]]: | |
uri_list = ", ".join([f"'{uri}'" for uri in uri_list]) | |
with engine.connect() as conn: | |
with conn.begin(): | |
sql = f""" | |
SELECT e1.uri AS uri1, e2.uri AS uri2, VECTOR_COSINE(e1.embedding, e2.embedding) AS distance | |
FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2 | |
WHERE e1.uri IN ({uri_list}) AND e2.uri IN ({uri_list}) AND e1.uri != e2.uri | |
""" | |
result = conn.execute(text(sql)) | |
data = result.fetchall() | |
return data | |
def augment_the_set_of_diseaces(diseases: List[str]) -> str: | |
print(diseases) | |
for i in range(15-len(diseases)): | |
with engine.connect() as conn: | |
with conn.begin(): | |
sql = f""" | |
SELECT TOP 1 e2.uri AS new_disease, (SUM(VECTOR_COSINE(e1.embedding, e2.embedding))/ {len(diseases)}) AS score | |
FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2 | |
WHERE e1.uri IN ({','.join([f"'{disease}'" for disease in diseases])}) | |
AND e2.uri NOT IN ({','.join([f"'{disease}'" for disease in diseases])}) | |
AND e2.label != 'nan' | |
GROUP BY e2.label | |
ORDER BY score DESC | |
""" | |
result = conn.execute(text(sql)) | |
data = result.fetchall() | |
diseases.append(data[0][0].split('/')[-1]) | |
return diseases | |
def get_embedding(string: str, encoder) -> List[float]: | |
# Embed the string using sentence-transformers | |
vector = encoder.encode(string, show_progress_bar=False) | |
return vector | |
def get_diseases_related_to_a_textual_description( | |
description: str, encoder | |
) -> List[str]: | |
# Embed the description using sentence-transformers | |
description_embedding = get_embedding(description, encoder) | |
string_representation = str(description_embedding.tolist())[1:-1] | |
with engine.connect() as conn: | |
with conn.begin(): | |
sql = f""" | |
SELECT TOP 5 d.uri, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance | |
FROM Test.DiseaseDescriptions d | |
ORDER BY distance DESC | |
""" | |
result = conn.execute(text(sql)) | |
data = result.fetchall() | |
return [{"uri": row[0], "distance": row[1]} for row in data] | |
def get_clinical_trials_related_to_diseases( | |
diseases: List[str], encoder | |
) -> List[str]: | |
# Embed the diseases using sentence-transformers | |
diseases_string = ", ".join(diseases) | |
disease_embedding = get_embedding(diseases_string, encoder) | |
string_representation = str(disease_embedding.tolist())[1:-1] | |
with engine.connect() as conn: | |
with conn.begin(): | |
sql = f""" | |
SELECT TOP 10 d.nct_id, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance | |
FROM Test.ClinicalTrials d | |
ORDER BY distance DESC | |
""" | |
result = conn.execute(text(sql)) | |
data = result.fetchall() | |
return [{"nct_id": row[0], "distance": row[1]} for row in data] | |
def to_capitalized_case(string: str) -> str: | |
string = string.replace("_", " ") | |
if string.isupper(): | |
return string[0] + string[1:].lower() | |
def list_to_capitalized_case(strings: List[str]) -> str: | |
strings = [to_capitalized_case(s) for s in strings] | |
return ", ".join(strings) | |
def render_trial_details(trial: dict) -> None: | |
# TODO: handle key errors for all cases (→ do not render) | |
official_title = trial["protocolSection"]["identificationModule"]["officialTitle"] | |
st.write(f"##### {official_title}") | |
brief_summary = trial["protocolSection"]["descriptionModule"]["briefSummary"] | |
st.write(brief_summary) | |
status_module = { | |
"Status": to_capitalized_case(trial["protocolSection"]["statusModule"]["overallStatus"]), | |
"Status Date": trial["protocolSection"]["statusModule"]["statusVerifiedDate"], | |
"Has Results": trial["hasResults"] | |
} | |
st.write("###### Status") | |
st.table(status_module) | |
design_module = { | |
"Study Type": to_capitalized_case(trial["protocolSection"]["designModule"]["studyType"]), | |
"Phases": list_to_capitalized_case(trial["protocolSection"]["designModule"]["phases"]), | |
"Allocation": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["allocation"]), | |
"Primary Purpose": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["primaryPurpose"]), | |
"Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"]["count"], | |
"Masking": to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"]["masking"]), | |
"Who Masked": list_to_capitalized_case(trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"]["whoMasked"]) | |
} | |
st.write("###### Design") | |
st.table(design_module) | |
interventions_module = {} | |
for intervention in trial["protocolSection"]["armsInterventionsModule"]["interventions"]: | |
name = intervention["name"] | |
desc = intervention["description"] | |
interventions_module[name] = desc | |
st.write("###### Interventions") | |
st.table(interventions_module) | |
if __name__ == "__main__": | |
username = "demo" | |
password = "demo" | |
hostname = os.getenv("IRIS_HOSTNAME", "localhost") | |
port = "1972" | |
namespace = "USER" | |
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}" | |
try: | |
engine = create_engine(CONNECTION_STRING) | |
diseases = get_most_similar_diseases_from_uri("C1843013") | |
for disease in diseases: | |
print(disease) | |
except Exception as e: | |
print(e) | |
try: | |
print(get_uri_from_name(engine, "Alzheimer disease 3")) | |
except Exception as e: | |
print(e) | |
clinical_record_info = get_clinical_records_by_ids(["NCT00841061"]) | |
print(clinical_record_info) | |
textual_description = ( | |
"A disease that causes memory loss and other cognitive impairments." | |
) | |
encoder = SentenceTransformer("allenai-specter") | |
diseases = get_diseases_related_to_a_textual_description( | |
textual_description, encoder | |
) | |
for disease in diseases: | |
print(disease) | |
try: | |
similarities = get_similarities_among_diseases_uris( | |
[ | |
"http://identifiers.org/medgen/C4553765", | |
"http://identifiers.org/medgen/C4553176", | |
"http://identifiers.org/medgen/C4024935", | |
] | |
) | |
for similarity in similarities: | |
print( | |
f'{similarity[0].split("/")[-1]} and {similarity[1].split("/")[-1]} have a similarity of {similarity[2]}' | |
) | |
except Exception as e: | |
print(e) | |
# %% | |