import os from typing import Tuple import streamlit as st import torch import transformers from transformers import AutoConfig import tokenizers from sampling import CAIFSampler, TopKWithTemperatureSampler from generator import Generator device = "cuda" if torch.cuda.is_available() else "cpu" ATTRIBUTE_MODELS = { "Russian": ( "cointegrated/rubert-tiny-toxicity", 'tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base', 'tinkoff-ai/response-quality-classifier-large', "SkolkovoInstitute/roberta_toxicity_classifier", "SkolkovoInstitute/russian_toxicity_classifier" ), "English": ( "unitary/toxic-bert", ) } LANGUAGE_MODELS = { "Russian": ('sberbank-ai/rugpt3small_based_on_gpt2',), "Eanglish": ("distilgpt2") } ATTRIBUTE_MODEL_LABEL = { "Russian": 'Выберите модель классификации', "English": "Choose attribute model" } LM_LABEL = { "English": "Choose language model", "Russian": "Выберите языковую модель" } ATTRIBUTE_LABEL = { "Russian": "Веберите нужный атрибут текста", "English": "Choose desired attribute", } TEXT_PROMPT_LABEL = { "English": "Text prompt", "Russian": "Начало текста" } PROMPT_EXAMPLE = { "English": "Hello, today I", "Russian": "Привет, сегодня я" } def main(): st.header("CAIF") language = st.selectbox("Language", ("English", "Russian")) cls_model_name = st.selectbox( ATTRIBUTE_MODEL_LABEL[language], ATTRIBUTE_MODELS[language] ) lm_model_name = st.selectbox( LM_LABEL[language], LANGUAGE_MODELS[language] ) cls_model_config = AutoConfig.from_pretrained(cls_model_name) if cls_model_config.problem_type == "multi_label_classification": label2id = cls_model_config.label2id label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys()) target_label_id = label2id[label_key] else: label2id = cls_model_config.label2id print(list(label2id.keys())) label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]]) target_label_id = 0 prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language]) alpha = st.slider("Alpha", min_value=-10, max_value=10, step=1, value=0) entropy_threshold = st.slider("Entropy threshold", min_value=0., max_value=5., step=.1, value=0.) auth_token = os.environ.get('TOKEN') or True with st.spinner('Running inference...'): text = inference( lm_model_name=lm_model_name, cls_model_name=cls_model_name, prompt=prompt, alpha=alpha, entropy_threshold=entropy_threshold ) st.subheader("Generated text:") st.markdown(text) @st.cache(hash_funcs={str: lambda lm_model_name: hash(lm_model_name)}, allow_output_mutation=True) def load_generator(lm_model_name: str) -> Generator: with st.spinner('Loading language model...'): generator = Generator(lm_model_name=lm_model_name, device=device) return generator def load_sampler(cls_model_name, lm_tokenizer): with st.spinner('Loading classifier model...'): sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device) return sampler @st.cache def inference( lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool = True, alpha: float = 5, target_label_id: int = 0, entropy_threshold: float = 0 ) -> str: generator = load_generator(lm_model_name=lm_model_name) lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name) if alpha != 0: caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer) if entropy_threshold < 0.05: entropy_threshold = None else: caif_sampler = None entropy_threshold = None generator.set_caif_sampler(caif_sampler) ordinary_sampler = TopKWithTemperatureSampler() kwargs = { "top_k": 20, "temperature": 1.0, "top_k_classifier": 100, "classifier_weight": alpha, "target_cls_id": target_label_id } generator.set_ordinary_sampler(ordinary_sampler) if device == "cpu": autocast = torch.cpu.amp.autocast else: autocast = torch.cuda.amp.autocast with autocast(fp16): print(f"Generating for prompt: {prompt}") sequences, tokens = generator.sample_sequences( num_samples=1, input_prompt=prompt, max_length=20, caif_period=1, entropy=entropy_threshold, **kwargs ) print(f"Output for prompt: {sequences}") return sequences[0] if __name__ == "__main__": main()