caif / app.py
Балаганский Никита Николаевич
add logging
116ed83
raw
history blame
2.71 kB
import os
from typing import Tuple
import streamlit as st
import torch
import transformers
import tokenizers
from sampling import CAIFSampler, TopKWithTemperatureSampler
from generator import Generator
device = "cuda" if torch.cuda.is_available() else "cpu"
def main():
st.subheader("CAIF")
cls_model_name = st.selectbox(
'Выберите модель классификации',
('tinkoff-ai/response-quality-classifier-tiny', 'tinkoff-ai/response-quality-classifier-base',
'tinkoff-ai/response-quality-classifier-large')
)
lm_model_name = st.selectbox(
'Выберите языковую модель',
('sberbank-ai/rugpt3small_based_on_gpt2',)
)
prompt = st.text_input("Начало текста:", "Привет")
auth_token = os.environ.get('TOKEN') or True
with st.spinner('Running inference...'):
text = inference(lm_model_name=lm_model_name, cls_model_name=cls_model_name, prompt=prompt)
@st.cache(hash_funcs={str: lambda lm_model_name: hash(lm_model_name)}, allow_output_mutation=True)
def load_generator(lm_model_name: str) -> Generator:
with st.spinner('Loading language model...'):
generator = Generator(lm_model_name=lm_model_name, device=device)
return generator
def load_sampler(cls_model_name, lm_tokenizer):
with st.spinner('Loading classifier model...'):
sampler = CAIFSampler(classifier_name=cls_model_name, lm_tokenizer=lm_tokenizer, device=device)
return sampler
@st.cache
def inference(lm_model_name: str, cls_model_name: str, prompt: str, fp16: bool = True) -> str:
generator = load_generator(lm_model_name=lm_model_name)
lm_tokenizer = transformers.AutoTokenizer.from_pretrained(lm_model_name)
caif_sampler = load_sampler(cls_model_name=cls_model_name, lm_tokenizer=lm_tokenizer)
generator.set_caif_sampler(caif_sampler)
ordinary_sampler = TopKWithTemperatureSampler()
kwargs = {
"top_k": 20,
"temperature": 1.0,
"top_k_classifier": 100,
"classifier_weight": 5,
}
generator.set_ordinary_sampler(ordinary_sampler)
if device == "cpu":
autocast = torch.cpu.amp.autocast
else:
autocast = torch.cuda.amp.autocast
with autocast(fp16):
print(f"Generating for prompt: {prompt}")
sequences, tokens = generator.sample_sequences(
num_samples=1,
input_prompt=prompt,
max_length=20,
caif_period=1,
caif_tokens_num=100,
entropy=None,
**kwargs
)
print(f"Output for prompt: {sequences}")
return sequences[0]
if __name__ == "__main__":
main()