import os from typing import Tuple import streamlit as st import torch import transformers from transformers import AutoConfig import tokenizers from sampling import CAIFSampler, TopKWithTemperatureSampler from generator import Generator device = "cuda" if torch.cuda.is_available() else "cpu" def main(): st.header("CAIF") cls_model_name = st.selectbox( 'Выберите модель классификации', ( 'tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base', 'tinkoff-ai/response-quality-classifier-large', "SkolkovoInstitute/roberta_toxicity_classifier" ) ) lm_model_name = st.selectbox( 'Выберите языковую модель', ('sberbank-ai/rugpt3small_based_on_gpt2',) ) cls_model_config = AutoConfig.from_pretrained(cls_model_name) if cls_model_config.problem_type == "multi_label_classification": label2id = cls_model_config.label2id label_key = st.selectbox("Веберите нужный атрибут текста", label2id.keys()) target_label_id = label2id[label_key] else: label2id = cls_model_config.label2id label_key = st.selectbox("Веберите нужный атрибут текста", list(label2id.keys())[-1]) target_label_id = 0 prompt = st.text_input("Начало текста:", "Привет") alpha = st.slider("Alpha:", min_value=-10, max_value=10, step=1, value=0) entropy_threshold = st.slider("Entropy Threshold:", min_value=0., max_value=5., step=.1, value=0.) auth_token = os.environ.get('TOKEN') or True with st.spinner('Running inference...'): text = inference(lm_model_name=lm_model_name, cls_model_name=cls_model_name, prompt=prompt, alpha=alpha) st.subheader("Generated text:") st.markdown(text) @st.cache(hash_funcs={str: lambda lm_model_name: hash(lm_model_name)}, allow_output_mutation=True) def load_generator(lm_model_name: str) -> Generator: with st.spinner('Loading language model...'): generator = Generator(lm_model_name=lm_model_name, device=device) return generator def load_sampler(cls_model_name, lm_tokenizer): with st.spinner('Loading classifier model...'): sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device) return sampler @st.cache def inference( lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool = True, alpha: float = 5, target_label_id: int = 0, entropy_threshold: float = 0 ) -> str: generator = load_generator(lm_model_name=lm_model_name) lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name) if alpha != 0: caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer) else: caif_sampler = None if entropy_threshold < 0.05: entropy_threshold = None generator.set_caif_sampler(caif_sampler) ordinary_sampler = TopKWithTemperatureSampler() kwargs = { "top_k": 20, "temperature": 1.0, "top_k_classifier": 100, "classifier_weight": alpha, "target_cls_id": target_label_id } generator.set_ordinary_sampler(ordinary_sampler) if device == "cpu": autocast = torch.cpu.amp.autocast else: autocast = torch.cuda.amp.autocast with autocast(fp16): print(f"Generating for prompt: {prompt}") sequences, tokens = generator.sample_sequences( num_samples=1, input_prompt=prompt, max_length=20, caif_period=1, entropy=entropy_threshold, **kwargs ) print(f"Output for prompt: {sequences}") return sequences[0] if __name__ == "__main__": main()