import torch from torch import nn from joblib import load import textwrap import streamlit as st device = 'cpu' class GenreNet(nn.Module): def __init__(self, config): super().__init__() # параметры сетиnspose arrayroupout'] self.dropout = config['dropout'] self.out_range = config['out_range'] # финальный полносвязный слой для пронгоза оценки self.head = nn.Sequential( nn.Linear(312, 256), nn.Dropout(self.dropout[0]), nn.ReLU(), nn.Linear(256, 128), nn.Dropout(self.dropout[0]), nn.ReLU(), nn.Linear(128, 64), nn.Dropout(self.dropout[0]), nn.ReLU(), nn.Linear(64, 1), ) def forward(self, emb): x = torch.sigmoid(self.head(emb)) x = x * (self.out_range[1] - self.out_range[0]) + self.out_range[0] return(x) config = { 'dropout': [.5], 'out_range': [1.,5.] # для номировки выходных оценок } bert = load('./model.joblib') model = GenreNet(config) model.load_state_dict(torch.load('./pages/weights_los065_ep100_lr0001_lay256_128_64_1.pt', map_location=device)) tokenizer = load('./tokenizer.joblib') def embed_bert_cls(text, model, tokenizer): t = tokenizer(text, padding=True, truncation=True, return_tensors='pt') with torch.no_grad(): model_output = model(**{k: v.to(device) for k, v in t.items()}) embeddings = model_output.last_hidden_state[:, 0, :] embeddings = torch.nn.functional.normalize(embeddings) return embeddings[0] genre = {1 : 'Романтика', 2:'Поэзия', 3:'Детектив', 4:'Приключения', 5:'Фантастика', } prompt = st.text_input('Узнаем жанр!',) if len(prompt) > 1: with torch.inference_mode(): prompt_embedding = embed_bert_cls([prompt], bert, tokenizer) out = model(prompt_embedding).cpu().numpy() #for out_ in out: st.write('Предполагаемый жанр:', genre[int(round(out.item(), 0))])