|
import os |
|
from typing import Tuple |
|
|
|
import streamlit as st |
|
|
|
import torch |
|
|
|
import transformers |
|
|
|
from transformers import AutoConfig |
|
import tokenizers |
|
|
|
from sampling import CAIFSampler, TopKWithTemperatureSampler |
|
from generator import Generator |
|
|
|
import pickle |
|
|
|
from plotly import graph_objects as go |
|
|
|
import numpy as np |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
ATTRIBUTE_MODELS = { |
|
"Russian": ( |
|
"cointegrated/rubert-tiny-toxicity", |
|
'tinkoff-ai/response-quality-classifier-tiny', |
|
'tinkoff-ai/response-quality-classifier-base', |
|
'tinkoff-ai/response-quality-classifier-large', |
|
"SkolkovoInstitute/roberta_toxicity_classifier", |
|
"SkolkovoInstitute/russian_toxicity_classifier" |
|
), |
|
"English": ( |
|
"unitary/toxic-bert", |
|
) |
|
} |
|
|
|
LANGUAGE_MODELS = { |
|
"Russian": ( |
|
'sberbank-ai/rugpt3small_based_on_gpt2', |
|
"sberbank-ai/rugpt3large_based_on_gpt2" |
|
), |
|
"English": ("gpt2", "distilgpt2", "EleutherAI/gpt-neo-1.3B") |
|
} |
|
|
|
ATTRIBUTE_MODEL_LABEL = { |
|
"Russian": 'Выберите модель классификации', |
|
"English": "Choose attribute model" |
|
} |
|
|
|
LM_LABEL = { |
|
"English": "Choose language model", |
|
"Russian": "Выберите языковую модель" |
|
} |
|
|
|
ATTRIBUTE_LABEL = { |
|
"Russian": "Веберите нужный атрибут текста", |
|
"English": "Choose desired attribute", |
|
} |
|
|
|
TEXT_PROMPT_LABEL = { |
|
"English": "Text prompt", |
|
"Russian": "Начало текста" |
|
} |
|
|
|
PROMPT_EXAMPLE = { |
|
"English": "Hello, today I want to show you a new method", |
|
"Russian": "Привет, сегодня я" |
|
} |
|
|
|
|
|
def main(): |
|
st.header("CAIF") |
|
with open("entropy_cdf.pkl", "rb") as inp: |
|
x_s, y_s = pickle.load(inp) |
|
scatter = go.Scatter({ |
|
"x": x_s, |
|
"y": y_s, |
|
"name": "GPT2", |
|
"mode": "lines", |
|
} |
|
) |
|
layout = go.Layout({ |
|
"yaxis": { |
|
"title": "Speedup", |
|
"tickvals": [0, 0.5, 0.8, 1], |
|
"ticktext": ["1x", "2x", "5x", "10x"] |
|
}, |
|
"xaxis": {"title": "Entropy threshold"}, |
|
"template": "plotly_white", |
|
}) |
|
|
|
language = st.selectbox("Language", ("English", "Russian")) |
|
cls_model_name = st.selectbox( |
|
ATTRIBUTE_MODEL_LABEL[language], |
|
ATTRIBUTE_MODELS[language] |
|
|
|
) |
|
lm_model_name = st.selectbox( |
|
LM_LABEL[language], |
|
LANGUAGE_MODELS[language] |
|
) |
|
cls_model_config = AutoConfig.from_pretrained(cls_model_name) |
|
if cls_model_config.problem_type == "multi_label_classification": |
|
label2id = cls_model_config.label2id |
|
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys()) |
|
target_label_id = label2id[label_key] |
|
else: |
|
label2id = cls_model_config.label2id |
|
print(list(label2id.keys())) |
|
label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]]) |
|
target_label_id = 1 |
|
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language]) |
|
st.latex(r"p(x_i|x_{<i}, c) \propto p(x_i|x_{<i})p(c|x_{\leq i})^{\alpha}") |
|
alpha = st.slider("Alpha", min_value=-10, max_value=10, step=1, value=0) |
|
entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=10., step=.1, value=2.) |
|
plot_idx = np.argmin(np.abs(entropy_threshold - x_s)) |
|
scatter_tip = go.Scatter({ |
|
"x": [x_s[plot_idx]], |
|
"y": [y_s[plot_idx]], |
|
"mode": "markers" |
|
}) |
|
scatter_tip_lines = go.Scatter({ |
|
"x": [0, x_s[plot_idx]], |
|
"y": [y_s[plot_idx]] * 2, |
|
"mode": "lines", |
|
"line": { |
|
"color": "grey", |
|
"dash": "dash" |
|
} |
|
}) |
|
figure = go.Figure(data=[scatter, scatter_tip, scatter_tip_lines], layout=layout) |
|
figure.update_layout(paper_bgcolor="#FFFFFF", plot_bgcolor='#FFFFFF', showlegend=False) |
|
st.plotly_chart(figure, use_container_width=True) |
|
auth_token = os.environ.get('TOKEN') or True |
|
fp16 = st.checkbox("FP16", value=True) |
|
with st.spinner('Running inference...'): |
|
text = inference( |
|
lm_model_name=lm_model_name, |
|
cls_model_name=cls_model_name, |
|
prompt=prompt, |
|
alpha=alpha, |
|
target_label_id=target_label_id, |
|
entropy_threshold=entropy_threshold, |
|
fp16=fp16, |
|
) |
|
st.subheader("Generated text:") |
|
st.write(text) |
|
|
|
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True) |
|
def load_generator(lm_model_name: str) -> Generator: |
|
with st.spinner('Loading language model...'): |
|
generator = Generator(lm_model_name=lm_model_name, device=device) |
|
return generator |
|
|
|
|
|
def load_sampler(cls_model_name, lm_tokenizer): |
|
with st.spinner('Loading classifier model...'): |
|
sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device) |
|
return sampler |
|
|
|
|
|
@st.cache |
|
def inference( |
|
lm_model_name: str, |
|
cls_model_name: str, |
|
prompt: str, |
|
fp16: bool = True, |
|
alpha: float = 5, |
|
target_label_id: int = 0, |
|
entropy_threshold: float = 0 |
|
) -> str: |
|
torch.set_grad_enabled(False) |
|
generator = load_generator(lm_model_name=lm_model_name) |
|
lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name) |
|
if alpha != 0: |
|
caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer) |
|
if entropy_threshold < 0.05: |
|
entropy_threshold = None |
|
else: |
|
caif_sampler = None |
|
entropy_threshold = None |
|
|
|
generator.set_caif_sampler(caif_sampler) |
|
ordinary_sampler = TopKWithTemperatureSampler() |
|
kwargs = { |
|
"top_k": 20, |
|
"temperature": 1.0, |
|
"top_k_classifier": 100, |
|
"classifier_weight": alpha, |
|
"target_cls_id": target_label_id |
|
} |
|
generator.set_ordinary_sampler(ordinary_sampler) |
|
if device == "cpu": |
|
autocast = torch.cpu.amp.autocast |
|
else: |
|
autocast = torch.cuda.amp.autocast |
|
with autocast(fp16): |
|
print(f"Generating for prompt: {prompt}") |
|
sequences, tokens = generator.sample_sequences( |
|
num_samples=1, |
|
input_prompt=prompt, |
|
max_length=20, |
|
caif_period=1, |
|
entropy=entropy_threshold, |
|
**kwargs |
|
) |
|
print(f"Output for prompt: {sequences}") |
|
return sequences[0] |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|