|
import time |
|
import streamlit as st |
|
import torch |
|
import string |
|
|
|
from transformers import BertTokenizer, BertForMaskedLM |
|
|
|
@st.cache() |
|
def load_bert_model(model_name): |
|
try: |
|
bert_tokenizer = BertTokenizer.from_pretrained(model_name) |
|
bert_model = BertForMaskedLM.from_pretrained(model_name).eval() |
|
return bert_tokenizer,bert_model |
|
except Exception as e: |
|
pass |
|
|
|
|
|
|
|
|
|
def decode(tokenizer, pred_idx, top_clean): |
|
ignore_tokens = string.punctuation + '[PAD]' |
|
tokens = [] |
|
for w in pred_idx: |
|
token = ''.join(tokenizer.decode(w).split()) |
|
if token not in ignore_tokens: |
|
tokens.append(token.replace('##', '')) |
|
return '\n'.join(tokens[:top_clean]) |
|
|
|
def encode(tokenizer, text_sentence, add_special_tokens=True): |
|
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token) |
|
|
|
if tokenizer.mask_token == text_sentence.split()[-1]: |
|
text_sentence += ' .' |
|
|
|
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)]) |
|
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0] |
|
return input_ids, mask_idx |
|
|
|
def get_all_predictions(text_sentence, top_clean=5): |
|
|
|
input_ids, mask_idx = encode(bert_tokenizer, text_sentence) |
|
with torch.no_grad(): |
|
predict = bert_model(input_ids)[0] |
|
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean) |
|
return {'bert': bert} |
|
|
|
def get_bert_prediction(input_text,top_k): |
|
try: |
|
input_text += ' <mask>' |
|
res = get_all_predictions(input_text, top_clean=int(top_k)) |
|
return res |
|
except Exception as error: |
|
pass |
|
|
|
try: |
|
|
|
st.title("Qualitative evaluation of Pretrained BERT models") |
|
st.markdown(""" |
|
<a href="https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html"><small style="font-size:18px; color: #8f8f8f">This app is used to qualitatively examine the performance of pretrained models to do NER , <b>with no fine tuning</b></small></a> |
|
""", unsafe_allow_html=True) |
|
st.write("Incomplete. Work in progress...") |
|
|
|
st.write("CLS vectors as well as the model prediction for a blank position are examined") |
|
|
|
top_k = 10 |
|
print(top_k) |
|
|
|
|
|
bert_tokenizer, bert_model = load_bert_model('ajitrajasekharan/biomedical') |
|
default_text = "Imatinib is used to treat" |
|
|
|
|
|
input_text = st.text_area( |
|
label="Original text", |
|
value=default_text, |
|
) |
|
|
|
start = None |
|
if st.button("Submit"): |
|
start = time.time() |
|
with st.spinner("Computing"): |
|
|
|
|
|
|
|
|
|
try: |
|
res = get_bert_prediction(default_text,top_k) |
|
|
|
|
|
st.header("JSON:") |
|
|
|
st.json(res) |
|
|
|
except Exception as e: |
|
st.error("Some error occured!" + str(e)) |
|
st.stop() |
|
|
|
st.write("---") |
|
|
|
|
|
|
|
if start is not None: |
|
st.text(f"prediction took {time.time() - start:.2f}s") |
|
|
|
except Exception as e: |
|
print("SOME PROBLEM OCCURED") |
|
|
|
|
|
|