|
import gradio as gr |
|
|
|
import torch |
|
import numpy as np |
|
from transformers import AutoTokenizer |
|
from src.modeling.modeling_bert import BertForSequenceClassification |
|
from src.modeling.modeling_electra import ElectraForSequenceClassification |
|
from src.attention_rollout import AttentionRollout |
|
|
|
import seaborn as sns |
|
import pandas as pd |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import matplotlib.gridspec as gridspec |
|
import matplotlib.backends.backend_pdf |
|
|
|
def inference(text, model): |
|
if model == "bert-base-uncased-cls-sst2": |
|
config = { |
|
|
|
|
|
"MODEL": "TehranNLP-org/bert-base-uncased-cls-sst2" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
elif model == "bert-large-sst2": |
|
config = { |
|
|
|
|
|
|
|
|
|
|
|
|
|
"MODEL": "TehranNLP-org/bert-large-sst2" |
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
else: |
|
config = { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"MODEL": "TehranNLP-org/electra-base-sst2" |
|
|
|
|
|
} |
|
SENTENCE = text |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config["MODEL"]) |
|
tokenized_sentence = tokenizer.encode_plus(SENTENCE, return_tensors="pt") |
|
if "bert" in config["MODEL"]: |
|
model = BertForSequenceClassification.from_pretrained(config["MODEL"]) |
|
elif "electra" in config["MODEL"]: |
|
model = ElectraForSequenceClassification.from_pretrained(config["MODEL"]) |
|
else: |
|
raise Exception(f"Not implented model: {config['MODEL']}") |
|
|
|
|
|
with torch.no_grad(): |
|
logits, attentions, norms = model(**tokenized_sentence, output_attentions=True, output_norms=True, return_dict=False) |
|
num_layers = len(attentions) |
|
norm_nenc = torch.stack([norms[i][4] for i in range(num_layers)]).squeeze().cpu().numpy() |
|
print("Single layer N-Enc token attribution:", norm_nenc.shape) |
|
|
|
|
|
globenc = AttentionRollout().compute_flows([norm_nenc], output_hidden_states=True)[0] |
|
globenc = np.array(globenc) |
|
print("Aggregated N-Enc token attribution (GlobEnc):", globenc.shape) |
|
|
|
|
|
|
|
tokenized_text = tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][0]) |
|
plt.figure(figsize=(14, 8)) |
|
norm_cls = globenc[:, 0, :] |
|
norm_cls = np.flip(norm_cls, axis=0) |
|
row_sums = norm_cls.max(axis=1) |
|
norm_cls = norm_cls / row_sums[:, np.newaxis] |
|
df = pd.DataFrame(norm_cls, columns=tokenized_text, index=range(len(norm_cls), 0, -1)) |
|
ax = sns.heatmap(df, cmap="Reds", square=True) |
|
bottom, top = ax.get_ylim() |
|
ax.set_ylim(bottom + 0.5, top - 0.5) |
|
plt.title("GlobEnc", fontsize=16) |
|
plt.ylabel("Layer", fontsize=16) |
|
plt.xticks(rotation = 90, fontsize=16) |
|
plt.yticks(fontsize=13) |
|
plt.gcf().subplots_adjust(bottom=0.2) |
|
print("logits:", logits) |
|
|
|
return plt |
|
|
|
demo = gr.Blocks() |
|
|
|
with demo: |
|
gr.Markdown( |
|
""" |
|
# Hello World! |
|
Start typing below to see the output. |
|
""") |
|
inp = [gr.Textbox(),gr.Dropdown(choices=['bert-base-uncased-cls-sst2','bert-large-sst2','electra-base-sst2'])] |
|
out = gr.Plot() |
|
|
|
button = gr.Button(value="Run") |
|
button.click(fn=inference, |
|
inputs=inp, |
|
outputs=out) |
|
|
|
demo.launch() |