caif / app.py
Балаганский Никита Николаевич
fix
cb16bf9
raw
history blame
9.03 kB
import os
from typing import Tuple
import streamlit as st
import torch
import transformers
from transformers import AutoConfig
import tokenizers
from sampling import CAIFSampler, TopKWithTemperatureSampler
from generator import Generator
import pickle
from plotly import graph_objects as go
import numpy as np
device = "cuda" if torch.cuda.is_available() else "cpu"
ATTRIBUTE_MODELS = {
"Russian": (
"cointegrated/rubert-tiny-toxicity",
"SkolkovoInstitute/russian_toxicity_classifier"
),
"English": (
"unitary/toxic-bert",
"distilbert-base-uncased-finetuned-sst-2-english",
"cardiffnlp/twitter-roberta-base-sentiment-latest",
)
}
LANGUAGE_MODELS = {
"Russian": (
'sberbank-ai/rugpt3small_based_on_gpt2',
"sberbank-ai/rugpt3large_based_on_gpt2"
),
"English": ("gpt2", "distilgpt2", "EleutherAI/gpt-neo-1.3B")
}
ATTRIBUTE_MODEL_LABEL = {
"Russian": 'Выберите модель классификации',
"English": "Choose attribute model"
}
LM_LABEL = {
"English": "Choose language model",
"Russian": "Выберите языковую модель"
}
ATTRIBUTE_LABEL = {
"Russian": "Веберите нужный атрибут текста",
"English": "Choose desired attribute",
}
TEXT_PROMPT_LABEL = {
"English": "Text prompt",
"Russian": "Начало текста"
}
PROMPT_EXAMPLE = {
"English": "Hello there",
"Russian": "Привет всем"
}
WARNING_TEXT = {
"English": """
**Warning!**
If you are clicking checkbox bellow positive """ + r"$\alpha$" + """ values for CAIF sampling become available.
It means that language model will be forced to produce toxic or/and abusive text.
This space is only a demonstration of our method for controllable text generation
and we are not responsible for the content produced by this method.
**Please use it carefully and with positive intentions!**
""",
"Russian": """
**Внимание!**
После нажатия на чекбокс ниже положительные """ + r"$\alpha$" + """ станут доступны.
Это означает, что языковая модель будет генерировать токсичные тексты.
Это демо служит лишь демонстрацией нашего метода контролируемой генерации.
Мы не несем ответственности за полученные тексты.
**Используйте этот метод осторожно и с положительными намерениями!**
"""
}
def main():
st.header("CAIF")
with open("entropy_cdf.pkl", "rb") as inp:
x_s, y_s = pickle.load(inp)
scatter = go.Scatter({
"x": x_s,
"y": y_s,
"name": "GPT2",
"mode": "lines",
}
)
layout = go.Layout({
"yaxis": {
"title": "Speedup",
"tickvals": [0, 0.5, 0.8, 1],
"ticktext": ["1x", "2x", "5x", "10x"]
},
"xaxis": {"title": "Entropy threshold"},
"template": "plotly_white",
})
language = st.selectbox("Language", ("English", "Russian"))
cls_model_name = st.selectbox(
ATTRIBUTE_MODEL_LABEL[language],
ATTRIBUTE_MODELS[language]
)
lm_model_name = st.selectbox(
LM_LABEL[language],
LANGUAGE_MODELS[language]
)
cls_model_config = AutoConfig.from_pretrained(cls_model_name)
if cls_model_config.problem_type == "multi_label_classification":
label2id = cls_model_config.label2id
if "rubert-tiny-toxicity" in cls_model_name:
idx = 0
for i, k in enumerate(label2id.keys()):
if k == 'threat':
idx = i
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys(), index=idx)
else:
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
target_label_id = label2id[label_key]
act_type = "sigmoid"
elif cls_model_config.problem_type == "single_label_classification":
label2id = cls_model_config.label2id
label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]])
target_label_id = 1
act_type = "sigmoid"
else:
label2id = cls_model_config.label2id
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
target_label_id = label2id[label_key]
act_type = "softmax"
st.write(WARNING_TEXT[language])
show_pos_alpha = st.checkbox("Show positive alphas", value=False)
if "sst" in cls_model_name:
prompt = st.text_input(TEXT_PROMPT_LABEL[language], "The movie")
elif "rubert-tiny-toxicity" in cls_model_name:
prompt = st.text_input(TEXT_PROMPT_LABEL[language], "Ну ты и")
else:
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
st.latex(r"p(x_i|x_{<i}, c) \propto p(x_i|x_{<i})p(c|x_{\leq i})^{\alpha}")
if act_type == "softmax":
alpha = st.slider("α", min_value=-40, max_value=40 if show_pos_alpha else 0, step=1, value=0)
else:
alpha = st.slider("α", min_value=-10, max_value=10 if show_pos_alpha else 0, step=1, value=0)
entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=10., step=.1, value=2.)
plot_idx = np.argmin(np.abs(entropy_threshold - x_s))
scatter_tip = go.Scatter({
"x": [x_s[plot_idx]],
"y": [y_s[plot_idx]],
"mode": "markers"
})
scatter_tip_lines = go.Scatter({
"x": [0, x_s[plot_idx]],
"y": [y_s[plot_idx]] * 2,
"mode": "lines",
"line": {
"color": "grey",
"dash": "dash"
}
})
figure = go.Figure(data=[scatter, scatter_tip, scatter_tip_lines], layout=layout)
figure.update_layout(paper_bgcolor="#FFFFFF", plot_bgcolor='#FFFFFF', showlegend=False)
st.plotly_chart(figure, use_container_width=True)
auth_token = os.environ.get('TOKEN') or True
fp16 = st.checkbox("FP16", value=True)
st.session_state["generated_text"] = None
st.subheader("Generated text:")
def generate():
text = inference(
lm_model_name=lm_model_name,
cls_model_name=cls_model_name,
prompt=prompt,
alpha=alpha,
target_label_id=target_label_id,
entropy_threshold=entropy_threshold,
fp16=fp16,
act_type=act_type
)
st.button("Generate new", on_click=generate())
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
def load_generator(lm_model_name: str) -> Generator:
with st.spinner('Loading language model...'):
generator = Generator(lm_model_name=lm_model_name, device=device)
return generator
# @st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True)
def load_sampler(cls_model_name, lm_tokenizer):
with st.spinner('Loading classifier model...'):
sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device)
return sampler
def inference(
lm_model_name: str,
cls_model_name: str,
prompt: str,
fp16: bool = True,
alpha: float = 5,
target_label_id: int = 0,
entropy_threshold: float = 0,
act_type: str = "sigmoid"
) -> str:
torch.set_grad_enabled(False)
generator = load_generator(lm_model_name=lm_model_name)
lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name)
if alpha != 0:
caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
if entropy_threshold < 0.05:
entropy_threshold = None
else:
caif_sampler = None
entropy_threshold = None
generator.set_caif_sampler(caif_sampler)
ordinary_sampler = TopKWithTemperatureSampler()
kwargs = {
"top_k": 20,
"temperature": 1.0,
"top_k_classifier": 100,
"classifier_weight": alpha,
"target_cls_id": target_label_id,
"act_type": act_type
}
generator.set_ordinary_sampler(ordinary_sampler)
if device == "cpu":
autocast = torch.cpu.amp.autocast
else:
autocast = torch.cuda.amp.autocast
with autocast(fp16):
print(f"Generating for prompt: {prompt}")
progress_bar = st.progress(0)
sequences, tokens = generator.sample_sequences(
num_samples=1,
input_prompt=prompt,
max_length=20,
caif_period=1,
entropy=entropy_threshold,
progress_bar=progress_bar,
**kwargs
)
print(f"Output for prompt: {sequences}")
return sequences[0]
if __name__ == "__main__":
main()