File size: 2,152 Bytes
eb4714a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from torch import nn
from joblib import load
import textwrap
import streamlit as st

device = 'cpu'



class GenreNet(nn.Module):
    def __init__(self, config):
        super().__init__()
        # параметры сетиnspose arrayroupout']
        self.dropout = config['dropout']
        self.out_range = config['out_range']


        # финальный полносвязный слой для пронгоза оценки
        self.head = nn.Sequential(
            nn.Linear(312, 256),
            nn.Dropout(self.dropout[0]),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.Dropout(self.dropout[0]),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.Dropout(self.dropout[0]),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    

    def forward(self, emb):
        x = torch.sigmoid(self.head(emb))
        x = x * (self.out_range[1] - self.out_range[0]) + self.out_range[0]
        return(x)

config = {
    'dropout': [.5],
    'out_range': [1.,5.] # для номировки выходных оценок
    } 

bert = load('./model.joblib')
model = GenreNet(config)
model.load_state_dict(torch.load('./pages/weights_los065_ep100_lr0001_lay256_128_64_1.pt', map_location=device))
tokenizer = load('./tokenizer.joblib')

def embed_bert_cls(text, model, tokenizer):
    t = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        model_output = model(**{k: v.to(device) for k, v in t.items()})
    embeddings = model_output.last_hidden_state[:, 0, :]
    embeddings = torch.nn.functional.normalize(embeddings)
    return embeddings[0]

genre = {1 : 'Романтика', 2:'Поэзия', 3:'Детектив', 4:'Приключения', 5:'Фантастика', }

prompt = st.text_input('Узнаем жанр!',)
if len(prompt) > 1:
    with torch.inference_mode():
        prompt_embedding = embed_bert_cls([prompt], bert, tokenizer)
        out = model(prompt_embedding).cpu().numpy()
        #for out_ in out:
        st.write('Предполагаемый жанр:', genre[int(round(out.item(), 0))])