|
import os |
|
from typing import Tuple |
|
|
|
import streamlit as st |
|
|
|
import torch |
|
|
|
import transformers |
|
|
|
from transformers import AutoConfig |
|
import tokenizers |
|
|
|
from sampling import CAIFSampler, TopKWithTemperatureSampler |
|
from generator import Generator |
|
|
|
import pickle |
|
|
|
from plotly import graph_objects as go |
|
|
|
import numpy as np |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
ATTRIBUTE_MODELS = { |
|
"Russian": ( |
|
"cointegrated/rubert-tiny-toxicity", |
|
"SkolkovoInstitute/russian_toxicity_classifier" |
|
), |
|
"English": ( |
|
"unitary/toxic-bert", |
|
"distilbert-base-uncased-finetuned-sst-2-english", |
|
"cardiffnlp/twitter-roberta-base-sentiment-latest", |
|
) |
|
} |
|
|
|
LANGUAGE_MODELS = { |
|
"Russian": ( |
|
'sberbank-ai/rugpt3small_based_on_gpt2', |
|
"sberbank-ai/rugpt3large_based_on_gpt2" |
|
), |
|
"English": ("gpt2", "distilgpt2", "EleutherAI/gpt-neo-1.3B") |
|
} |
|
|
|
ATTRIBUTE_MODEL_LABEL = { |
|
"Russian": 'Выберите модель классификации', |
|
"English": "Choose attribute model" |
|
} |
|
|
|
LM_LABEL = { |
|
"English": "Choose language model", |
|
"Russian": "Выберите языковую модель" |
|
} |
|
|
|
ATTRIBUTE_LABEL = { |
|
"Russian": "Веберите нужный атрибут текста", |
|
"English": "Choose desired attribute", |
|
} |
|
|
|
TEXT_PROMPT_LABEL = { |
|
"English": "Text prompt", |
|
"Russian": "Начало текста" |
|
} |
|
|
|
PROMPT_EXAMPLE = { |
|
"English": "Hello there", |
|
"Russian": "Привет всем" |
|
} |
|
|
|
WARNING_TEXT = { |
|
"English": """ |
|
**Warning!** |
|
|
|
If you are clicking checkbox bellow positive """ + r"$\alpha$" + """ values for CAIF sampling become available. |
|
It means that language model will be forced to produce toxic or/and abusive text. |
|
This space is only a demonstration of our method for controllable text generation |
|
and we are not responsible for the content produced by this method. |
|
|
|
**Please use it carefully and with positive intentions!** |
|
""", |
|
"Russian": """ |
|
**Внимание!** |
|
|
|
После нажатия на чекбокс ниже положительные """ + r"$\alpha$" + """ станут доступны. |
|
Это означает, что языковая модель будет генерировать токсичные тексты. |
|
Это демо служит лишь демонстрацией нашего метода контролируемой генерации. |
|
Мы не несем ответственности за полученные тексты. |
|
|
|
**Используйте этот метод осторожно и с положительными намерениями!** |
|
""" |
|
} |
|
|
|
|
|
def main(): |
|
st.header("CAIF") |
|
with open("entropy_cdf.pkl", "rb") as inp: |
|
x_s, y_s = pickle.load(inp) |
|
scatter = go.Scatter({ |
|
"x": x_s, |
|
"y": y_s, |
|
"name": "GPT2", |
|
"mode": "lines", |
|
} |
|
) |
|
layout = go.Layout({ |
|
"yaxis": { |
|
"title": "Speedup", |
|
"tickvals": [0, 0.5, 0.8, 1], |
|
"ticktext": ["1x", "2x", "5x", "10x"] |
|
}, |
|
"xaxis": {"title": "Entropy threshold"}, |
|
"template": "plotly_white", |
|
}) |
|
|
|
language = st.selectbox("Language", ("English", "Russian")) |
|
cls_model_name = st.selectbox( |
|
ATTRIBUTE_MODEL_LABEL[language], |
|
ATTRIBUTE_MODELS[language] |
|
|
|
) |
|
lm_model_name = st.selectbox( |
|
LM_LABEL[language], |
|
LANGUAGE_MODELS[language] |
|
) |
|
cls_model_config = AutoConfig.from_pretrained(cls_model_name) |
|
if cls_model_config.problem_type == "multi_label_classification": |
|
label2id = cls_model_config.label2id |
|
if "rubert-tiny-toxicity" in cls_model_name: |
|
idx = 0 |
|
for i, k in enumerate(label2id.keys()): |
|
if k == 'threat': |
|
idx = i |
|
|
|
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys(), index=idx) |
|
else: |
|
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys()) |
|
target_label_id = label2id[label_key] |
|
act_type = "sigmoid" |
|
elif cls_model_config.problem_type == "single_label_classification": |
|
label2id = cls_model_config.label2id |
|
label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]]) |
|
target_label_id = 1 |
|
act_type = "sigmoid" |
|
else: |
|
label2id = cls_model_config.label2id |
|
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys()) |
|
target_label_id = label2id[label_key] |
|
act_type = "softmax" |
|
st.write(WARNING_TEXT[language]) |
|
show_pos_alpha = st.checkbox("Show positive alphas", value=False) |
|
if "sst" in cls_model_name: |
|
prompt = st.text_input(TEXT_PROMPT_LABEL[language], "The movie") |
|
elif "rubert-tiny-toxicity" in cls_model_name: |
|
prompt = st.text_input(TEXT_PROMPT_LABEL[language], "Ну ты и") |
|
else: |
|
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language]) |
|
st.latex(r"p(x_i|x_{<i}, c) \propto p(x_i|x_{<i})p(c|x_{\leq i})^{\alpha}") |
|
if act_type == "softmax": |
|
alpha = st.slider("α", min_value=-40, max_value=40 if show_pos_alpha else 0, step=1, value=0) |
|
else: |
|
alpha = st.slider("α", min_value=-10, max_value=10 if show_pos_alpha else 0, step=1, value=0) |
|
entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=10., step=.1, value=2.) |
|
plot_idx = np.argmin(np.abs(entropy_threshold - x_s)) |
|
scatter_tip = go.Scatter({ |
|
"x": [x_s[plot_idx]], |
|
"y": [y_s[plot_idx]], |
|
"mode": "markers" |
|
}) |
|
scatter_tip_lines = go.Scatter({ |
|
"x": [0, x_s[plot_idx]], |
|
"y": [y_s[plot_idx]] * 2, |
|
"mode": "lines", |
|
"line": { |
|
"color": "grey", |
|
"dash": "dash" |
|
} |
|
}) |
|
figure = go.Figure(data=[scatter, scatter_tip, scatter_tip_lines], layout=layout) |
|
figure.update_layout(paper_bgcolor="#FFFFFF", plot_bgcolor='#FFFFFF', showlegend=False) |
|
st.plotly_chart(figure, use_container_width=True) |
|
auth_token = os.environ.get('TOKEN') or True |
|
fp16 = st.checkbox("FP16", value=True) |
|
st.session_state["generated_text"] = None |
|
st.subheader("Generated text:") |
|
|
|
def generate(): |
|
text = inference( |
|
lm_model_name=lm_model_name, |
|
cls_model_name=cls_model_name, |
|
prompt=prompt, |
|
alpha=alpha, |
|
target_label_id=target_label_id, |
|
entropy_threshold=entropy_threshold, |
|
fp16=fp16, |
|
act_type=act_type |
|
) |
|
|
|
st.button("Generate new", on_click=generate()) |
|
|
|
|
|
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True) |
|
def load_generator(lm_model_name: str) -> Generator: |
|
with st.spinner('Loading language model...'): |
|
generator = Generator(lm_model_name=lm_model_name, device=device) |
|
return generator |
|
|
|
|
|
|
|
def load_sampler(cls_model_name, lm_tokenizer): |
|
with st.spinner('Loading classifier model...'): |
|
sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device) |
|
return sampler |
|
|
|
|
|
def inference( |
|
lm_model_name: str, |
|
cls_model_name: str, |
|
prompt: str, |
|
fp16: bool = True, |
|
alpha: float = 5, |
|
target_label_id: int = 0, |
|
entropy_threshold: float = 0, |
|
act_type: str = "sigmoid" |
|
) -> str: |
|
torch.set_grad_enabled(False) |
|
generator = load_generator(lm_model_name=lm_model_name) |
|
lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name) |
|
if alpha != 0: |
|
caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer) |
|
if entropy_threshold < 0.05: |
|
entropy_threshold = None |
|
else: |
|
caif_sampler = None |
|
entropy_threshold = None |
|
|
|
generator.set_caif_sampler(caif_sampler) |
|
ordinary_sampler = TopKWithTemperatureSampler() |
|
kwargs = { |
|
"top_k": 20, |
|
"temperature": 1.0, |
|
"top_k_classifier": 100, |
|
"classifier_weight": alpha, |
|
"target_cls_id": target_label_id, |
|
"act_type": act_type |
|
} |
|
generator.set_ordinary_sampler(ordinary_sampler) |
|
if device == "cpu": |
|
autocast = torch.cpu.amp.autocast |
|
else: |
|
autocast = torch.cuda.amp.autocast |
|
with autocast(fp16): |
|
print(f"Generating for prompt: {prompt}") |
|
progress_bar = st.progress(0) |
|
sequences, tokens = generator.sample_sequences( |
|
num_samples=1, |
|
input_prompt=prompt, |
|
max_length=20, |
|
caif_period=1, |
|
entropy=entropy_threshold, |
|
progress_bar=progress_bar, |
|
**kwargs |
|
) |
|
print(f"Output for prompt: {sequences}") |
|
return sequences[0] |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|