xlreator's picture
initial commit
67eaf9f verified
raw
history blame
No virus
3.13 kB
import torch
import pandas as pd
import configparser
import gradio as gr
from gensim.models import KeyedVectors
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForTokenClassification, AutoTokenizer
from segmentation import segment
from utils import clean_entity
class Linker:
def __init__(self, config: dict[str, object],
context_window_width: int = -1):
self._vectors = None
self._emb_model = None
if context_window_width <= 0:
context_window_width = config['context_window_width']
self.context_window_width = context_window_width
self.config = config
def add_context(self, row: pd.Series) -> str:
window_start = max(0, row.start - self.context_window_width)
window_end = min(row.end + self.context_window_width, len(row.text))
return clean_entity(row.text[window_start:window_end])
def _load_embeddings(self):
self._vectors = KeyedVectors.load(self.config['keyed_vectors_file'])
def _load_model(self):
self._emb_model = SentenceTransformer(config['embedding_model'])
@property
def embeddings(self):
if self._vectors is None:
self._load_embeddings()
return self._vectors
@property
def embedding_model(self):
if self._emb_model is None:
self._load_model()
return self._emb_model
def link(self, df: pd.DataFrame) -> list[dict]:
mention_emb = self.embedding_model.encode(df.mention.str.lower().values)
concepts = [self.embeddings.most_similar(m, topn=1)[0][0]
for m in mention_emb]
return concepts
def highlight_text(spans: pd.DataFrame, text: str) -> list[tuple[str, object]]:
token_concepts = [None for _ in text]
for row in spans.itertuples():
for k in range(row.start, row.end):
token_concepts[k] = row.concept
return list(zip(list(text), token_concepts))
def entity_link(query: str) -> list[tuple[str, object]]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seg_model = AutoModelForTokenClassification.from_pretrained(
config['segmentation_model']
)
seg_tokenizer = AutoTokenizer.from_pretrained(
config['segmentation_tokenizer']
)
thresh = float(config['thresh'])
query_df = pd.DataFrame({'note_id': [0], 'text': [query]})
seg = segment(query_df, seg_model, seg_tokenizer, device, thresh)
linked_concepts = []
if len(seg) > 0:
seg = seg.sort_values('start')
linked_concepts = linker.link(seg)
seg['concept'] = linked_concepts
return highlight_text(seg, query)
config_parser = configparser.ConfigParser()
config_parser.read('config.ini')
config = config_parser['DEFAULT']
linker = Linker(config)
demo = gr.Interface(
fn=entity_link,
inputs=["text"],
outputs=gr.HighlightedText(
label="linking",
combine_adjacent=True,
),
theme=gr.themes.Base()
)