import argparse import html import time from extend import spacy_component # this is needed to register the spacy component import spacy import streamlit as st from annotated_text import annotation from classy.scripts.model.demo import tabbed_navigation from classy.utils.streamlit import get_md_200_random_color_generator def main( model_checkpoint_path: str, default_inventory_path: str, cuda_device: int, ): # setup examples examples = [ "Italy beat England and won Euro 2021.", "Japan began the defence of their Asian Cup title with a lucky 2-1 win against Syria in a Group C championship match on Friday.", "The project was coded in Java.", ] # css rules st.write( """ """, unsafe_allow_html=True, ) # setup header st.markdown( "

ExtEnD: Extractive Entity Disambiguation

", unsafe_allow_html=True, ) st.write( """
Python spaCy
""", unsafe_allow_html=True, ) # how it works def hiw(): st.markdown(""" ## How it works ExtEnD frames Entity Disambiguation as a text extraction problem: """) st.image( "data/repo-assets/extend_formulation.png", caption="ExtEnD Formulation" ) st.markdown( """ Given the sentence *After a long fight Superman saved Metropolis*, where *Superman* is the mention to disambiguate, ExtEnD first concatenates the descriptions of all the possible candidates of *Superman* in the inventory and then selects the span whose description best suits the mention in its context. To convert this task to end2end entity linking, as we do in *Model demo*, we leverage spaCy (more specifically, its NER) and run ExtEnD on each named entity spaCy identifies (if the corresponding mention is contained in the inventory). Links: * [full paper](https://www.researchgate.net/publication/359392427_ExtEnD_Extractive_Entity_Disambiguation) * [GitHub](https://github.com/SapienzaNLP/extend) """ ) # demo def demo(): st.markdown("## Demo") @st.cache(allow_output_mutation=True) def load_resources(inventory_path): # load nlp nlp = spacy.load("en_core_web_sm") extend_config = dict( checkpoint_path=model_checkpoint_path, mentions_inventory_path=inventory_path, device=cuda_device, tokens_per_batch=10_000, ) nlp.add_pipe("extend", after="ner", config=extend_config) # mock call to load resources nlp(examples[0]) # return return nlp # read input placeholder = st.selectbox( "Examples", options=examples, index=0, ) input_text = st.text_area("Input text to entity-disambiguate", placeholder) # custom inventory uploaded_inventory_path = st.file_uploader( "[Optional] Upload custom inventory (tsv file, mention \\t desc1 \\t desc2 \\t)", accept_multiple_files=False, type=["tsv"], ) if uploaded_inventory_path is not None: inventory_path = f"data/inventories/{uploaded_inventory_path.name}" with open(inventory_path, "wb") as f: f.write(uploaded_inventory_path.getbuffer()) else: inventory_path = default_inventory_path # load model and color generator nlp = load_resources(inventory_path) color_generator = get_md_200_random_color_generator() if st.button("Disambiguate", key="classify"): # tag sentence time_start = time.perf_counter() doc = nlp(input_text) time_end = time.perf_counter() # extract entities entities = {} for ent in doc.ents: if ent._.disambiguated_entity is not None: entities[ent.start_char] = ( ent.start_char, ent.end_char, ent.text, ent._.disambiguated_entity, ) # create annotated html components annotated_html_components = [] assert all(any(t.idx == _s for t in doc) for _s in entities) it = iter(list(doc)) while True: try: t = next(it) except StopIteration: break if t.idx in entities: _start, _end, _text, _entity = entities[t.idx] while t.idx + len(t) != _end: t = next(it) annotated_html_components.append( str(annotation(*(_text, _entity, color_generator()))) ) else: annotated_html_components.append(str(html.escape(t.text))) st.markdown( "\n".join( [ "
", *annotated_html_components, "

" f'

Time: {(time_end - time_start):.2f}s

' "
", ] ), unsafe_allow_html=True, ) demo() hiw() if __name__ == "__main__": main( "experiments/extend-longformer-large/2021-10-22/09-11-39/checkpoints/best.ckpt", "data/inventories/aida.tsv", cuda_device=-1, )