UTC_HandyLab / interfaces /relation_e.py
BioMike's picture
Upload 22 files
5571f23 verified
raw
history blame
No virus
4.81 kB
from utca.core import RenameAttribute
from utca.implementation.predictors import TokenSearcherPredictor, TokenSearcherPredictorConfig
from utca.implementation.tasks import TokenSearcherNER, TokenSearcherNERPostprocessor, TokenSearcherRelationExtraction, TokenSearcherRelationExtractionPostprocessor
from typing import Dict, Union
import gradio as gr
text = """
Dr. Paul Hammond, a renowned neurologist at Johns Hopkins University, has recently published a paper in the prestigious journal \"Nature Neuroscience\".
His research focuses on a rare genetic mutation, found in less than 0.01% of the population, that appears to prevent the development of Alzheimer's disease.
Collaborating with researchers at the University of California, San Francisco, the team is now working to understand the mechanism by which this mutation confers its protective effect.
Funded by the National Institutes of Health, their research could potentially open new avenues for Alzheimer's treatment.
"""
predictor = TokenSearcherPredictor(
TokenSearcherPredictorConfig(
device="cuda:0",
model="knowledgator/UTC-DeBERTa-large-v2"
)
)
pipe = (
TokenSearcherNER( # TokenSearcherNER task produces classified entities that will be at the "output" key.
predictor=predictor,
postprocess=TokenSearcherNERPostprocessor(
threshold=0.5 # Entity threshold
)
)
| RenameAttribute("output", "entities") # Rename output entities from TokenSearcherNER task to use them as inputs in TokenSearcherRelationExtraction
| TokenSearcherRelationExtraction( # TokenSearcherRelationExtraction is used for relation extraction.
predictor=predictor,
postprocess=TokenSearcherRelationExtractionPostprocessor(
threshold=0.5 # Relation threshold
)
)
)
def process(
relation: str, text, distance_threshold: str, pairs_filter: str, labels: str
) -> Dict[str, Union[str, int, float]]:
pairs_filter = [tuple(pair.strip() for pair in pair.split("->")) for pair in pairs_filter.split(",")]
if len(distance_threshold) < 1 or not distance_threshold or not distance_threshold.strip().isdigit():
r = pipe.run({
"text": text,
"labels": [label.strip() for label in labels.split(",")],
"relations": [{
"relation": relation,
"pairs_filter": pairs_filter
}]
})
elif int(distance_threshold.strip()):
r = pipe.run({
"text": text,
"labels": [label.strip() for label in labels.split(",")],
"relations": [{
"relation": relation,
"pairs_filter": pairs_filter,
"distance_threshold": int(distance_threshold.replace(" ", ""))
}]
})
return r["output"]
relation_e_examples = [
[
"worked at",
text,
"None",
"scientist -> university, scientist -> other",
"scientist, university, city, research, journal"]
]
with gr.Blocks(title="Open Information Extracting") as relation_e_interface:
relation = gr.Textbox(label="Relation", placeholder="Enter relation you want to extract here")
input_text = gr.Textbox(label="Text input", placeholder="Enter your text here")
labels = gr.Textbox(label="Labels", placeholder="Enter your labels here (comma separated)", scale=2)
pairs_filter = gr.Textbox(label="Pairs Filter", placeholder="It specifies possible members of relations by their entity labels. Write as: source -> target,..")
distance_threshold = gr.Textbox(label="Distance Threshold", placeholder="It specifies the max distance in characters between spans in the text")
output = gr.Textbox(label="Predicted Relation")
submit_btn = gr.Button("Submit")
examples = gr.Examples(
relation_e_examples,
fn=process,
inputs=[relation, input_text, distance_threshold, pairs_filter, labels],
outputs=output,
cache_examples=True
)
theme=gr.themes.Base()
input_text.submit(fn=process, inputs=[relation, input_text, distance_threshold, pairs_filter, labels], outputs=output)
labels.submit(fn=process, inputs=[relation, input_text, distance_threshold, pairs_filter, labels], outputs=output)
pairs_filter.submit(fn=process, inputs=[relation, input_text, distance_threshold, pairs_filter, labels], outputs=output)
submit_btn.click(fn=process, inputs=[relation, input_text, distance_threshold, pairs_filter, labels], outputs=output)
distance_threshold.submit(fn=process, inputs=[relation, input_text, distance_threshold, pairs_filter, labels], outputs=output)
if __name__ == "__main__":
relation_e_interface.launch()