GlobEnc / app.py
akhaliq's picture
akhaliq HF staff
Update app.py
9ba191d
raw
history blame
4.88 kB
import gradio as gr
import torch
import numpy as np
from transformers import AutoTokenizer
from src.modeling.modeling_bert import BertForSequenceClassification
from src.modeling.modeling_electra import ElectraForSequenceClassification
from src.attention_rollout import AttentionRollout
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.backends.backend_pdf
def inference(text, model):
if model == "bert-base-uncased-cls-sst2":
config = {
# As of now, BERT and ELECTRA are supported. You can choose any checkpoing of these models.
### BERT-base
"MODEL": "TehranNLP-org/bert-base-uncased-cls-sst2"
# "MODEL": "TehranNLP-org/bert-base-uncased-cls-mnli"
# "MODEL": "TehranNLP-org/bert-base-uncased-cls-hatexplain"
### BERT-large
# "MODEL": "TehranNLP-org/bert-large-sst2"
# "MODEL": "TehranNLP-org/bert-large-mnli"
# "MODEL": "TehranNLP-org/bert-large-hateXplain"
### ELECTRA
# "MODEL": "TehranNLP-org/electra-base-sst2"
# "MODEL": "TehranNLP-org/electra-base-mnli"
# "MODEL": "TehranNLP-org/electra-base-hateXplain"
}
elif model == "bert-large-sst2":
config = {
# As of now, BERT and ELECTRA are supported. You can choose any checkpoing of these models.
### BERT-base
#"MODEL": "TehranNLP-org/bert-base-uncased-cls-sst2"
# "MODEL": "TehranNLP-org/bert-base-uncased-cls-mnli"
# "MODEL": "TehranNLP-org/bert-base-uncased-cls-hatexplain"
### BERT-large
"MODEL": "TehranNLP-org/bert-large-sst2"
# "MODEL": "TehranNLP-org/bert-large-mnli"
# "MODEL": "TehranNLP-org/bert-large-hateXplain"
### ELECTRA
# "MODEL": "TehranNLP-org/electra-base-sst2"
# "MODEL": "TehranNLP-org/electra-base-mnli"
# "MODEL": "TehranNLP-org/electra-base-hateXplain"
}
else:
config = {
# As of now, BERT and ELECTRA are supported. You can choose any checkpoing of these models.
### BERT-base
#"MODEL": "TehranNLP-org/bert-base-uncased-cls-sst2"
# "MODEL": "TehranNLP-org/bert-base-uncased-cls-mnli"
# "MODEL": "TehranNLP-org/bert-base-uncased-cls-hatexplain"
### BERT-large
#"MODEL": "TehranNLP-org/bert-large-sst2"
# "MODEL": "TehranNLP-org/bert-large-mnli"
# "MODEL": "TehranNLP-org/bert-large-hateXplain"
### ELECTRA
"MODEL": "TehranNLP-org/electra-base-sst2"
# "MODEL": "TehranNLP-org/electra-base-mnli"
# "MODEL": "TehranNLP-org/electra-base-hateXplain"
}
SENTENCE = text
tokenizer = AutoTokenizer.from_pretrained(config["MODEL"])
tokenized_sentence = tokenizer.encode_plus(SENTENCE, return_tensors="pt")
if "bert" in config["MODEL"]:
model = BertForSequenceClassification.from_pretrained(config["MODEL"])
elif "electra" in config["MODEL"]:
model = ElectraForSequenceClassification.from_pretrained(config["MODEL"])
else:
raise Exception(f"Not implented model: {config['MODEL']}")
# Extract single layer attentions
with torch.no_grad():
logits, attentions, norms = model(**tokenized_sentence, output_attentions=True, output_norms=True, return_dict=False)
num_layers = len(attentions)
norm_nenc = torch.stack([norms[i][4] for i in range(num_layers)]).squeeze().cpu().numpy()
print("Single layer N-Enc token attribution:", norm_nenc.shape)
# Aggregate and compute GlobEnc
globenc = AttentionRollout().compute_flows([norm_nenc], output_hidden_states=True)[0]
globenc = np.array(globenc)
print("Aggregated N-Enc token attribution (GlobEnc):", globenc.shape)
tokenized_text = tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][0])
plt.figure(figsize=(14, 8))
norm_cls = globenc[:, 0, :]
norm_cls = np.flip(norm_cls, axis=0)
row_sums = norm_cls.max(axis=1)
norm_cls = norm_cls / row_sums[:, np.newaxis]
df = pd.DataFrame(norm_cls, columns=tokenized_text, index=range(len(norm_cls), 0, -1))
ax = sns.heatmap(df, cmap="Reds", square=True)
bottom, top = ax.get_ylim()
ax.set_ylim(bottom + 0.5, top - 0.5)
plt.title("GlobEnc", fontsize=16)
plt.ylabel("Layer", fontsize=16)
plt.xticks(rotation = 90, fontsize=16)
plt.yticks(fontsize=13)
plt.gcf().subplots_adjust(bottom=0.2)
print("logits:", logits)
return plt
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
# Hello World!
Start typing below to see the output.
""")
inp = [gr.Textbox(),gr.Dropdown(choices=['bert-base-uncased-cls-sst2','bert-large-sst2','electra-base-sst2'])]
out = gr.Plot()
button = gr.Button(value="Run")
button.click(fn=inference,
inputs=inp,
outputs=out)
demo.launch()