File size: 1,241 Bytes
777cb96
 
 
 
 
 
 
 
 
5963bd7
777cb96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from joblib import load
import textwrap
import streamlit as st

device = 'cpu'

tokenizer = load('./pages/tokenizer.joblib')
model = load('./pages/model.joblib')
model.load_state_dict(torch.load('./pages/model_weights.pt', map_location=device))

temperature = st.slider('Градус дичи:', min_value = 1., max_value = 20., value = 3.)
num_beams = st.slider('Число веток для поиска:', min_value = 1, max_value = 15, value = 7)
max_length = st.slider('Максимальная длина генерации:', min_value = 50, max_value = 150, value = 70)

prompt = st.text_input('Дайте волю фантазии!',)
if len(prompt) > 1:
    with torch.inference_mode():
        prompt = tokenizer.encode(prompt, return_tensors='pt').to(device)
        out = model.generate(
            input_ids=prompt,
            max_length=max_length,
            num_beams=num_beams,
            do_sample=True,
            temperature=temperature,
            top_k=50,
            top_p=0.6,
            no_repeat_ngram_size=3,
            num_return_sequences=3,
            ).cpu().numpy()
        for out_ in out:
            st.write(textwrap.fill(tokenizer.decode(out_), 40), end='\n------------------\n')