Spaces:
Running
Running
import streamlit as st | |
from annotated_text import annotated_text | |
from refined.inference.processor import Refined | |
import requests | |
import json | |
import spacy | |
# Load German model | |
nlp_model_de = spacy.load("de_core_news_sm") | |
nlp_model_de.add_pipe("entityfishing", config={"language": "de"}) | |
# Page config | |
st.set_page_config( | |
page_title="Entity Linking by WordLift", | |
page_icon="fav-ico.png", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
menu_items={ | |
'Get Help': 'https://wordlift.io/book-a-demo/', | |
'About': "# This is a demo app for NEL/NED/NER and SEO" | |
} | |
) | |
# Sidebar | |
st.sidebar.image("logo-wordlift.png") | |
language_options = {"English", "German"} | |
selected_language = st.sidebar.selectbox("Select the Language", list(language_options)) | |
# Based on selected language, display model and entity set options | |
if selected_language != "German": | |
# Only show these options for languages other than German | |
model_options = {"aida_model", "wikipedia_model_with_numbers"} | |
selected_model_name = st.sidebar.selectbox("Select the Model", list(model_options)) | |
# Select entity_set | |
entity_set_options = {"wikidata", "wikipedia"} | |
selected_entity_set = st.sidebar.selectbox("Select the Entity Set", list(entity_set_options)) | |
else: | |
selected_model_name = None | |
selected_entity_set = None | |
# π Add the caching decorator | |
def load_model(selected_language, model_name=None, entity_set=None): | |
if selected_language == "German": | |
# Load the German-specific model | |
nlp_model_de = spacy.load("de_core_news_sm") | |
nlp_model_de.add_pipe("entityfishing", config={"language": "de"}) | |
return nlp_model_de | |
else: | |
# Load the pretrained model for other languages | |
refined_model = Refined.from_pretrained(model_name=model_name, entity_set=entity_set) | |
return refined_model | |
# Use the cached model | |
model = load_model(selected_language, selected_model_name, selected_entity_set) | |
# Addi citation | |
citation = """ | |
@inproceedings{ayoola-etal-2022-refined, | |
title = "{R}e{F}in{ED}: An Efficient Zero-shot-capable Approach to End-to-End Entity Linking", | |
author = "Tom Ayoola, Shubhi Tyagi, Joseph Fisher, Christos Christodoulopoulos, Andrea Pierleoni", | |
booktitle = "NAACL", | |
year = "2022" | |
} | |
""" | |
with st.sidebar.expander('Citations'): | |
st.markdown(citation) | |
# Helper functions | |
def get_wikidata_id(entity_string): | |
entity_list = entity_string.split("=") | |
entity_id = str(entity_list[1]) | |
entity_link = "http/www.wikidata.org/entity/" + entity_id | |
return {"id": entity_id, "link": entity_link} | |
def get_entity_data(entity_link): | |
try: | |
response = requests.get(f'https://api.wordlift.io/id/{entity_link}') | |
return response.json() | |
except Exception as e: | |
print(f"Exception when fetching data for entity: {entity_link}. Exception: {e}") | |
return None | |
# Create the form | |
with st.form(key='my_form'): | |
text_input = st.text_area(label='Enter a sentence') | |
submit_button = st.form_submit_button(label='Analyze') | |
# When processing the text, check the language and adjust processing accordingly | |
if text_input: | |
if selected_language == "German": | |
doc_de = model(text_input) | |
# Map entities to a format similar to English output | |
entities = [(ent.text, ent.label_, ent._.kb_qid, ent._.url_wikidata) for ent in doc_de.ents] | |
# Debug | |
for ent in doc_de.ents: | |
st.write(f"Entity: {ent.text}, Label: {ent.label_}, QID: {ent._.kb_qid}, URL: {ent._.url_wikidata}") | |
else: | |
entities = model.process_text(text_input) | |
# Logic for English language processing | |
entities_map = {} | |
entities_data = {} | |
for entity in entities: | |
if selected_language == "German": | |
entity_string, entity_type, wikidata_id, wikidata_url = entity | |
entities_map[entity_string] = {"id": wikidata_id, "link": wikidata_url} | |
entity_data = get_entity_data(wikidata_url) | |
if entity_data is not None: | |
entities_data[entity_string] = entity_data | |
else: | |
single_entity_list = str(entity).strip('][').replace("\'", "").split(', ') | |
if len(single_entity_list) >= 2 and "wikidata" in single_entity_list[1]: | |
entities_map[single_entity_list[0].strip()] = get_wikidata_id(single_entity_list[1]) | |
entity_data = get_entity_data(entities_map[single_entity_list[0].strip()]["link"]) | |
if entity_data is not None: | |
entities_data[single_entity_list[0].strip()] = entity_data | |
combined_entity_info_dictionary = dict([(k, [entities_map[k], entities_data[k] if k in entities_data else None]) for k in entities_map]) | |
if submit_button: | |
# Prepare a list to hold the final output | |
final_text = [] | |
# JSON-LD data | |
json_ld_data = { | |
"@context": "https://schema.org", | |
"@type": "WebPage", | |
"mentions": [] | |
} | |
# Replace each entity in the text with its annotated version | |
for entity_string, entity_info in entities_map.items(): | |
entity_data = entities_data.get(entity_string, None) | |
entity_type = None | |
if entity_data is not None: | |
entity_type = entity_data.get("@type", None) | |
# Use different colors based on the entity's type | |
color = "#8ef" # Default color | |
if entity_type == "Place": | |
color = "#8AC7DB" | |
elif entity_type == "Organization": | |
color = "#ADD8E6" | |
elif entity_type == "Person": | |
color = "#67B7D1" | |
elif entity_type == "Product": | |
color = "#2ea3f2" | |
elif entity_type == "CreativeWork": | |
color = "#00BFFF" | |
elif entity_type == "Event": | |
color = "#1E90FF" | |
entity_annotation = (entity_string, entity_info["id"], color) | |
text_input = text_input.replace(entity_string, f'{{{str(entity_annotation)}}}', 1) | |
# Add the entity to JSON-LD data | |
entity_json_ld = combined_entity_info_dictionary[entity_string][1] | |
json_ld_data["mentions"].append(entity_json_ld) | |
# Split the modified text_input into a list | |
text_list = text_input.split("{") | |
for item in text_list: | |
if "}" in item: | |
item_list = item.split("}") | |
final_text.append(eval(item_list[0])) | |
if len(item_list[1]) > 0: | |
final_text.append(item_list[1]) | |
else: | |
final_text.append(item) | |
# Pass the final_text to the annotated_text function | |
annotated_text(*final_text) | |
with st.expander("See annotations"): | |
st.write(combined_entity_info_dictionary) | |
with st.expander("Here is the final JSON-LD"): | |
st.json(json_ld_data) # Output JSON-LD |