File size: 10,575 Bytes
67a2b13 ddb3f9b bb04844 ddb3f9b 67a2b13 48e06d4 67a2b13 2b81d2f 67a2b13 bb04844 67a2b13 22e4917 67a2b13 2c67aa0 67a2b13 2c67aa0 67a2b13 2c67aa0 bb04844 2c67aa0 22e4917 2c67aa0 d1f912d 2c67aa0 67a2b13 2c67aa0 d1f912d bb04844 2c67aa0 bb04844 2c67aa0 2b81d2f 2c67aa0 67a2b13 ddb3f9b 1484edd ddb3f9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import argparse
import html
import time
from extend import spacy_component # this is needed to register the spacy component
import spacy
import streamlit as st
from annotated_text import annotation
from classy.utils.streamlit import get_md_200_random_color_generator
def main(
model_checkpoint_path: str,
inventory_path: str,
cuda_device: int,
):
# setup examples
examples = [
"Japan began the defence of their title with a lucky 2-1 win against Syria in a championship match on Friday.",
"The project was coded in Java.",
"Rome is in Italy",
]
# define load_resources
@st.cache(allow_output_mutation=True)
def load_resources(inventory_path):
# load nlp
nlp = spacy.load("en_core_web_sm")
extend_config = dict(
checkpoint_path=model_checkpoint_path,
mentions_inventory_path=inventory_path,
device=cuda_device,
tokens_per_batch=10_000,
)
nlp.add_pipe("extend", after="ner", config=extend_config)
# mock call to load resources
nlp(examples[0])
# return
return nlp
# preload default resources
load_resources(inventory_path)
# css rules
st.write(
"""
<style type="text/css">
a {
text-decoration: none !important;
}
</style>
""",
unsafe_allow_html=True,
)
# setup header
st.markdown(
"<h1 style='text-align: center;'>ExtEnD: Extractive Entity Disambiguation</h1>",
unsafe_allow_html=True,
)
st.write(
"""
<div align="center">
<a href="https://sunglasses-ai.github.io/classy/">
<img alt="Python" style="height: 3em; margin: 0em 1em 2em 1em;" src="">
</a>
<a href="https://spacy.io/" tyle="text-decoration: none">
<img alt="spaCy" style="height: 3em; margin: 0em 1em 2em 1em;" src="">
</a>
</div>
""",
unsafe_allow_html=True,
)
# how it works
def hiw():
st.markdown("""
## How it works
ExtEnD frames Entity Disambiguation as a text extraction problem:
""")
st.image(
"data/repo-assets/extend_formulation.png", caption="ExtEnD Formulation"
)
st.markdown(
"""
Given the sentence *After a long fight Superman saved Metropolis*, where *Superman* is the mention
to disambiguate, ExtEnD first concatenates the descriptions of all the possible candidates of *Superman* in the
inventory and then selects the span whose description best suits the mention in its context.
To use ExtEnD for full end2end entity linking, as we do in *Demo*, we just need to leverage a mention
identifier. Here [we use spaCy](https://github.com/SapienzaNLP/extend#spacy) (more specifically, its NER) and run ExtEnD on each named
entity spaCy identifies (if the corresponding mention is contained in the inventory).
##### Links:
* [Full Paper](https://www.researchgate.net/publication/359392427_ExtEnD_Extractive_Entity_Disambiguation)
* [GitHub](https://github.com/SapienzaNLP/extend)
"""
)
# demo
def demo():
st.markdown("## Demo")
# read input
placeholder = st.selectbox(
"Examples",
options=examples,
index=0,
)
input_text = st.text_area("Input text to entity-disambiguate", placeholder)
# button
should_disambiguate = st.button("Disambiguate", key="classify")
# load model and color generator
nlp = load_resources(inventory_path)
color_generator = get_md_200_random_color_generator()
if should_disambiguate:
# tag sentence
time_start = time.perf_counter()
doc = nlp(input_text)
time_end = time.perf_counter()
# extract entities
entities = {}
for ent in doc.ents:
if ent._.disambiguated_entity is not None:
entities[ent.start_char] = (
ent.start_char,
ent.end_char,
ent.text,
ent._.disambiguated_entity,
)
# create annotated html components
annotated_html_components = []
assert all(any(t.idx == _s for t in doc) for _s in entities)
it = iter(list(doc))
while True:
try:
t = next(it)
except StopIteration:
break
if t.idx in entities:
_start, _end, _text, _entity = entities[t.idx]
while t.idx + len(t) != _end:
t = next(it)
annotated_html_components.append(
f"<a href=\"https://en.wikipedia.org/wiki/{_entity.lower().replace(' ', '_').capitalize()}\">{annotation(*(_text, _entity, color_generator()))}</a>"
)
else:
annotated_html_components.append(str(html.escape(t.text)))
st.markdown(
"\n".join(
[
"<div>",
*annotated_html_components,
"<p></p>"
f'<div style="text-align: right"><p style="color: gray">Time: {(time_end - time_start):.2f}s</p></div>'
"</div>",
]
),
unsafe_allow_html=True,
)
demo()
hiw()
if __name__ == "__main__":
main(
"experiments/extend-longformer-large/2021-10-22/09-11-39/checkpoints/best.ckpt",
"data/inventories/le-and-titov-2018-inventory.min-count-2.sqlite3",
cuda_device=-1,
)
|