File size: 3,994 Bytes
c337225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96a2369
c337225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
The SpEL annotation visualization script. You can use this script as a playground to explore the capabilities and
limitations of the SpEL framework.
"""
import torch
from model import SpELAnnotator
from data_loader import dl_sa
from utils import chunk_annotate_and_merge_to_phrase
from candidate_manager import CandidateManager
import streamlit as st
from annotated_text import annotated_text
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@st.cache_resource
def load_model():
    load_aida_finetuned = True
    load_full_vocabulary=True
    candidate_setting = "n"
    model = SpELAnnotator()
    model.init_model_from_scratch(device=device)
    candidates_manager_to_use = CandidateManager(dl_sa.mentions_vocab,
                                                 is_kb_yago=candidate_setting == "k",
                                                 is_ppr_for_ned=candidate_setting.startswith("p"),
                                                 is_context_agnostic=candidate_setting == "pg",
                                                 is_indexed_for_spans=True) if candidate_setting != "n" else None
    if load_aida_finetuned and not load_full_vocabulary:
        model.shrink_classification_head_to_aida(device=device)
        model.load_checkpoint(None, device=device, load_from_torch_hub=True, finetuned_after_step=3)
    elif load_aida_finetuned:
        model.load_checkpoint(None, device=device, load_from_torch_hub=True, finetuned_after_step=4)
    else:
        model.load_checkpoint(None, device=device, load_from_torch_hub=True, finetuned_after_step=2)
    return model, candidates_manager_to_use

annotator, candidates_manager = load_model()
st.title("SpEL Prediction Visualization")
st.caption('Running the \"[SpEL-base-step3-500K.pt](https://vault.sfu.ca/index.php/s/8nw5fFXdz2yBP5z/download)\" model without any hand-crafted candidate sets. For more information please checkout [SpEL\'s github repository](https://github.com/shavarani/SpEL).')
mention = st.text_input("Enter the text:")
process_button = st.button("Annotate")

if process_button and mention:
    phrase_annotations = chunk_annotate_and_merge_to_phrase(
        annotator, mention, k_for_top_k_to_keep=5, normalize_for_chinese_characters=True)
    last_step_annotations = [[p.words[0].token_offsets[0][1][0],
                              p.words[-1].token_offsets[-1][1][-1],
                              (dl_sa.mentions_itos[p.resolved_annotation], p.subword_annotations)]
                             for p in phrase_annotations if p.resolved_annotation != 0]
    if candidates_manager:
        for p in phrase_annotations:
            candidates_manager.modify_phrase_annotation_using_candidates(p, mention)
    if last_step_annotations:
        anns = sorted([(l_ann[0], l_ann[1], l_ann[2][0]) for l_ann in last_step_annotations], key=lambda x: x[0])
        begin = 0
        last_char = len(mention)
        anns_pointer = 0
        processed_anns = []
        anno_text = []
        while begin < last_char:
            if anns_pointer == len(anns):
                processed_anns.append((begin, last_char, "O"))
                anno_text.append(mention[begin: last_char])
                begin = last_char
                continue
            first_unprocessed_annotation = anns[anns_pointer]
            if first_unprocessed_annotation[0] > begin:
                processed_anns.append((begin, first_unprocessed_annotation[0], "O"))
                anno_text.append(mention[begin: first_unprocessed_annotation[0]])
                begin = first_unprocessed_annotation[0]
            else:
                processed_anns.append(first_unprocessed_annotation)
                anns_pointer += 1
                begin = first_unprocessed_annotation[1]
                anno_text.append((mention[first_unprocessed_annotation[0]: first_unprocessed_annotation[1]], first_unprocessed_annotation[2]))
        annotated_text(anno_text)
    else:
        annotated_text(mention)