|
import streamlit as st |
|
import gradio as gr |
|
import shap |
|
import numpy as np |
|
import scipy as sp |
|
import torch |
|
import transformers |
|
from transformers import pipeline |
|
from transformers import RobertaTokenizer, RobertaModel |
|
from transformers import AutoModelForSequenceClassification |
|
from transformers import TFAutoModelForSequenceClassification |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
from transformers import XLNetTokenizer, XLNetForSequenceClassification |
|
import matplotlib.pyplot as plt |
|
import sys |
|
import csv |
|
import sentencepiece |
|
csv.field_size_limit(sys.maxsize) |
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("vikvenk/ADR_Detection") |
|
model = AutoModelForSequenceClassification.from_pretrained("vikvenk/ADR_Detection").to(device) |
|
|
|
pred = transformers.pipeline("text-classification", model=model, |
|
tokenizer=tokenizer, return_all_scores=True) |
|
|
|
explainer = shap.Explainer(pred) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") |
|
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") |
|
|
|
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") |
|
|
|
|
|
def adr_predict(x): |
|
encoded_input = tokenizer(x, return_tensors='pt') |
|
output = model(**encoded_input) |
|
scores = output[0][0].detach() |
|
scores = torch.nn.functional.softmax(scores) |
|
|
|
shap_values = explainer([str(x).lower()]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
local_plot = shap.plots.text(shap_values[0], display=False) |
|
|
|
|
|
|
|
|
|
res = ner_pipe(x) |
|
|
|
entity_colors = { |
|
'Severity': 'red', |
|
'Sign_symptom': 'green', |
|
'Medication': 'lightblue', |
|
'Age': 'yellow', |
|
'Sex':'yellow', |
|
'Diagnostic_procedure':'gray', |
|
'Biological_structure':'silver'} |
|
|
|
htext = "" |
|
prev_end = 0 |
|
|
|
for entity in res: |
|
start = entity['start'] |
|
end = entity['end'] |
|
word = entity['word'].replace("##", "") |
|
color = entity_colors[entity['entity_group']] |
|
|
|
htext += f"{x[prev_end:start]}<mark style='background-color:{color};'>{word}</mark>" |
|
prev_end = end |
|
|
|
htext += x[prev_end:] |
|
|
|
return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot,htext |
|
|
|
|
|
|
|
def main(prob1): |
|
text = str(prob1).lower() |
|
obj = adr_predict(text) |
|
return obj[0],obj[1],obj[2] |
|
|
|
title = "Welcome to **ADR Detector** 🪐" |
|
description1 = """This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons. Please do NOT use for medical diagnosis.""" |
|
|
|
with gr.Blocks(title=title) as demo: |
|
gr.Markdown(f"## {title}") |
|
gr.Markdown(description1) |
|
gr.Markdown("""---""") |
|
prob1 = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...") |
|
submit_btn = gr.Button("Analyze") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(visible=True) as output_col: |
|
label = gr.Label(label = "Predicted Label") |
|
|
|
|
|
with gr.Column(visible=True) as output_col: |
|
local_plot = gr.HTML(label = 'Shap:') |
|
htext = gr.HTML(label="NER") |
|
|
|
|
|
|
|
submit_btn.click( |
|
main, |
|
[prob1], |
|
[label |
|
,local_plot, htext |
|
|
|
], api_name="adr" |
|
) |
|
|
|
with gr.Row(): |
|
gr.Markdown("### Click on any of the examples below to see how it works:") |
|
gr.Examples([["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."], |
|
["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]], |
|
[prob1], [label,local_plot, htext |
|
|
|
], main, cache_examples=True) |
|
|
|
demo.launch() |