Spaces:
Runtime error
Runtime error
File size: 2,348 Bytes
777cb96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from joblib import load
### pip install faiss-cpu
import faiss
### pip install datasets
from datasets import Dataset
import torch
import pandas as pd
import streamlit as st
device = 'cpu'
### подгрузка всех компонентов - модель, токенайзер и датасет с эмбеддингами
embeddings_dataset = load('./embeddings_dataset.joblib')
tokenizer = load('./tokenizer.joblib')
model = load('./model.joblib')
### функция возвращающая от БЕРТа только [CLS] опиывающий общий смысл всего предложения
def embed_bert_cls(text, model, tokenizer):
t = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = model(**{k: v.to(device) for k, v in t.items()})
embeddings = model_output.last_hidden_state[:, 0, :]
embeddings = torch.nn.functional.normalize(embeddings)
return embeddings[0].cpu().numpy()
### функция ниже отдает готовый датасет с рекомендациями книг
def recommend(input_string,n_neighbors=5):
### input_string - то, что вводит пользователь в аннотации, эмбеддинг пользовательского текста
question_embedding = embed_bert_cls([input_string], model, tokenizer)
### n_neighbors - число предлагаемых системой книг, вводит пользователь,
### поиск похожих книг по запросу
scores, samples = embeddings_dataset.get_nearest_examples(
"embeddings", question_embedding, k=n_neighbors
)
### для корректной работы требуется формат таблиц huggingface, поэтому в конце
### происходит перевод в пандас для удобства
samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)
return samples_df
### конечный датасет: samples_df
user_input = st.text_input('Your text here:', )
number = st.number_input('Insert a number', min_value = 1, max_value = 5, value = 3)
if len(user_input) > 1:
st.write(recommend(user_input, number)) |