Grosy commited on
Commit
0428dd0
1 Parent(s): 7980b9a

Initial app code, based on Endre/SemanticSearch-HU

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import transformers
4
+ import torch
5
+ from sentence_transformers import util
6
+
7
+ # explicit no operation hash functions defined, because raw sentences, embedding, model and tokenizer are not going to change
8
+
9
+
10
+ @st.cache(hash_funcs={list: lambda _: None})
11
+ def load_raw_sentences(filename):
12
+ with open(filename) as f:
13
+ return f.readlines()
14
+
15
+ @st.cache(hash_funcs={torch.Tensor: lambda _: None})
16
+ def load_embeddings(filename):
17
+ with open(filename) as f:
18
+ return torch.load(filename,map_location=torch.device('cpu') )
19
+
20
+
21
+ #Mean Pooling - Take attention mask into account for correct averaging
22
+ def mean_pooling(model_output, attention_mask):
23
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
24
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
25
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
26
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
27
+ return sum_embeddings / sum_mask
28
+
29
+ def findTopKMostSimilar(query_embedding, embeddings, all_sentences, k):
30
+ cosine_scores = util.pytorch_cos_sim(query_embedding, embeddings)
31
+ cosine_scores_list = cosine_scores.squeeze().tolist()
32
+ pairs = []
33
+ for idx,score in enumerate(cosine_scores_list):
34
+ if idx < len(all_sentences):
35
+ pairs.append({'score': '{:.4f}'.format(score), 'text': all_sentences[idx]})
36
+ pairs = sorted(pairs, key=lambda x: x['score'], reverse=True)
37
+ return pairs[0:k]
38
+
39
+ def calculateEmbeddings(sentences,tokenizer,model):
40
+ tokenized_sentences = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')
41
+ with torch.no_grad():
42
+ model_output = model(**tokenized_sentences)
43
+ sentence_embeddings = mean_pooling(model_output, tokenized_sentences['attention_mask'])
44
+ return sentence_embeddings
45
+
46
+ # explicit no operation hash function, because model and tokenizer are not going to change
47
+ @st.cache(hash_funcs={transformers.models.bert.tokenization_bert_fast.BertTokenizerFast: lambda _: None, transformers.models.bert.modeling_bert.BertModel: lambda _: None})
48
+ def load_model_and_tokenizer():
49
+ multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
50
+ tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
51
+ model = AutoModel.from_pretrained(multilingual_checkpoint)
52
+ print(type(tokenizer))
53
+ print(type(model))
54
+ return model, tokenizer
55
+
56
+
57
+ model,tokenizer = load_model_and_tokenizer();
58
+ raw_text_file = 'data/data/joint_text_filtered.md'
59
+ all_sentences = load_raw_sentences(raw_text_file)
60
+
61
+ embeddings_file = 'data/multibert_embedded.pt'
62
+ all_embeddings = load_embeddings(embeddings_file)
63
+
64
+
65
+ st.header('RF szöveg kereső')
66
+
67
+ st.caption('[HU] Adjon meg egy tetszőleges kifejezést és a rendszer visszaadja az 5 hozzá legjobban hasonlító szöveget')
68
+
69
+
70
+
71
+ text_area_input_query = st.text_area('[HU] Beviteli mező - [EN] Query input',value='Mikor van a leadási hataridő?')
72
+
73
+ if text_area_input_query:
74
+ query_embedding = calculateEmbeddings([text_area_input_query],tokenizer,model)
75
+ top_pairs = findTopKMostSimilar(query_embedding, all_embeddings, all_sentences, 5)
76
+ st.json(top_pairs)
77
+
78
+
79
+