multi model update
Browse files
app.py
CHANGED
@@ -52,15 +52,26 @@ def load_model_and_tokenizer():
|
|
52 |
print(type(tokenizer))
|
53 |
print(type(model))
|
54 |
return model, tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
|
57 |
model,tokenizer = load_model_and_tokenizer();
|
|
|
58 |
raw_text_file = 'joint_text_filtered.md'
|
59 |
all_sentences = load_raw_sentences(raw_text_file)
|
60 |
|
61 |
embeddings_file = 'multibert_embedded.pt' #alternative: hunbert_embedded.pt
|
62 |
all_embeddings = load_embeddings(embeddings_file)
|
63 |
-
|
|
|
64 |
|
65 |
st.header('RF szöveg kereső')
|
66 |
|
@@ -74,6 +85,9 @@ if text_area_input_query:
|
|
74 |
query_embedding = calculateEmbeddings([text_area_input_query],tokenizer,model)
|
75 |
top_pairs = findTopKMostSimilar(query_embedding, all_embeddings, all_sentences, 5)
|
76 |
st.json(top_pairs)
|
|
|
|
|
|
|
77 |
|
78 |
|
79 |
|
|
|
52 |
print(type(tokenizer))
|
53 |
print(type(model))
|
54 |
return model, tokenizer
|
55 |
+
|
56 |
+
@st.cache(hash_funcs={transformers.models.bert.tokenization_bert_fast.BertTokenizerFast: lambda _: None, transformers.models.bert.modeling_bert.BertModel: lambda _: None})
|
57 |
+
def load_hu_model_and_tokenizer():
|
58 |
+
multilingual_checkpoint = 'sentence-transformers/SZTAKI-HLT/hubert-base-cc' #alternative: SZTAKI-HLT/hubert-base-cc
|
59 |
+
tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
|
60 |
+
model = AutoModel.from_pretrained(multilingual_checkpoint)
|
61 |
+
print(type(tokenizer))
|
62 |
+
print(type(model))
|
63 |
+
return model, tokenizer
|
64 |
|
65 |
|
66 |
model,tokenizer = load_model_and_tokenizer();
|
67 |
+
model_hu,tokenizer_hu = load_hu_model_and_tokenizer();
|
68 |
raw_text_file = 'joint_text_filtered.md'
|
69 |
all_sentences = load_raw_sentences(raw_text_file)
|
70 |
|
71 |
embeddings_file = 'multibert_embedded.pt' #alternative: hunbert_embedded.pt
|
72 |
all_embeddings = load_embeddings(embeddings_file)
|
73 |
+
embeddings_file_hu = 'hunbert_embedded.pt'
|
74 |
+
all_embeddings_hu = load_embeddings(embeddings_file_hu)
|
75 |
|
76 |
st.header('RF szöveg kereső')
|
77 |
|
|
|
85 |
query_embedding = calculateEmbeddings([text_area_input_query],tokenizer,model)
|
86 |
top_pairs = findTopKMostSimilar(query_embedding, all_embeddings, all_sentences, 5)
|
87 |
st.json(top_pairs)
|
88 |
+
query_embedding = calculateEmbeddings([text_area_input_query],tokenizer_hu,model_hu)
|
89 |
+
top_pairs = findTopKMostSimilar(query_embedding, all_embeddings_hu, all_sentences, 5)
|
90 |
+
st.json(top_pairs)
|
91 |
|
92 |
|
93 |
|