Grosy commited on
Commit
5bec655
1 Parent(s): d1a4946

multi model update

Browse files
Files changed (1) hide show
  1. app.py +15 -1
app.py CHANGED
@@ -52,15 +52,26 @@ def load_model_and_tokenizer():
52
  print(type(tokenizer))
53
  print(type(model))
54
  return model, tokenizer
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  model,tokenizer = load_model_and_tokenizer();
 
58
  raw_text_file = 'joint_text_filtered.md'
59
  all_sentences = load_raw_sentences(raw_text_file)
60
 
61
  embeddings_file = 'multibert_embedded.pt' #alternative: hunbert_embedded.pt
62
  all_embeddings = load_embeddings(embeddings_file)
63
-
 
64
 
65
  st.header('RF szöveg kereső')
66
 
@@ -74,6 +85,9 @@ if text_area_input_query:
74
  query_embedding = calculateEmbeddings([text_area_input_query],tokenizer,model)
75
  top_pairs = findTopKMostSimilar(query_embedding, all_embeddings, all_sentences, 5)
76
  st.json(top_pairs)
 
 
 
77
 
78
 
79
 
 
52
  print(type(tokenizer))
53
  print(type(model))
54
  return model, tokenizer
55
+
56
+ @st.cache(hash_funcs={transformers.models.bert.tokenization_bert_fast.BertTokenizerFast: lambda _: None, transformers.models.bert.modeling_bert.BertModel: lambda _: None})
57
+ def load_hu_model_and_tokenizer():
58
+ multilingual_checkpoint = 'sentence-transformers/SZTAKI-HLT/hubert-base-cc' #alternative: SZTAKI-HLT/hubert-base-cc
59
+ tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
60
+ model = AutoModel.from_pretrained(multilingual_checkpoint)
61
+ print(type(tokenizer))
62
+ print(type(model))
63
+ return model, tokenizer
64
 
65
 
66
  model,tokenizer = load_model_and_tokenizer();
67
+ model_hu,tokenizer_hu = load_hu_model_and_tokenizer();
68
  raw_text_file = 'joint_text_filtered.md'
69
  all_sentences = load_raw_sentences(raw_text_file)
70
 
71
  embeddings_file = 'multibert_embedded.pt' #alternative: hunbert_embedded.pt
72
  all_embeddings = load_embeddings(embeddings_file)
73
+ embeddings_file_hu = 'hunbert_embedded.pt'
74
+ all_embeddings_hu = load_embeddings(embeddings_file_hu)
75
 
76
  st.header('RF szöveg kereső')
77
 
 
85
  query_embedding = calculateEmbeddings([text_area_input_query],tokenizer,model)
86
  top_pairs = findTopKMostSimilar(query_embedding, all_embeddings, all_sentences, 5)
87
  st.json(top_pairs)
88
+ query_embedding = calculateEmbeddings([text_area_input_query],tokenizer_hu,model_hu)
89
+ top_pairs = findTopKMostSimilar(query_embedding, all_embeddings_hu, all_sentences, 5)
90
+ st.json(top_pairs)
91
 
92
 
93