|
import os |
|
from typing import Tuple |
|
|
|
import streamlit as st |
|
|
|
import torch |
|
|
|
import transformers |
|
import tokenizers |
|
|
|
from sampling import CAIFSampler, TopKWithTemperatureSampler |
|
from generator import Generator |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def main(): |
|
st.subheader("CAIF") |
|
cls_model_name = st.selectbox( |
|
'Выберите модель классификации', |
|
('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base', |
|
'tinkoff-ai/response-quality-classifier-large') |
|
) |
|
lm_model_name = st.selectbox( |
|
'Выберите языковую модель', |
|
('sberbank-ai/rugpt3small_based_on_gpt2',) |
|
) |
|
prompt = st.text_input("Начало текста:", "Привет") |
|
auth_token = os.environ.get('TOKEN') or True |
|
with st.spinner('Running inference...'): |
|
text = inference(lm_model_name=lm_model_name, cls_model_name=cls_model_name, prompt=prompt) |
|
|
|
|
|
|
|
|
|
@st.cache(hash_funcs={str: lambda lm_model_name: hash(lm_model_name)}, allow_output_mutation=True) |
|
def load_generator(lm_model_name: str) -> Generator: |
|
with st.spinner('Loading language model...'): |
|
generator = Generator(lm_model_name=lm_model_name, device=device) |
|
return generator |
|
|
|
|
|
def load_sampler(cls_model_name, lm_tokenizer): |
|
with st.spinner('Loading classifier model...'): |
|
sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device) |
|
return sampler |
|
|
|
|
|
@st.cache |
|
def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool = True) -> str: |
|
generator = load_generator(lm_model_name=lm_model_name) |
|
lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name) |
|
caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer) |
|
generator.set_caif_sampler(caif_sampler) |
|
ordinary_sampler = TopKWithTemperatureSampler() |
|
kwargs = { |
|
"top_k": 20, |
|
"temperature": 1.0, |
|
"top_k_classifier": 100, |
|
"classifier_weight": 5, |
|
} |
|
generator.set_ordinary_sampler(ordinary_sampler) |
|
if device == "cpu": |
|
autocast = torch.cpu.amp.autocast |
|
else: |
|
autocast = torch.cuda.amp.autocast |
|
with autocast(fp16): |
|
print(f"Generating for prompt: {prompt}") |
|
sequences, tokens = generator.sample_sequences( |
|
num_samples=1, |
|
input_prompt=prompt, |
|
max_length=20, |
|
caif_period=1, |
|
caif_tokens_num=100, |
|
entropy=None, |
|
**kwargs |
|
) |
|
print(f"Output for prompt: {sequences}") |
|
return sequences[0] |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|