import pandas as pd import plotly.graph_objects as go from typing import List, Dict, Any, Tuple, Union _PUBMED_LINK= "https://pubmed.ncbi.nlm.nih.gov/{article_id}/" _PMC_LINK = "https://www.ncbi.nlm.nih.gov/pmc/articles/{article_id}/" _MARKDOWN_TEMPLATE = """# [{article_title}]({article_link}) # Filtered sections : {sections_md}""" # entities highlighted text def get_highlighted_text(entities:List[Dict[str,Any]], original_text:str) -> List[Tuple[str,Union[str,None]]] : """Convert the output of the model to a list of tuples (entity, label) for `gradio.HighlightedText`output""" conversion = {"PrimaryOutcome":"primary","SecondaryOutcome":"secondary"} highlighted_text = [] for entity in entities: entity_original_text = original_text[entity["start"]:entity["end"]] if entity["entity_group"] == "O": entity_output = (entity_original_text, None) else: entity_output = (entity_original_text, conversion[entity["entity_group"]]) highlighted_text.append(entity_output) return highlighted_text # article filtered sections markdown output def get_article_markdown( article_id:str, article_sections:dict[str,list[str]], filtered_sections:dict[str,list[str]]) -> str: """Get the markdown of a list of sections""" # link to online article article_link = _PMC_LINK if article_id.startswith("PMC") else _PUBMED_LINK article_link = article_link.format(article_id=article_id) # get title, abstract, and filtered sections article_title = article_sections["Title"][0] sections_md = "" for title, content in filtered_sections.items(): sections_md += f"## {title}\n" sections_md += " ".join(content) + "\n" return _MARKDOWN_TEMPLATE.format( article_link=article_link, article_title=article_title, sections_md=sections_md ) # registry dataframe display def _highlight_df_rows(row): if row['type'] =='primary': return ['background-color: lightcoral'] * len(row) elif row['type'] == 'secondary': return ['background-color: lightgreen'] * len(row) else : return ['background-color: lightgrey'] * len(row) def get_registry_dataframe(registry_outcomes: list[dict[str,str]]) -> str: return pd.DataFrame(registry_outcomes).style.apply(_highlight_df_rows, axis=1) # fcts for sankey diagram def _sent_line_formatting(sentence:str, max_words:int=10) -> str: """format a sentence to be displayed in a sankey diagram so that each line has a maximum of `max_words` words""" words = sentence.split() batchs = [words[i:i+max_words] for i in range(0, len(words), max_words)] return "
".join([" ".join(batch) for batch in batchs]) def _find_entity_score(entity_text, raw_entities): for tc_output in raw_entities: if entity_text == tc_output["word"]: return tc_output["score"] def get_sankey_diagram( registry_outcomes: list[tuple[str,str]], article_outcomes: list[tuple[str,str]], connections: set[tuple[int,int,float]], raw_entities: list[Dict[str,Any]], cosine_threshold: float=0.44, ) -> go.Figure: color_map = { "primary": "red", "secondary": "green", "other": "grey", } # Create lists of formatted sentences and colors for the nodes list1 = [(_sent_line_formatting(sent), color_map[typ]) for typ, sent in registry_outcomes] list2 = [(_sent_line_formatting(sent), color_map[typ]) for typ, sent in article_outcomes] display_connections = [ (list1[i][0],list2[j][0],"mediumaquamarine") if cosine > cosine_threshold else (list1[i][0],list2[j][0],"lightgray") for i,j,cosine in connections ] # Create a list of labels and colors for the nodes labels = [x[0] for x in list1 + list2] colors = [x[1] for x in list1 + list2] # Create lists of sources and targets for the connections sources = [labels.index(x[0]) for x in display_connections] targets = [labels.index(x[1]) for x in display_connections] # Create a list of values and colors for the connections values = [1] * len(display_connections) connection_colors = [x[2] for x in display_connections] # data appearing on hover of each node (outcome) node_customdata = [f"from: registry
type:{t}" for t,_ in registry_outcomes] node_customdata += [f"from: article
type: {t}
confidence: " + str(_find_entity_score(s, raw_entities)) for t,s in article_outcomes] node_hovertemplate = "outcome: %{label}
%{customdata} " # data appearing on hover of each link (node connections) link_customdata = [cosine for _,_,cosine in connections] link_hovertemplate = "similarity: %{customdata} " # sankey diagram data filling sankey = go.Sankey( node=dict( pad=15, thickness=20, line=dict(color="black", width=0.5), label=labels, color=colors, customdata=node_customdata, hovertemplate=node_hovertemplate ), link=dict( source=sources, target=targets, value=values, customdata=link_customdata, color=connection_colors, hovertemplate=link_hovertemplate ) ) # conversion to figure fig = go.Figure(data=[sankey]) fig.update_layout( title_text="Registry outcomes (left) connections with article outcomes (right), similarity threshold = " + str(cosine_threshold), font_size=10, width=1200, xaxis=dict(rangeslider=dict(visible=True),type="linear") ) return fig