import os |
from typing import Tuple |
import streamlit as st |
import torch |
import transformers |
from transformers import AutoConfig |
import tokenizers |
from sampling import CAIFSampler, TopKWithTemperatureSampler |
from generator import Generator |
import pickle |
from plotly import graph_objects as go |
import numpy as np |
device = "cuda" if torch.cuda.is_available() else "cpu" |
"Russian": ( |
"cointegrated/rubert-tiny-toxicity", |
"SkolkovoInstitute/russian_toxicity_classifier" |
), |
"English": ( |
"unitary/toxic-bert", |
"distilbert-base-uncased-finetuned-sst-2-english", |
"cardiffnlp/twitter-roberta-base-sentiment-latest", |
) |
} |
"Russian": ( |
'sberbank-ai/rugpt3small_based_on_gpt2', |
"sberbank-ai/rugpt3large_based_on_gpt2" |
), |
"English": ("gpt2", "distilgpt2", "EleutherAI/gpt-neo-1.3B") |
} |
"Russian": 'Выберите модель классификации', |
"English": "Choose attribute model" |
} |
LM_LABEL = { |
"English": "Choose language model", |
"Russian": "Выберите языковую модель" |
} |
"Russian": "Веберите нужный атрибут текста", |
"English": "Choose desired attribute", |
} |
"English": "Text prompt", |
"Russian": "Начало текста" |
} |
"English": "Hello there", |
"Russian": "Привет всем" |
} |
"English": """ |
**Warning!** |
If you are clicking checkbox bellow positive """ + r"$\alpha$" + """ values for CAIF sampling become available. |
It means that language model will be forced to produce toxic or/and abusive text. |
This space is only a demonstration of our method for controllable text generation |
and we are not responsible for the content produced by this method. |
**Please use it carefully and with positive intentions!** |
""", |
"Russian": """ |
**Внимание!** |
После нажатия на чекбокс ниже положительные """ + r"$\alpha$" + """ станут доступны. |
Это означает, что языковая модель будет генерировать токсичные тексты. |
Это демо служит лишь демонстрацией нашего метода контролируемой генерации. |
Мы не несем ответственности за полученные тексты. |
**Используйте этот метод осторожно и с положительными намерениями!** |
""" |
} |
def main(): |
st.header("CAIF") |
with open("entropy_cdf.pkl", "rb") as inp: |
x_s, y_s = pickle.load(inp) |
scatter = go.Scatter({ |
"x": x_s, |
"y": y_s, |
"name": "GPT2", |
"mode": "lines", |
} |
) |
layout = go.Layout({ |
"yaxis": { |
"title": "Speedup", |
"tickvals": [0, 0.5, 0.8, 1], |
"ticktext": ["1x", "2x", "5x", "10x"] |
}, |
"xaxis": {"title": "Entropy threshold"}, |
"template": "plotly_white", |
}) |
language = st.selectbox("Language", ("English", "Russian")) |
cls_model_name = st.selectbox( |
) |
lm_model_name = st.selectbox( |
LM_LABEL[language], |
) |
cls_model_config = AutoConfig.from_pretrained(cls_model_name) |
if cls_model_config.problem_type == "multi_label_classification": |
label2id = cls_model_config.label2id |
if "rubert-tiny-toxicity" in cls_model_name: |
idx = 0 |
for i, k in enumerate(label2id.keys()): |
if k == 'threat': |
idx = i |
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys(), index=idx) |
else: |
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys()) |
target_label_id = label2id[label_key] |
act_type = "sigmoid" |
elif cls_model_config.problem_type == "single_label_classification": |
label2id = cls_model_config.label2id |
label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]]) |
target_label_id = 1 |
act_type = "sigmoid" |
else: |
label2id = cls_model_config.label2id |
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys()) |
target_label_id = label2id[label_key] |
act_type = "softmax" |
st.write(WARNING_TEXT[language]) |
show_pos_alpha = st.checkbox("Show positive alphas", value=False) |
if "sst" in cls_model_name: |
prompt = st.text_input(TEXT_PROMPT_LABEL[language], "The movie") |
elif "rubert-tiny-toxicity" in cls_model_name: |
prompt = st.text_input(TEXT_PROMPT_LABEL[language], "Ну ты и") |
else: |
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language]) |
st.latex(r"p(x_i|x_{<i}, c) \propto p(x_i|x_{<i})p(c|x_{\leq i})^{\alpha}") |
if act_type == "softmax": |
alpha = st.slider("α", min_value=-40, max_value=40 if show_pos_alpha else 0, step=1, value=0) |
else: |
alpha = st.slider("α", min_value=-10, max_value=10 if show_pos_alpha else 0, step=1, value=0) |
entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=10., step=.1, value=2.) |
plot_idx = np.argmin(np.abs(entropy_threshold - x_s)) |
scatter_tip = go.Scatter({ |
"x": [x_s[plot_idx]], |
"y": [y_s[plot_idx]], |
"mode": "markers" |
}) |
scatter_tip_lines = go.Scatter({ |
"x": [0, x_s[plot_idx]], |
"y": [y_s[plot_idx]] * 2, |
"mode": "lines", |
"line": { |
"color": "grey", |
"dash": "dash" |
} |
}) |
figure = go.Figure(data=[scatter, scatter_tip, scatter_tip_lines], layout=layout) |
figure.update_layout(paper_bgcolor="#FFFFFF", plot_bgcolor='#FFFFFF', showlegend=False) |
st.plotly_chart(figure, use_container_width=True) |
auth_token = os.environ.get('TOKEN') or True |
fp16 = st.checkbox("FP16", value=True) |
st.session_state["generated_text"] = None |
st.subheader("Generated text:") |
def generate(): |
text = inference( |
lm_model_name=lm_model_name, |
cls_model_name=cls_model_name, |
prompt=prompt, |
alpha=alpha, |
target_label_id=target_label_id, |
entropy_threshold=entropy_threshold, |
fp16=fp16, |
act_type=act_type |
) |
st.button("Generate new", on_click=generate()) |
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True) |
def load_generator(lm_model_name: str) -> Generator: |
with st.spinner('Loading language model...'): |
generator = Generator(lm_model_name=lm_model_name, device=device) |
return generator |
def load_sampler(cls_model_name, lm_tokenizer): |
with st.spinner('Loading classifier model...'): |
sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device) |
return sampler |
def inference( |
lm_model_name: str, |
cls_model_name: str, |
prompt: str, |
fp16: bool = True, |
alpha: float = 5, |
target_label_id: int = 0, |
entropy_threshold: float = 0, |
act_type: str = "sigmoid" |
) -> str: |
torch.set_grad_enabled(False) |
generator = load_generator(lm_model_name=lm_model_name) |
lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name) |
if alpha != 0: |
caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer) |
if entropy_threshold < 0.05: |
entropy_threshold = None |
else: |
caif_sampler = None |
entropy_threshold = None |
generator.set_caif_sampler(caif_sampler) |
ordinary_sampler = TopKWithTemperatureSampler() |
kwargs = { |
"top_k": 20, |
"temperature": 1.0, |
"top_k_classifier": 100, |
"classifier_weight": alpha, |
"target_cls_id": target_label_id, |
"act_type": act_type |
} |
generator.set_ordinary_sampler(ordinary_sampler) |
if device == "cpu": |
autocast = torch.cpu.amp.autocast |
else: |
autocast = torch.cuda.amp.autocast |
with autocast(fp16): |
print(f"Generating for prompt: {prompt}") |
progress_bar = st.progress(0) |
sequences, tokens = generator.sample_sequences( |
num_samples=1, |
input_prompt=prompt, |
max_length=20, |
caif_period=1, |
entropy=entropy_threshold, |
progress_bar=progress_bar, |
**kwargs |
) |
print(f"Output for prompt: {sequences}") |
return sequences[0] |
if __name__ == "__main__": |
main() |