import argparse import html import time from extend import spacy_component # this is needed to register the spacy component import spacy import streamlit as st from annotated_text import annotation from classy.scripts.model.demo import tabbed_navigation from classy.utils.streamlit import get_md_200_random_color_generator def main( model_checkpoint_path: str, inventory_path: str, cuda_device: int, ): # setup examples examples = [ "Japan began the defence of their title with a lucky 2-1 win against Syria in a Group C championship match on Friday.", "The project was coded in Java.", "Rome is in Italy", ] # define load_resources @st.cache(allow_output_mutation=True) def load_resources(inventory_path): # load nlp nlp = spacy.load("en_core_web_sm") extend_config = dict( checkpoint_path=model_checkpoint_path, mentions_inventory_path=inventory_path, device=cuda_device, tokens_per_batch=10_000, ) nlp.add_pipe("extend", after="ner", config=extend_config) # mock call to load resources nlp(examples[0]) # return return nlp # preload default resources load_resources(inventory_path) # css rules st.write( """ """, unsafe_allow_html=True, ) # setup header st.markdown( "

ExtEnD: Extractive Entity Disambiguation

", unsafe_allow_html=True, ) st.write( """
Python spaCy
""", unsafe_allow_html=True, ) # how it works def hiw(): st.markdown(""" ## How it works ExtEnD frames Entity Disambiguation as a text extraction problem: """) st.image( "data/repo-assets/extend_formulation.png", caption="ExtEnD Formulation" ) st.markdown( """ Given the sentence *After a long fight Superman saved Metropolis*, where *Superman* is the mention to disambiguate, ExtEnD first concatenates the descriptions of all the possible candidates of *Superman* in the inventory and then selects the span whose description best suits the mention in its context. To use ExtEnD for full end2end entity linking, as we do in *Demo*, we just need to leverage a mention identifier. Here [we use spaCy](https://github.com/SapienzaNLP/extend#spacy) (more specifically, its NER) and run ExtEnD on each named entity spaCy identifies (if the corresponding mention is contained in the inventory). ##### Links: * [Full Paper](https://www.researchgate.net/publication/359392427_ExtEnD_Extractive_Entity_Disambiguation) * [GitHub](https://github.com/SapienzaNLP/extend) """ ) # demo def demo(): st.markdown("## Demo") # read input placeholder = st.selectbox( "Examples", options=examples, index=0, ) input_text = st.text_area("Input text to entity-disambiguate", placeholder) # button should_disambiguate = st.button("Disambiguate", key="classify") # load model and color generator nlp = load_resources(inventory_path) color_generator = get_md_200_random_color_generator() if should_disambiguate: # tag sentence time_start = time.perf_counter() doc = nlp(input_text) time_end = time.perf_counter() # extract entities entities = {} for ent in doc.ents: if ent._.disambiguated_entity is not None: entities[ent.start_char] = ( ent.start_char, ent.end_char, ent.text, ent._.disambiguated_entity, ) # create annotated html components annotated_html_components = [] assert all(any(t.idx == _s for t in doc) for _s in entities) it = iter(list(doc)) while True: try: t = next(it) except StopIteration: break if t.idx in entities: _start, _end, _text, _entity = entities[t.idx] while t.idx + len(t) != _end: t = next(it) annotated_html_components.append( f"{annotation(*(_text, _entity, color_generator()))}" ) else: annotated_html_components.append(str(html.escape(t.text))) st.markdown( "\n".join( [ "
", *annotated_html_components, "

" f'

Time: {(time_end - time_start):.2f}s

' "
", ] ), unsafe_allow_html=True, ) demo() hiw() if __name__ == "__main__": main( "experiments/extend-longformer-large/2021-10-22/09-11-39/checkpoints/best.ckpt", "data/inventories/aida.tsv", cuda_device=-1, )