Spaces:
Sleeping
Sleeping
import streamlit as st | |
import gradio as gr | |
import shap | |
import numpy as np | |
import scipy as sp | |
import torch | |
import transformers | |
from transformers import pipeline | |
from transformers import RobertaTokenizer, RobertaModel | |
from transformers import AutoModelForSequenceClassification | |
from transformers import TFAutoModelForSequenceClassification | |
from transformers import AutoTokenizer, AutoModelForTokenClassification | |
import matplotlib.pyplot as plt | |
import sys | |
import csv | |
csv.field_size_limit(sys.maxsize) | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained("WyattMiller/Mod4Team5") | |
model = AutoModelForSequenceClassification.from_pretrained("WyattMiller/Mod4Team5").to(device) | |
# build a pipeline object to do predictions | |
pred = transformers.pipeline("text-classification", model=model, | |
tokenizer=tokenizer, return_all_scores=True) | |
explainer = shap.Explainer(pred) | |
## | |
# classifier = transformers.pipeline("text-classification", model = "cross-encoder/qnli-electra-base") | |
# def med_score(x): | |
# label = x['label'] | |
# score_1 = x['score'] | |
# return round(score_1,3) | |
# def sym_score(x): | |
# label2sym= x['label'] | |
# score_1sym = x['score'] | |
# return round(score_1sym,3) | |
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all") | |
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all") | |
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu | |
# | |
def adr_predict(x): | |
encoded_input = tokenizer(x, return_tensors='pt') | |
output = model(**encoded_input) | |
scores = output[0][0].detach() | |
scores = torch.nn.functional.softmax(scores) | |
shap_values = explainer([str(x).lower()]) | |
# # Find the index of the class you want as the default reference (e.g., 'label_1') | |
# label_1_index = np.where(np.array(explainer.output_names) == 'label_1')[0][0] | |
# # Plot the SHAP values for a specific instance in your dataset (e.g., instance 0) | |
# shap.plots.text(shap_values[label_1_index][0]) | |
local_plot = shap.plots.text(shap_values[0], display=False) | |
# med = med_score(classifier(x+str(", There is a medication."))[0]) | |
# sym = sym_score(classifier(x+str(", There is a symptom."))[0]) | |
res = ner_pipe(x) | |
entity_colors = { | |
'Severity': 'red', | |
'Sign_symptom': 'green', | |
'Medication': 'lightblue', | |
'Age': 'yellow', | |
'Sex':'yellow', | |
'Diagnostic_procedure':'gray', | |
'Biological_structure':'silver'} | |
htext = "" | |
prev_end = 0 | |
for entity in res: | |
start = entity['start'] | |
end = entity['end'] | |
word = entity['word'].replace("##", "") | |
color = entity_colors[entity['entity_group']] | |
htext += f"{x[prev_end:start]}<mark style='background-color:{color};'>{word}</mark>" | |
prev_end = end | |
htext += x[prev_end:] | |
return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot,htext | |
# ,{"Contains Medication": float(med), "No Medications": float(1-med)} , {"Contains Symptoms": float(sym), "No Symptoms": float(1-sym)} | |
def main(prob1): | |
text = str(prob1).lower() | |
obj = adr_predict(text) | |
return obj[0],obj[1],obj[2] | |
title = "Welcome to **ADR Detector** 🪐" | |
description1 = """This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons. Please do NOT use for medical diagnosis.""" | |
with gr.Blocks(title=title) as demo: | |
gr.Markdown(f"## {title}") | |
gr.Markdown(description1) | |
gr.Markdown("""---""") | |
prob1 = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...") | |
submit_btn = gr.Button("Analyze") | |
with gr.Row(): | |
with gr.Column(visible=True) as output_col: | |
label = gr.Label(label = "Predicted Label") | |
with gr.Column(visible=True) as output_col: | |
local_plot = gr.HTML(label = 'Shap:') | |
htext = gr.HTML(label="NER") | |
# med = gr.Label(label = "Contains Medication") | |
# sym = gr.Label(label = "Contains Symptoms") | |
submit_btn.click( | |
main, | |
[prob1], | |
[label | |
,local_plot, htext | |
# , med, sym | |
], api_name="adr" | |
) | |
with gr.Row(): | |
gr.Markdown("### Click on any of the examples below to see how it works:") | |
gr.Examples([["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."], | |
["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]], | |
[prob1], [label,local_plot, htext | |
# , med, sym | |
], main, cache_examples=True) | |
demo.launch() |