caif / app.py
Балаганский Никита Николаевич
add languages
c3ce809
raw
history blame
4.96 kB
import os
from typing import Tuple
import streamlit as st
import torch
import transformers
from transformers import AutoConfig
import tokenizers
from sampling import CAIFSampler, TopKWithTemperatureSampler
from generator import Generator
device = "cuda" if torch.cuda.is_available() else "cpu"
ATTRIBUTE_MODELS = {
"Russian": (
"cointegrated/rubert-tiny-toxicity",
'tinkoff-ai/response-quality-classifier-tiny',
'tinkoff-ai/response-quality-classifier-base',
'tinkoff-ai/response-quality-classifier-large',
"SkolkovoInstitute/roberta_toxicity_classifier",
"SkolkovoInstitute/russian_toxicity_classifier"
),
"English": (
"unitary/toxic-bert",
)
}
LANGUAGE_MODELS = {
"Russian": ('sberbank-ai/rugpt3small_based_on_gpt2',),
"Eanglish": ("distilgpt2")
}
ATTRIBUTE_MODEL_LABEL = {
"Russian": 'Выберите модель классификации',
"English": "Choose attribute model"
}
LM_LABEL = {
"English": "Choose language model",
"Russian": "Выберите языковую модель"
}
ATTRIBUTE_LABEL = {
"Russian": "Веберите нужный атрибут текста",
"English": "Choose desired attribute",
}
TEXT_PROMPT_LABEL = {
"English": "Text prompt",
"Russian": "Начало текста"
}
PROMPT_EXAMPLE = {
"English": "Hello, today I",
"Russian": "Привет, сегодня я"
}
def main():
st.header("CAIF")
language = st.selectbox("Language", ("English", "Russian"))
cls_model_name = st.selectbox(
ATTRIBUTE_MODEL_LABEL[language],
ATTRIBUTE_MODELS[language]
)
lm_model_name = st.selectbox(
LM_LABEL[language],
LANGUAGE_MODELS[language]
)
cls_model_config = AutoConfig.from_pretrained(cls_model_name)
if cls_model_config.problem_type == "multi_label_classification":
label2id = cls_model_config.label2id
label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys())
target_label_id = label2id[label_key]
else:
label2id = cls_model_config.label2id
print(list(label2id.keys()))
label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]])
target_label_id = 0
prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language])
alpha = st.slider("Alpha", min_value=-10, max_value=10, step=1, value=0)
entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=5., step=.1, value=0.)
auth_token = os.environ.get('TOKEN') or True
with st.spinner('Running inference...'):
text = inference(
lm_model_name=lm_model_name,
cls_model_name=cls_model_name,
prompt=prompt,
alpha=alpha,
entropy_threshold=entropy_threshold
)
st.subheader("Generated text:")
st.markdown(text)
@st.cache(hash_funcs={str: lambda lm_model_name: hash(lm_model_name)}, allow_output_mutation=True)
def load_generator(lm_model_name: str) -> Generator:
with st.spinner('Loading language model...'):
generator = Generator(lm_model_name=lm_model_name, device=device)
return generator
def load_sampler(cls_model_name, lm_tokenizer):
with st.spinner('Loading classifier model...'):
sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device)
return sampler
@st.cache
def inference(
lm_model_name: str,
cls_model_name: str,
prompt: str,
fp16: bool = True,
alpha: float = 5,
target_label_id: int = 0,
entropy_threshold: float = 0
) -> str:
generator = load_generator(lm_model_name=lm_model_name)
lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name)
if alpha != 0:
caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
if entropy_threshold < 0.05:
entropy_threshold = None
else:
caif_sampler = None
entropy_threshold = None
generator.set_caif_sampler(caif_sampler)
ordinary_sampler = TopKWithTemperatureSampler()
kwargs = {
"top_k": 20,
"temperature": 1.0,
"top_k_classifier": 100,
"classifier_weight": alpha,
"target_cls_id": target_label_id
}
generator.set_ordinary_sampler(ordinary_sampler)
if device == "cpu":
autocast = torch.cpu.amp.autocast
else:
autocast = torch.cuda.amp.autocast
with autocast(fp16):
print(f"Generating for prompt: {prompt}")
sequences, tokens = generator.sample_sequences(
num_samples=1,
input_prompt=prompt,
max_length=20,
caif_period=1,
entropy=entropy_threshold,
**kwargs
)
print(f"Output for prompt: {sequences}")
return sequences[0]
if __name__ == "__main__":
main()