xlreator's picture
Upload app.py
46d4dde verified
raw
history blame contribute delete
No virus
3.14 kB
import torch
import pandas as pd
import configparser
import gradio as gr
from gensim.models import KeyedVectors
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForTokenClassification, AutoTokenizer
from segmentation import segment
from utils import clean_entity
class Linker:
def __init__(self, config: dict[str, object],
context_window_width: int = -1):
self._vectors = None
self._emb_model = None
if context_window_width <= 0:
context_window_width = config['context_window_width']
self.context_window_width = context_window_width
self.config = config
def add_context(self, row: pd.Series) -> str:
window_start = max(0, row.start - self.context_window_width)
window_end = min(row.end + self.context_window_width, len(row.text))
return clean_entity(row.text[window_start:window_end])
def _load_embeddings(self):
self._vectors = KeyedVectors.load(self.config['keyed_vectors_file'])
def _load_model(self):
self._emb_model = SentenceTransformer(config['embedding_model'])
@property
def embeddings(self):
if self._vectors is None:
self._load_embeddings()
return self._vectors
@property
def embedding_model(self):
if self._emb_model is None:
self._load_model()
return self._emb_model
def link(self, df: pd.DataFrame) -> list[dict]:
mention_emb = self.embedding_model.encode(df.mention.str.lower().values)
concepts = [self.embeddings.most_similar(m, topn=1)[0][0]
for m in mention_emb]
return concepts
def highlight_text(spans: pd.DataFrame, text: str) -> list[tuple[str, object]]:
token_concepts = [None for _ in text]
for row in spans.itertuples():
for k in range(row.start, row.end):
token_concepts[k] = row.concept
return list(zip(list(text), token_concepts))
def entity_link(query: str) -> list[tuple[str, object]]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seg_model = AutoModelForTokenClassification.from_pretrained(
config['segmentation_model']
)
seg_tokenizer = AutoTokenizer.from_pretrained(
config['segmentation_tokenizer']
)
thresh = float(config['thresh'])
query_df = pd.DataFrame({'note_id': [0], 'text': [query]})
seg = segment(query_df, seg_model, seg_tokenizer, device, thresh)
linked_concepts = []
if len(seg) > 0:
seg = seg.sort_values('start')
linked_concepts = linker.link(seg)
seg['concept'] = linked_concepts
return highlight_text(seg, query)
config_parser = configparser.ConfigParser()
config_parser.read('config.ini')
config = config_parser['DEFAULT']
linker = Linker(config)
demo = gr.Interface(
fn=entity_link,
inputs=["text"],
outputs=gr.HighlightedText(
label="linking",
combine_adjacent=True,
),
theme=gr.themes.Base()
)
demo.launch()