book-search / pages /annotation_generator.py
EgorDan's picture
deleted assignmetn model to weights
5963bd7
raw
history blame
1.24 kB
import torch
from joblib import load
import textwrap
import streamlit as st
device = 'cpu'
tokenizer = load('./pages/tokenizer.joblib')
model = load('./pages/model.joblib')
model.load_state_dict(torch.load('./pages/model_weights.pt', map_location=device))
temperature = st.slider('Градус дичи:', min_value = 1., max_value = 20., value = 3.)
num_beams = st.slider('Число веток для поиска:', min_value = 1, max_value = 15, value = 7)
max_length = st.slider('Максимальная длина генерации:', min_value = 50, max_value = 150, value = 70)
prompt = st.text_input('Дайте волю фантазии!',)
if len(prompt) > 1:
with torch.inference_mode():
prompt = tokenizer.encode(prompt, return_tensors='pt').to(device)
out = model.generate(
input_ids=prompt,
max_length=max_length,
num_beams=num_beams,
do_sample=True,
temperature=temperature,
top_k=50,
top_p=0.6,
no_repeat_ngram_size=3,
num_return_sequences=3,
).cpu().numpy()
for out_ in out:
st.write(textwrap.fill(tokenizer.decode(out_), 40), end='\n------------------\n')