Spaces:
Runtime error
Runtime error
TuringsSolutions
commited on
Commit
•
6546065
1
Parent(s):
8b72437
Update app.py
Browse files
app.py
CHANGED
@@ -5,23 +5,15 @@ import spacy
|
|
5 |
from spacy.tokens import Doc, Span
|
6 |
from relik import Relik
|
7 |
from relik.inference.data.objects import TaskType, RelikOutput
|
8 |
-
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
|
9 |
from pyvis.network import Network
|
10 |
|
11 |
# RELIK Models Setup
|
12 |
-
|
13 |
-
|
14 |
-
wikidata_retriever = Relik.from_pretrained("relik-ie/encoder-e5-small-v2-wikipedia-relations", device="cuda")
|
15 |
-
wikidata_index = InMemoryDocumentIndex.from_pretrained("relik-ie/encoder-e5-small-v2-wikipedia-relations-index", index_precision="bf16", device="cuda")
|
16 |
|
17 |
relik_models = {
|
18 |
-
"sapienzanlp/relik-entity-linking-large":
|
19 |
-
|
20 |
-
reader_kwargs={"dataset_kwargs": {"use_nme": True}}
|
21 |
-
),
|
22 |
-
"relik-ie/relik-relation-extraction-small": Relik.from_pretrained(
|
23 |
-
"relik-ie/relik-relation-extraction-small", index=wikidata_index, device="cuda", retriever=wikidata_retriever
|
24 |
-
)
|
25 |
}
|
26 |
|
27 |
def get_span_annotations(response, doc):
|
@@ -84,4 +76,4 @@ with gr.Blocks(fill_height=True, css=css, theme=theme) as demo:
|
|
84 |
allow_flagging="never"
|
85 |
)
|
86 |
if __name__ == "__main__":
|
87 |
-
demo.launch()
|
|
|
5 |
from spacy.tokens import Doc, Span
|
6 |
from relik import Relik
|
7 |
from relik.inference.data.objects import TaskType, RelikOutput
|
|
|
8 |
from pyvis.network import Network
|
9 |
|
10 |
# RELIK Models Setup
|
11 |
+
def setup_relik_model(model_name: str, device: str):
|
12 |
+
return Relik.from_pretrained(model_name, device=device)
|
|
|
|
|
13 |
|
14 |
relik_models = {
|
15 |
+
"sapienzanlp/relik-entity-linking-large": setup_relik_model("sapienzanlp/relik-entity-linking-large", "cuda"),
|
16 |
+
"relik-ie/relik-relation-extraction-small": setup_relik_model("relik-ie/relik-relation-extraction-small", "cuda")
|
|
|
|
|
|
|
|
|
|
|
17 |
}
|
18 |
|
19 |
def get_span_annotations(response, doc):
|
|
|
76 |
allow_flagging="never"
|
77 |
)
|
78 |
if __name__ == "__main__":
|
79 |
+
demo.launch()
|