import pandas as pd
import plotly.graph_objects as go
from typing import List, Dict, Any, Tuple, Union
_PUBMED_LINK= "https://pubmed.ncbi.nlm.nih.gov/{article_id}/"
_PMC_LINK = "https://www.ncbi.nlm.nih.gov/pmc/articles/{article_id}/"
_MARKDOWN_TEMPLATE = """# [{article_title}]({article_link})
# Filtered sections :
{sections_md}"""
# entities highlighted text
def get_highlighted_text(entities:List[Dict[str,Any]], original_text:str) -> List[Tuple[str,Union[str,None]]] :
"""Convert the output of the model to a list of tuples (entity, label)
for `gradio.HighlightedText`output"""
conversion = {"PrimaryOutcome":"primary","SecondaryOutcome":"secondary"}
highlighted_text = []
for entity in entities:
entity_original_text = original_text[entity["start"]:entity["end"]]
if entity["entity_group"] == "O":
entity_output = (entity_original_text, None)
else:
entity_output = (entity_original_text, conversion[entity["entity_group"]])
highlighted_text.append(entity_output)
return highlighted_text
# article filtered sections markdown output
def get_article_markdown(
article_id:str,
article_sections:dict[str,list[str]],
filtered_sections:dict[str,list[str]]) -> str:
"""Get the markdown of a list of sections"""
# link to online article
article_link = _PMC_LINK if article_id.startswith("PMC") else _PUBMED_LINK
article_link = article_link.format(article_id=article_id)
# get title, abstract, and filtered sections
article_title = article_sections["Title"][0]
sections_md = ""
for title, content in filtered_sections.items():
sections_md += f"## {title}\n"
sections_md += " ".join(content) + "\n"
return _MARKDOWN_TEMPLATE.format(
article_link=article_link,
article_title=article_title,
sections_md=sections_md
)
# registry dataframe display
def _highlight_df_rows(row):
if row['type'] =='primary':
return ['background-color: lightcoral'] * len(row)
elif row['type'] == 'secondary':
return ['background-color: lightgreen'] * len(row)
else :
return ['background-color: lightgrey'] * len(row)
def get_registry_dataframe(registry_outcomes: list[dict[str,str]]) -> str:
return pd.DataFrame(registry_outcomes).style.apply(_highlight_df_rows, axis=1)
# fcts for sankey diagram
def _sent_line_formatting(sentence:str, max_words:int=10) -> str:
"""format a sentence to be displayed in a sankey diagram so that
each line has a maximum of `max_words` words"""
words = sentence.split()
batchs = [words[i:i+max_words] for i in range(0, len(words), max_words)]
return "
".join([" ".join(batch) for batch in batchs])
def _find_entity_score(entity_text, raw_entities):
for tc_output in raw_entities:
if entity_text == tc_output["word"]:
return tc_output["score"]
def get_sankey_diagram(
registry_outcomes: list[tuple[str,str]],
article_outcomes: list[tuple[str,str]],
connections: set[tuple[int,int,float]],
raw_entities: list[Dict[str,Any]],
cosine_threshold: float=0.44,
) -> go.Figure:
color_map = {
"primary": "red",
"secondary": "green",
"other": "grey",
}
# Create lists of formatted sentences and colors for the nodes
list1 = [(_sent_line_formatting(sent), color_map[typ]) for typ, sent in registry_outcomes]
list2 = [(_sent_line_formatting(sent), color_map[typ]) for typ, sent in article_outcomes]
display_connections = [
(list1[i][0],list2[j][0],"mediumaquamarine") if cosine > cosine_threshold
else (list1[i][0],list2[j][0],"lightgray") for i,j,cosine in connections
]
# Create a list of labels and colors for the nodes
labels = [x[0] for x in list1 + list2]
colors = [x[1] for x in list1 + list2]
# Create lists of sources and targets for the connections
sources = [labels.index(x[0]) for x in display_connections]
targets = [labels.index(x[1]) for x in display_connections]
# Create a list of values and colors for the connections
values = [1] * len(display_connections)
connection_colors = [x[2] for x in display_connections]
# data appearing on hover of each node (outcome)
node_customdata = [f"from: registry
type:{t}" for t,_ in registry_outcomes]
node_customdata += [f"from: article
type: {t}
confidence: " + str(_find_entity_score(s, raw_entities)) for t,s in article_outcomes]
node_hovertemplate = "outcome: %{label}
%{customdata} "
# data appearing on hover of each link (node connections)
link_customdata = [cosine for _,_,cosine in connections]
link_hovertemplate = "similarity: %{customdata} "
# sankey diagram data filling
sankey = go.Sankey(
node=dict(
pad=15,
thickness=20,
line=dict(color="black", width=0.5),
label=labels,
color=colors,
customdata=node_customdata,
hovertemplate=node_hovertemplate
),
link=dict(
source=sources,
target=targets,
value=values,
customdata=link_customdata,
color=connection_colors,
hovertemplate=link_hovertemplate
)
)
# conversion to figure
fig = go.Figure(data=[sankey])
fig.update_layout(
title_text="Registry outcomes (left) connections with article outcomes (right), similarity threshold = " + str(cosine_threshold),
font_size=10,
width=1200,
xaxis=dict(rangeslider=dict(visible=True),type="linear")
)
return fig