File size: 14,919 Bytes
a6bd112
93e1b64
27d40b9
a6bd112
 
 
 
 
 
 
 
 
 
27d40b9
a6bd112
 
27d40b9
a6bd112
 
27d40b9
a6bd112
27d40b9
a6bd112
 
e7d7b51
a6bd112
27d40b9
93e1b64
9b5c4aa
 
 
 
 
4bb7c94
27d40b9
9b5c4aa
27d40b9
 
 
 
 
93e1b64
 
 
27d40b9
ee05396
 
a6bd112
 
 
ee05396
a6bd112
9b5c4aa
 
 
a6bd112
 
 
 
9b5c4aa
a6bd112
 
 
9b5c4aa
 
 
 
 
 
 
 
 
1e2e3b8
9b5c4aa
a6bd112
 
 
 
 
 
 
9b5c4aa
4021316
47c6369
9b5c4aa
a6bd112
 
 
 
 
 
47c6369
a6bd112
 
 
47c6369
a6bd112
 
 
e7d7b51
a6bd112
 
 
 
 
 
e7d7b51
47c6369
9b5c4aa
 
a6bd112
 
 
9b5c4aa
a6bd112
 
 
 
 
 
 
 
 
 
 
 
 
e7d7b51
 
9b5c4aa
1e2e3b8
a6bd112
 
 
 
 
 
 
9b5c4aa
47c6369
 
1e2e3b8
9b5c4aa
 
 
a6bd112
ee05396
47c6369
1e2e3b8
4bb7c94
 
 
 
 
 
a6bd112
 
 
 
1b1c01c
a6bd112
1b1c01c
 
4bb7c94
 
a6bd112
4bb7c94
a6bd112
 
1b1c01c
a6bd112
 
 
 
 
 
9b5c4aa
 
4021316
9b5c4aa
e7d7b51
9b5c4aa
 
 
 
 
3cfdd19
 
 
47c6369
3cfdd19
47c6369
3cfdd19
 
a6bd112
 
9b5c4aa
a6bd112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b5c4aa
 
 
 
 
4bb7c94
 
 
 
 
a6bd112
4bb7c94
 
 
 
 
 
 
 
a6bd112
4bb7c94
 
 
 
 
 
 
 
a6bd112
4bb7c94
 
9b5c4aa
 
 
 
 
 
 
a6bd112
 
 
 
796a53b
 
 
 
 
ec6a815
76919d3
1b1c01c
 
 
 
 
 
 
 
a6bd112
76919d3
1b1c01c
 
76919d3
1b1c01c
 
76919d3
a6bd112
 
 
 
76919d3
 
a6bd112
 
 
1b1c01c
 
a6bd112
1b1c01c
 
a6bd112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b1c01c
a6bd112
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
import json
import os
import time

import matplotlib
import numpy as np
import pandas as pd
import streamlit as st
from sentence_transformers import SentenceTransformer
from sqlalchemy import create_engine, text
from streamlit_agraph import Config, Edge, Node, agraph

from llm_res import get_short_summary_out_of_json_files, tagging_insights_from_json
from utils import (
    augment_the_set_of_diseaces,
    filter_out_less_promising_diseases,
    get_all_diseases_name,
    get_clinical_records_by_ids,
    get_clinical_trials_related_to_diseases,
    get_diseases_related_to_a_textual_description,
    get_most_similar_diseases_from_uri,
    get_similarities_among_diseases_uris,
    get_similarities_df,
    get_uri_from_name,
    render_trial_details,
    get_labels_of_diseases_from_uris,
)

# variables to reveal next steps
show_graph = False
show_analyze_status = False
show_overview = False
show_details = False
show_metrics = False

# IRIS connection
username = "demo"
password = "demo"
hostname = os.getenv("IRIS_HOSTNAME", "localhost")
port = "1972"
namespace = "USER"
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
engine = create_engine(CONNECTION_STRING)


st.image("img_klinic.jpeg", caption="(AI-generated image)", use_column_width=True)
st.title("Klìnic", help="AI-powered clinical trial search engine")
st.subheader(
    "Find clinical trials in a scoped domain of biomedical research, guiding your research with AI-powered insights."
)

with st.container():  # user input
    col1, col2 = st.columns((6, 1))

    with col1:
        description_input = st.text_area(
            label="Enter a disease description 👇",
            placeholder="A disorder manifested in memory loss and other cognitive impairments among elderly patients (60+ years old), especially women.",
        )
    with col2:
        st.text("")  # dummy to center vertically
        st.text("")  # dummy to center vertically
        st.text("")  # dummy to center vertically
        show_analyze_status = st.button("Analyze 🔎")


# analyze
with st.container():
    if show_analyze_status:
        with st.status("Analyzing...") as status:
            # 1. Embed the textual description that the user entered using the model
            # 2. Get 5 diseases with the highest cosine silimarity from the DB
            status.write("Analyzing the description that you wrote...")
            encoder = SentenceTransformer("allenai-specter")
            diseases_related_to_the_user_text = (
                get_diseases_related_to_a_textual_description(
                    description_input, encoder
                )
            )
            status.info(
                f"Selected {len(diseases_related_to_the_user_text)} diseases related to the description you entered."
            )
            status.json(diseases_related_to_the_user_text, expanded=False)
            status.divider()
            # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
            status.write(
                "Getting the similarities among the diseases to filter out less promising ones..."
            )
            diseases_uris = [
                disease["uri"] for disease in diseases_related_to_the_user_text
            ]
            similarities = get_similarities_among_diseases_uris(diseases_uris)
            status.info(
                f"Obtained similarity information among the diseases by measuring the cosine similarity of their embeddings."
            )
            status.json(similarities, expanded=False)
            filtered_diseases_uris, df_similarities = (
                filter_out_less_promising_diseases(similarities)
            )
            # Apply a colormap to the table
            status.table(
                df_similarities.style.background_gradient(cmap="viridis", axis=None)
            )
            status.info(
                f"Filtered out less promising diseases, keeping {len(filtered_diseases_uris)} diseases."
            )
            status.json(filtered_diseases_uris, expanded=False)
            status.divider()
            # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
            # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
            status.write(
                "Augmenting the set of diseases by finding others with related embeddings..."
            )
            augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
            similarities_of_augmented_set_of_diseases = (
                get_similarities_among_diseases_uris(augmented_set_of_diseases)
            )
            df_similarities_augmented_set = get_similarities_df(
                similarities_of_augmented_set_of_diseases
            )
            status.table(
                df_similarities_augmented_set.style.background_gradient(cmap="viridis", axis=None)
            )
            status.json(similarities_of_augmented_set_of_diseases, expanded=True)
            status.info(
                f"Augmented set of diseases: {len(augmented_set_of_diseases)} diseases."
            )
            status.json(augmented_set_of_diseases, expanded=False)
            status.divider()
            # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
            status.write("Getting the clinical trials related to the diseases found...")
            clinical_trials_related_to_the_diseases = (
                get_clinical_trials_related_to_diseases(
                    augmented_set_of_diseases, encoder
                )
            )
            status.info(
                f"Selected {len(clinical_trials_related_to_the_diseases)} clinical trials related to the diseases."
            )
            status.json(clinical_trials_related_to_the_diseases, expanded=False)
            status.divider()
            status.write("Getting the details of the clinical trials...")
            json_of_clinical_trials = get_clinical_records_by_ids(
                [trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
            )
            status.success(f"Details of the clinical trials obtained.")
            status.json(json_of_clinical_trials, expanded=False)
            status.divider()
            # 7. Use an LLM to get a summary of the clinical trials, in plain text format.
            try:
                status.write("Getting a summary of the clinical trials...")
                response = get_short_summary_out_of_json_files(json_of_clinical_trials)
                status.success("Summary of the clinical trials obtained.")
                disease_overview = response
            except Exception as e:
                print(f"Error while getting a summary of the clinical trials: {e}")
                status.warning(
                    f"Error while getting a summary of the clinical trials. This information will not be shown."
                )
            try:
                # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
                status.write("Getting summary statistics of the clinical trials...")
                response = tagging_insights_from_json(json_of_clinical_trials)
                average_minimum_age = response["avg_min_age"]
                average_maximum_age = response["avg_max_age"]
                most_common_gender = response["most_common_gender"]

                print(f"Response from LLM tagging: {response}")
                status.success(f"Summary statistics of the clinical trials obtained.")
            except Exception as e:
                print(
                    f"Error while extracting numerical data from the clinical trials: {e}"
                )
                status.warning(
                    f"Error while extracting numerical data from the clinical trials. This information will not be shown."
                )
            # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
            status.update(label="Done!", state="complete")
            status.balloons()
            show_graph = True
            trials = json_of_clinical_trials


# graph
with st.container():
    if show_graph:
        st.info(
            """This is a graph of the relevant diseases that we found, based on the description that you entered. The diseases are connected by edges if they are similar to each other. The color of the edges represents the similarity of the diseases.
                
We use the embeddings of the diseases to determine the similarity between them. The embeddings are generated using a Representation Learning algorithm that learns the topological relations among the nodes in the graph, depending on how they are connected. We utilize the [PyKeen](https://github.com/pykeen/pykeen) implementation of TransH to train an embedding model.

[TransH](https://ojs.aaai.org/index.php/AAAI/article/view/8870) utilizes hyperplanes to model relations between entities. It is a multi-relational model that can handle many-to-many relations between entities. The model is trained on the triples of the graph, where the triples are the subject, relation, and object of the graph. The model learns the embeddings of the entities and the relations, such that the embeddings of the subject and object are close to each other when the relation is true.

Specifically, it optimizes the following cost function:
$\\text{minimize} \\sum_{(h, r, t) \\in S} \\max(0, \\gamma + f(h, r, t) - f(h, r, t')) + \\sum_{(h, r, t) \\in S'} f(h, r, t)$
"""
        )
        try:
            edges_to_show = []
            labels_of_diseases = get_labels_of_diseases_from_uris(
                df_similarities_augmented_set.index
            )
            uris_and_labels_of_diseases = dict(
                zip(df_similarities_augmented_set.index, labels_of_diseases)
            )
            color_mapper = matplotlib.cm.get_cmap("viridis")
            for source in df_similarities_augmented_set.index:
                for target in df_similarities_augmented_set.columns:
                    if source != target:
                        weight = df_similarities_augmented_set.loc[source, target]
                        color = color_mapper(weight)
                        # Convert from rgba to hex
                        color = matplotlib.colors.to_hex(color)
                        edges_to_show.append(
                            Edge(
                                source=source,
                                target=target,
                                # Dynamic color based on the weight
                                color=color,
                                weight=weight**10,
                                type="CURVE_SMOOTH",
                                label=f"{weight:.2f}",
                            )
                        )
            graph_of_diseases = agraph(
                nodes=[
                    Node(
                        id=disease,
                        label=disease,#uris_and_labels_of_diseases[disease],
                        size=25,
                        shape="circular",
                    )
                    for disease in df_similarities_augmented_set.index
                ],
                edges=edges_to_show,
                config=Config(height=500, width=500),
            )
            time.sleep(2)
        except Exception as e:
            print(f"Error while showing the graph of the diseases: {e}")
            st.error("Error while showing the graph of the diseases.")
        finally:
            show_overview = True


# overview
with st.container():
    if show_overview:
        try:
            st.write("## Overview of Related Clinical Trials")
            st.write(disease_overview)
            time.sleep(1)
        except Exception as e:
            print(f"Error while showing the overview of the clinical trials: {e}")
        finally:
            show_metrics = True


with st.container():
    if show_metrics:
        try:
            st.write("## Metrics of the Clinical Trials")
            col1, col2, col3 = st.columns(3)
            with col1:
                st.metric("Average Minimum Age", average_minimum_age)
            with col2:
                st.metric("Average Maximum Age", average_maximum_age)
            with col3:
                st.metric("Most Common Gender", most_common_gender)
            time.sleep(2)
        except Exception as e:
            print(f"Error while showing the metrics: {e}")
        finally:
            show_details = True


# details
with st.container():
    if show_details:
        st.write("## Clinical Trials Details")

        tab_titles = [
            f"{trial['protocolSection']['identificationModule']['nctId']}"
            for trial in trials
        ]

        tabs = st.tabs(tab_titles)

        for i in range(0, len(tabs)):
            with tabs[i]:
                render_trial_details(trials[i])

show_graph_of_all_diseases = False
if show_graph_of_all_diseases:
    # If disease_names is not defined, define it
    if "disease_names" not in st.session_state:
        st.session_state.disease_names = get_all_diseases_name(engine)
    chosen_disease_name = st.selectbox(
        "Choose a disease",
        st.session_state.disease_names,
    )

    st.write("You selected:", chosen_disease_name)
    chosen_disease_uri = get_uri_from_name(engine, chosen_disease_name)

    nodes = []
    edges = []

    nodes.append(
        Node(
            id=chosen_disease_uri, label=chosen_disease_name, size=25, shape="circular"
        )
    )

    similar_diseases = get_most_similar_diseases_from_uri(
        engine, chosen_disease_uri, threshold=0.6
    )
    print(similar_diseases)
    for uri, name, weight in similar_diseases:
        nodes.append(Node(id=uri, label=name, size=25, shape="circular"))

        print(True if float(weight) > 0.7 else False)
        edges.append(
            Edge(
                source=chosen_disease_uri,
                target=uri,
                color="red" if float(weight) > 0.7 else "blue",
                weight=float(weight) ** 10,
                type="CURVE_SMOOTH",
                #    type="STRAIGHT"
            )
        )

    config = Config(
        width=750,
        height=950,
        directed=False,
        physics=True,
        hierarchical=False,
        collapsible=False,
        # **kwargs
    )

    return_value = agraph(nodes=nodes, edges=edges, config=config)