Spaces:
Sleeping
Sleeping
import streamlit as st | |
from streamlit_agraph import agraph, Node, Edge, Config | |
import os | |
from sqlalchemy import create_engine, text | |
import pandas as pd | |
import time | |
from utils import ( | |
get_all_diseases_name, | |
get_most_similar_diseases_from_uri, | |
get_uri_from_name, | |
get_diseases_related_to_a_textual_description, | |
get_similarities_among_diseases_uris, | |
augment_the_set_of_diseaces, | |
get_clinical_trials_related_to_diseases, | |
get_clinical_records_by_ids | |
) | |
from llm_res import process_dictionaty_with_llm_to_generate_response | |
import json | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
# variables to reveal next steps | |
show_graph = False | |
show_analyze_status = False | |
show_overview = False | |
show_details = False | |
# IRIS connection | |
username = "demo" | |
password = "demo" | |
hostname = os.getenv("IRIS_HOSTNAME", "localhost") | |
port = "1972" | |
namespace = "USER" | |
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}" | |
engine = create_engine(CONNECTION_STRING) | |
with st.container(): # user input | |
col1, col2 = st.columns((6, 1)) | |
with col1: | |
description_input = st.text_area(label="Enter the disease description 👇", placeholder='A disease that causes memory loss and other cognitive impairments.') | |
with col2: | |
st.text('') # dummy to center vertically | |
st.text('') # dummy to center vertically | |
st.text('') # dummy to center vertically | |
show_analyze_status = st.button("Analyze 🔎") | |
# analyze | |
with st.container(): | |
if show_analyze_status: | |
with st.status("Analyzing...") as status: | |
# 1. Embed the textual description that the user entered using the model | |
# 2. Get 5 diseases with the highest cosine silimarity from the DB | |
status.write("Analyzing the description that you wrote...") | |
encoder = SentenceTransformer("allenai-specter") | |
diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description( | |
description_input, encoder | |
) | |
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases) | |
status.write("Getting the similarities among the diseases to filter out less promising ones...") | |
diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text] | |
get_similarities_among_diseases_uris(diseases_uris) | |
# 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8) | |
# 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases | |
status.write("Augmenting the set of diseases by finding others with related embeddings...") | |
augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris) | |
# print(augmented_set_of_diseases) | |
# 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases | |
status.write("Getting the clinical trials related to the diseases found...") | |
clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases( | |
augmented_set_of_diseases, encoder | |
) | |
status.write("Getting the details of the clinical trials...") | |
json_of_clinical_trials = get_clinical_records_by_ids( | |
[trial["nct_id"] for trial in clinical_trials_related_to_the_diseases] | |
) | |
status.json(json_of_clinical_trials) | |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format. | |
status.write("Getting a summary of the clinical trials...") | |
response = process_dictionaty_with_llm_to_generate_response(json_of_clinical_trials) | |
print(f'Response from LLM: {response}') | |
# 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that. | |
status.write("Getting summary statistics of the clinical trials...") | |
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered | |
status.update(label="Done!", state="complete") | |
time.sleep(1) | |
show_graph = True | |
# graph | |
with st.container(): | |
if show_graph: | |
# TODO actual graph | |
graph_of_diseases = agraph( | |
nodes=[ | |
Node(id="A", label="Node A", size=10), | |
Node(id="B", label="Node B", size=10), | |
Node(id="C", label="Node C", size=10), | |
Node(id="D", label="Node D", size=10), | |
Node(id="E", label="Node E", size=10), | |
Node(id="F", label="Node F", size=10), | |
Node(id="G", label="Node G", size=10), | |
Node(id="H", label="Node H", size=10), | |
Node(id="I", label="Node I", size=10), | |
Node(id="J", label="Node J", size=10), | |
], | |
edges=[ | |
Edge(source="A", target="B"), | |
Edge(source="B", target="C"), | |
Edge(source="C", target="D"), | |
Edge(source="D", target="E"), | |
Edge(source="E", target="F"), | |
Edge(source="F", target="G"), | |
Edge(source="G", target="H"), | |
Edge(source="H", target="I"), | |
Edge(source="I", target="J"), | |
], | |
config=Config(height=500, width=500), | |
) | |
time.sleep(2) | |
show_overview = True | |
# overview | |
with st.container(): | |
if show_overview: | |
st.write("## Disease Overview") | |
disease_overview = ":red[lorem ipsum]" # TODO | |
st.write(disease_overview) | |
time.sleep(2) | |
show_details = True | |
# details | |
with st.container(): | |
if show_details: | |
st.write("## Clinical Trials Details") | |
trials = [] | |
# TODO replace mock data | |
with open("mock_trial.json") as f: | |
d = json.load(f) | |
for i in range(0, 5): | |
trials.append(d) | |
for trial in trials: | |
with st.expander(f"{trial['protocolSection']['identificationModule']['nctId']}"): | |
official_title = trial["protocolSection"]["identificationModule"][ | |
"officialTitle" | |
] | |
st.write(f"##### {official_title}") | |
brief_summary = trial["protocolSection"]["descriptionModule"]["briefSummary"] | |
st.write(brief_summary) | |
status_module = { | |
"Status": trial["protocolSection"]["statusModule"]["overallStatus"], | |
"Status Date": trial["protocolSection"]["statusModule"][ | |
"statusVerifiedDate" | |
], | |
} | |
st.write("###### Status") | |
st.table(status_module) | |
design_module = { | |
"Study Type": trial["protocolSection"]["designModule"]["studyType"], | |
# "Phases": trial["protocolSection"]["designModule"]["phases"], # breaks formatting because it is an array | |
"Allocation": trial["protocolSection"]["designModule"]["designInfo"][ | |
"allocation" | |
], | |
"Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"][ | |
"count" | |
], | |
} | |
st.write("###### Design") | |
st.table(design_module) | |
# TODO more modules? | |