import os from typing import Tuple import streamlit as st import torch import transformers from transformers import AutoConfig import tokenizers from sampling import CAIFSampler, TopKWithTemperatureSampler from generator import Generator import pickle from plotly import graph_objects as go import numpy as np device = "cuda" if torch.cuda.is_available() else "cpu" ATTRIBUTE_MODELS = { "Russian": ( "cointegrated/rubert-tiny-toxicity", 'tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base', 'tinkoff-ai/response-quality-classifier-large', "SkolkovoInstitute/roberta_toxicity_classifier", "SkolkovoInstitute/russian_toxicity_classifier" ), "English": ( "unitary/toxic-bert", ) } LANGUAGE_MODELS = { "Russian": ( 'sberbank-ai/rugpt3small_based_on_gpt2', "sberbank-ai/rugpt3large_based_on_gpt2" ), "English": ("gpt2", "distilgpt2", "EleutherAI/gpt-neo-1.3B") } ATTRIBUTE_MODEL_LABEL = { "Russian": 'Выберите модель классификации', "English": "Choose attribute model" } LM_LABEL = { "English": "Choose language model", "Russian": "Выберите языковую модель" } ATTRIBUTE_LABEL = { "Russian": "Веберите нужный атрибут текста", "English": "Choose desired attribute", } TEXT_PROMPT_LABEL = { "English": "Text prompt", "Russian": "Начало текста" } PROMPT_EXAMPLE = { "English": "Hello, today I want to show you a new method", "Russian": "Привет, сегодня я" } def main(): st.header("CAIF") with open("entropy_cdf.pkl", "rb") as inp: x_s, y_s = pickle.load(inp) scatter = go.Scatter({ "x": x_s, "y": y_s, "name": "GPT2", "mode": "lines", } ) layout = go.Layout({ "yaxis": { "title": "Speedup", "tickvals": [0, 0.5, 0.8, 1], "ticktext": ["1x", "2x", "5x", "10x"] }, "xaxis": {"title": "Entropy threshold"}, "template": "plotly_white", }) language = st.selectbox("Language", ("English", "Russian")) cls_model_name = st.selectbox( ATTRIBUTE_MODEL_LABEL[language], ATTRIBUTE_MODELS[language] ) lm_model_name = st.selectbox( LM_LABEL[language], LANGUAGE_MODELS[language] ) cls_model_config = AutoConfig.from_pretrained(cls_model_name) if cls_model_config.problem_type == "multi_label_classification": label2id = cls_model_config.label2id label_key = st.selectbox(ATTRIBUTE_LABEL[language], label2id.keys()) target_label_id = label2id[label_key] else: label2id = cls_model_config.label2id print(list(label2id.keys())) label_key = st.selectbox(ATTRIBUTE_LABEL[language], [list(label2id.keys())[-1]]) target_label_id = 1 prompt = st.text_input(TEXT_PROMPT_LABEL[language], PROMPT_EXAMPLE[language]) st.latex(r"p(x_i|x_{ Generator: with st.spinner('Loading language model...'): generator = Generator(lm_model_name=lm_model_name, device=device) return generator #@st.cache(hash_funcs={tokenizers.Tokenizer: lambda lm_tokenizer: hash(lm_tokenizer.to_str)}, allow_output_mutation=True) def load_sampler(cls_model_name, lm_tokenizer): with st.spinner('Loading classifier model...'): sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device) return sampler @st.cache def inference( lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool = True, alpha: float = 5, target_label_id: int = 0, entropy_threshold: float = 0 ) -> str: torch.set_grad_enabled(False) generator = load_generator(lm_model_name=lm_model_name) lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name) if alpha != 0: caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer) if entropy_threshold < 0.05: entropy_threshold = None else: caif_sampler = None entropy_threshold = None generator.set_caif_sampler(caif_sampler) ordinary_sampler = TopKWithTemperatureSampler() kwargs = { "top_k": 20, "temperature": 1.0, "top_k_classifier": 100, "classifier_weight": alpha, "target_cls_id": target_label_id } generator.set_ordinary_sampler(ordinary_sampler) if device == "cpu": autocast = torch.cpu.amp.autocast else: autocast = torch.cuda.amp.autocast with autocast(fp16): print(f"Generating for prompt: {prompt}") sequences, tokens = generator.sample_sequences( num_samples=1, input_prompt=prompt, max_length=20, caif_period=1, entropy=entropy_threshold, **kwargs ) print(f"Output for prompt: {sequences}") return sequences[0] if __name__ == "__main__": main()