Spaces:
Runtime error
Runtime error
File size: 2,152 Bytes
eb4714a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import torch
from torch import nn
from joblib import load
import textwrap
import streamlit as st
device = 'cpu'
class GenreNet(nn.Module):
def __init__(self, config):
super().__init__()
# параметры сетиnspose arrayroupout']
self.dropout = config['dropout']
self.out_range = config['out_range']
# финальный полносвязный слой для пронгоза оценки
self.head = nn.Sequential(
nn.Linear(312, 256),
nn.Dropout(self.dropout[0]),
nn.ReLU(),
nn.Linear(256, 128),
nn.Dropout(self.dropout[0]),
nn.ReLU(),
nn.Linear(128, 64),
nn.Dropout(self.dropout[0]),
nn.ReLU(),
nn.Linear(64, 1),
)
def forward(self, emb):
x = torch.sigmoid(self.head(emb))
x = x * (self.out_range[1] - self.out_range[0]) + self.out_range[0]
return(x)
config = {
'dropout': [.5],
'out_range': [1.,5.] # для номировки выходных оценок
}
bert = load('./model.joblib')
model = GenreNet(config)
model.load_state_dict(torch.load('./pages/weights_los065_ep100_lr0001_lay256_128_64_1.pt', map_location=device))
tokenizer = load('./tokenizer.joblib')
def embed_bert_cls(text, model, tokenizer):
t = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = model(**{k: v.to(device) for k, v in t.items()})
embeddings = model_output.last_hidden_state[:, 0, :]
embeddings = torch.nn.functional.normalize(embeddings)
return embeddings[0]
genre = {1 : 'Романтика', 2:'Поэзия', 3:'Детектив', 4:'Приключения', 5:'Фантастика', }
prompt = st.text_input('Узнаем жанр!',)
if len(prompt) > 1:
with torch.inference_mode():
prompt_embedding = embed_bert_cls([prompt], bert, tokenizer)
out = model(prompt_embedding).cpu().numpy()
#for out_ in out:
st.write('Предполагаемый жанр:', genre[int(round(out.item(), 0))])
|