endre sukosd commited on
Commit
eb6656d
1 Parent(s): 31f8e76

Streamlit app fixup

Browse files
README.md CHANGED
@@ -1,18 +1,11 @@
1
  ---
2
  title: SemanticSearch HU
3
-
4
  emoji: 💻
5
-
6
  colorFrom: green
7
-
8
  colorTo: white
9
-
10
  sdk: streamlit
11
-
12
  app_file: src/app.py
13
-
14
  pinned: false
15
-
16
  ---
17
 
18
  # Huggingface Course Project - 2021 November
 
1
  ---
2
  title: SemanticSearch HU
 
3
  emoji: 💻
 
4
  colorFrom: green
 
5
  colorTo: white
 
6
  sdk: streamlit
 
7
  app_file: src/app.py
 
8
  pinned: false
 
9
  ---
10
 
11
  # Huggingface Course Project - 2021 November
notebooks/QA_retrieval_precalculate_embeddings.ipynb CHANGED
@@ -1 +1 @@
1
- {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"QA_retrieval_huggingface_couser_2021_Nov.ipynb","provenance":[],"collapsed_sections":[],"mount_file_id":"1e_NcpgIuSh8rfI_Xf16ltcybK8TbgJWB","authorship_tag":"ABX9TyN3TvKBRyS+wRVSLWNFgC+f"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","metadata":{"id":"GI4Sz98ItJW7"},"source":["# TPU\n","# !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.10-cp37-cp37m-linux_x86_64.whl"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"97-OsdFhlD20","executionInfo":{"status":"ok","timestamp":1637680969592,"user_tz":-60,"elapsed":3348,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}},"outputId":"c47a98a7-f016-4a4f-827b-edc9229c5eca"},"source":["!pip install transformers sentence_transformers"],"execution_count":9,"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.12.5)\n","Requirement already satisfied: sentence_transformers in /usr/local/lib/python3.7/dist-packages (2.1.0)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (6.0)\n","Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n","Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.10.3)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.1.2)\n","Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers) (0.0.46)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.4.0)\n","Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.8.2)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.62.3)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (3.10.0.2)\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.6)\n","Requirement already satisfied: nltk in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (3.2.5)\n","Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.0.1)\n","Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.4.1)\n","Requirement already satisfied: sentencepiece in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (0.1.96)\n","Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (0.11.1+cu111)\n","Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.10.0+cu111)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.6.0)\n","Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from nltk->sentence_transformers) (1.15.0)\n","Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.10.8)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n","Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n","Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)\n","Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.1.0)\n","Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->sentence_transformers) (3.0.0)\n","Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision->sentence_transformers) (7.1.2)\n"]}]},{"cell_type":"code","metadata":{"id":"3-jkyQkdkdPQ","executionInfo":{"status":"ok","timestamp":1637680970023,"user_tz":-60,"elapsed":3,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}}},"source":["from transformers import AutoTokenizer, AutoModel\n","import torch\n","import pickle\n","from sentence_transformers import util\n","from datetime import datetime"],"execution_count":10,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"kA2h5mH8m-n8","executionInfo":{"status":"ok","timestamp":1637654036646,"user_tz":-60,"elapsed":26589,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}},"outputId":"88fcd97f-276c-4f70-de60-d1c5c9810443"},"source":["from google.colab import drive\n","drive.mount('/content/drive')\n","#drive.mount('/content/drive', force_remount=True)"],"execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}]},{"cell_type":"markdown","metadata":{"id":"b8SkQGWuB1z7"},"source":["# Load pretrained \n","\n","- multilingual sentence transformers from checkpoint\n","- tokenizer from checkpoint"]},{"cell_type":"code","metadata":{"id":"1R83LLVAk98K","executionInfo":{"status":"ok","timestamp":1637655426545,"user_tz":-60,"elapsed":6237,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}}},"source":["multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'\n","tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)\n","model = AutoModel.from_pretrained(multilingual_checkpoint)"],"execution_count":3,"outputs":[]},{"cell_type":"code","metadata":{"id":"wcdik3tQpkyi"},"source":["# GPU\n","device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n","model.to(device)\n","print(device)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"-YzAkemLsrC9"},"source":["# TPU\n","# unfortunately incompatible wheel package for pytorch-xla 1.10 version\n","#import torch_xla.core.xla_model as xm\n","#device = xm.xla_device()\n","#print(device)\n","#pip list | grep torch"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"dfeEQJOglxdw","executionInfo":{"status":"ok","timestamp":1637682096594,"user_tz":-60,"elapsed":362,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}}},"source":["#Mean Pooling - Take attention mask into account for correct averaging\n","def mean_pooling(model_output, attention_mask):\n"," token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n"," input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n"," sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)\n"," sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n"," return sum_embeddings / sum_mask\n","\n","def calculateEmbeddings(sentences,tokenizer,model,device=\"cpu\"):\n"," tokenized_sentences = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')\n"," tokenized_sentences.to(device)\n"," with torch.no_grad():\n"," model_output = model(**tokenized_sentences)\n"," sentence_embeddings = mean_pooling(model_output, tokenized_sentences['attention_mask'])\n"," del tokenized_sentences\n"," torch.cuda.empty_cache()\n"," return sentence_embeddings\n","\n","def findTopKMostSimilar(query_embedding, embeddings, k):\n"," cosine_scores = util.pytorch_cos_sim(query_embedding, embeddings)\n"," cosine_scores_list = cosine_scores.squeeze().tolist()\n"," pairs = []\n"," for idx,score in enumerate(cosine_scores_list):\n"," pairs.append({'index': idx, 'score': score})\n"," pairs = sorted(pairs, key=lambda x: x['score'], reverse=True)\n"," return pairs[0:k]\n","\n","def saveToDisc(embeddings, output_filename):\n"," with open(output_filename, \"ab\") as f:\n"," pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)"],"execution_count":23,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MddjkKfMCH81"},"source":["# Create sentence embeddings\n","\n","\n","* Load sentences from raw text file\n","* Precalculate in batches of 1000, to avoid running out of memory\n","* Save to disc/files incrementally, to be able to reuse later (in total 5 files of 100.000 embedding each)\n","\n"]},{"cell_type":"code","metadata":{"id":"yfOsCAVImIAl"},"source":["batch_size = 1000\n","\n","raw_text_file = '/content/drive/MyDrive/huggingface/shortened_abstracts_hu_2021_09_01.txt'\n","datetime_formatted = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')\n","output_embeddings_file_batched = f'/content/drive/MyDrive/huggingface/embeddings_{batch_size}_batches_at_{datetime_formatted}.pkl'\n","output_embeddings_file = f'/content/drive/MyDrive/huggingface/embeddings_at_{datetime_formatted}.pkl'\n","\n","print(datetime.now())\n","concated_sentence_embeddings = None\n","all_sentences = []\n","line = 'init'\n","total_read = 0\n","total_read_limit = 500000\n","skip_index = 400000\n","with open(raw_text_file) as f:\n"," while line and total_read < total_read_limit:\n"," count = 0\n"," sentence_batch = []\n"," while line and count < batch_size:\n"," line = f.readline()\n"," sentence_batch.append(line)\n"," count += 1\n"," \n"," all_sentences.extend(sentence_batch)\n"," \n"," if total_read >= skip_index:\n"," sentence_embeddings = calculateEmbeddings(sentence_batch,tokenizer,model,device)\n"," if concated_sentence_embeddings == None:\n"," concated_sentence_embeddings = sentence_embeddings\n"," else:\n"," concated_sentence_embeddings = torch.cat([concated_sentence_embeddings, sentence_embeddings], dim=0)\n"," print(concated_sentence_embeddings.size())\n"," saveToDisc(sentence_embeddings,output_embeddings_file_batched)\n"," total_read += count\n","print(datetime.now())"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1rGQc9GRCuNy"},"source":["# Test: Query embeddings"]},{"cell_type":"code","metadata":{"id":"FT7CwpM0Bwhi"},"source":["query_embedding = calculateEmbeddings(['Melyik a legnépesebb város a világon?'],tokenizer,model,device)\n","top_pairs = findTopKMostSimilar(query_embedding, concated_sentence_embeddings, 5)\n","\n","for pair in top_pairs:\n"," i = pair['index']\n"," score = pair['score']\n"," print(\"{} \\t\\t Score: {:.4f}\".format(all_sentences[skip_index+i], score))"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"6Hdu_5FiDYJr"},"source":["# Test: Load pre-calculated embeddings\n","\n","* Load embedding from files and stitch them together\n","* Save into one file\n"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"gkWt0Uj_Ddsp","executionInfo":{"status":"ok","timestamp":1637682006152,"user_tz":-60,"elapsed":1722,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}},"outputId":"1921456e-1fd6-4218-9ebb-cbe503f402b1"},"source":["def concatTensors(new_tensor, acc_tensor='None'):\n"," if acc_tensor == None:\n"," acc_tensor = new_tensor\n"," else:\n"," acc_tensor = torch.cat([acc_tensor, new_tensor], dim=0)\n"," return acc_tensor\n","\n","def loadFromDisc(batch_size, number_of_batches, filename):\n"," concated_sentence_embeddings = None\n"," count = 0\n"," batches = 0\n"," with open(filename, \"rb\") as f:\n"," loaded_embeddings = torch.empty([batch_size])\n"," while count < number_of_batches and loaded_embeddings.size()[0]==batch_size:\n"," loaded_embeddings = pickle.load(f)\n"," count += 1\n"," concated_sentence_embeddings = concatTensors(loaded_embeddings,concated_sentence_embeddings)\n"," print(f'Read file using {count} number of read+unpickle operations')\n"," print(concated_sentence_embeddings.size())\n"," return concated_sentence_embeddings\n","\n","\n","output_embeddings_file = 'data/processed/DBpedia_shortened_abstracts_hu_embeddings.pkl'\n","\n","embeddings_files = [\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:17:17.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:28:46.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:40:54.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:56:26.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_09:31:47.pkl'\n","]\n","\n","all_embeddings = None\n","for idx,emb_file in enumerate(embeddings_files):\n"," print(f'Processing file {idx}')\n"," file_embeddings = loadFromDisc(1000, 100, emb_file)\n"," all_embeddings = concatTensors(file_embeddings,all_embeddings)\n","\n","print(all_embeddings.size())"],"execution_count":20,"outputs":[{"output_type":"stream","name":"stdout","text":["Processing file 0\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 1\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 2\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 3\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 4\n","Read file using 67 number of read+unpickle operations\n","torch.Size([66529, 384])\n","torch.Size([466529, 384])\n"]}]},{"cell_type":"code","metadata":{"id":"M_8RHpNnIU7o","executionInfo":{"status":"ok","timestamp":1637683739951,"user_tz":-60,"elapsed":384,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}}},"source":["all_embeddings_output_file = '/content/drive/MyDrive/huggingface/shortened_abstracts_hu_2021_09_01_embedded.pt'\n","#saveToDisc(all_embeddings, all_embeddings_output_file)\n","torch.save(all_embeddings,all_embeddings_output_file)"],"execution_count":28,"outputs":[]},{"cell_type":"code","metadata":{"id":"LYCwyDpMjsXg"},"source":[""],"execution_count":null,"outputs":[]}]}
 
1
+ {"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"GI4Sz98ItJW7"},"outputs":[],"source":["# TPU\n","# !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.10-cp37-cp37m-linux_x86_64.whl"]},{"cell_type":"code","execution_count":9,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":3348,"status":"ok","timestamp":1637680969592,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"97-OsdFhlD20","outputId":"c47a98a7-f016-4a4f-827b-edc9229c5eca"},"outputs":[{"name":"stdout","output_type":"stream","text":["Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.12.5)\n","Requirement already satisfied: sentence_transformers in /usr/local/lib/python3.7/dist-packages (2.1.0)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (6.0)\n","Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n","Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.10.3)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.1.2)\n","Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers) (0.0.46)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.4.0)\n","Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.8.2)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.62.3)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (3.10.0.2)\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.6)\n","Requirement already satisfied: nltk in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (3.2.5)\n","Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.0.1)\n","Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.4.1)\n","Requirement already satisfied: sentencepiece in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (0.1.96)\n","Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (0.11.1+cu111)\n","Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.10.0+cu111)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.6.0)\n","Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from nltk->sentence_transformers) (1.15.0)\n","Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.10.8)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n","Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n","Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)\n","Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.1.0)\n","Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->sentence_transformers) (3.0.0)\n","Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision->sentence_transformers) (7.1.2)\n"]}],"source":["!pip install transformers sentence_transformers"]},{"cell_type":"code","execution_count":10,"metadata":{"executionInfo":{"elapsed":3,"status":"ok","timestamp":1637680970023,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"3-jkyQkdkdPQ"},"outputs":[],"source":["from transformers import AutoTokenizer, AutoModel\n","import torch\n","import pickle\n","from sentence_transformers import util\n","from datetime import datetime"]},{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":26589,"status":"ok","timestamp":1637654036646,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"kA2h5mH8m-n8","outputId":"88fcd97f-276c-4f70-de60-d1c5c9810443"},"outputs":[{"name":"stdout","output_type":"stream","text":["Mounted at /content/drive\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')\n","#drive.mount('/content/drive', force_remount=True)"]},{"cell_type":"markdown","metadata":{"id":"b8SkQGWuB1z7"},"source":["# Load pretrained \n","\n","- multilingual sentence transformers from checkpoint\n","- tokenizer from checkpoint"]},{"cell_type":"code","execution_count":3,"metadata":{"executionInfo":{"elapsed":6237,"status":"ok","timestamp":1637655426545,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"1R83LLVAk98K"},"outputs":[],"source":["multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'\n","tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)\n","model = AutoModel.from_pretrained(multilingual_checkpoint)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"wcdik3tQpkyi"},"outputs":[],"source":["# GPU\n","device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n","model.to(device)\n","print(device)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"-YzAkemLsrC9"},"outputs":[],"source":["# TPU\n","# unfortunately incompatible wheel package for pytorch-xla 1.10 version\n","#import torch_xla.core.xla_model as xm\n","#device = xm.xla_device()\n","#print(device)\n","#pip list | grep torch"]},{"cell_type":"code","execution_count":23,"metadata":{"executionInfo":{"elapsed":362,"status":"ok","timestamp":1637682096594,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"dfeEQJOglxdw"},"outputs":[],"source":["#Mean Pooling - Take attention mask into account for correct averaging\n","def mean_pooling(model_output, attention_mask):\n"," token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n"," input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n"," sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)\n"," sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n"," return sum_embeddings / sum_mask\n","\n","def calculateEmbeddings(sentences,tokenizer,model,device=\"cpu\"):\n"," tokenized_sentences = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')\n"," tokenized_sentences.to(device)\n"," with torch.no_grad():\n"," model_output = model(**tokenized_sentences)\n"," sentence_embeddings = mean_pooling(model_output, tokenized_sentences['attention_mask'])\n"," del tokenized_sentences\n"," torch.cuda.empty_cache()\n"," return sentence_embeddings\n","\n","def findTopKMostSimilar(query_embedding, embeddings, k):\n"," cosine_scores = util.pytorch_cos_sim(query_embedding, embeddings)\n"," cosine_scores_list = cosine_scores.squeeze().tolist()\n"," pairs = []\n"," for idx,score in enumerate(cosine_scores_list):\n"," pairs.append({'index': idx, 'score': score})\n"," pairs = sorted(pairs, key=lambda x: x['score'], reverse=True)\n"," return pairs[0:k]\n","\n","def saveToDisc(embeddings, output_filename):\n"," with open(output_filename, \"ab\") as f:\n"," pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)"]},{"cell_type":"markdown","metadata":{"id":"MddjkKfMCH81"},"source":["# Create sentence embeddings\n","\n","\n","* Load sentences from raw text file\n","* Precalculate in batches of 1000, to avoid running out of memory\n","* Save to disc/files incrementally, to be able to reuse later (in total 5 files of 100.000 embedding each)\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"yfOsCAVImIAl"},"outputs":[],"source":["batch_size = 1000\n","\n","raw_text_file = '/content/drive/MyDrive/huggingface/shortened_abstracts_hu_2021_09_01.txt'\n","datetime_formatted = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')\n","output_embeddings_file_batched = f'/content/drive/MyDrive/huggingface/embeddings_{batch_size}_batches_at_{datetime_formatted}.pkl'\n","output_embeddings_file = f'/content/drive/MyDrive/huggingface/embeddings_at_{datetime_formatted}.pkl'\n","\n","print(datetime.now())\n","concated_sentence_embeddings = None\n","all_sentences = []\n","line = 'init'\n","total_read = 0\n","total_read_limit = 500000\n","skip_index = 400000\n","with open(raw_text_file) as f:\n"," while line and total_read < total_read_limit:\n"," count = 0\n"," sentence_batch = []\n"," while line and count < batch_size:\n"," line = f.readline()\n"," sentence_batch.append(line)\n"," count += 1\n"," \n"," all_sentences.extend(sentence_batch)\n"," \n"," if total_read >= skip_index:\n"," sentence_embeddings = calculateEmbeddings(sentence_batch,tokenizer,model,device)\n"," if concated_sentence_embeddings == None:\n"," concated_sentence_embeddings = sentence_embeddings\n"," else:\n"," concated_sentence_embeddings = torch.cat([concated_sentence_embeddings, sentence_embeddings], dim=0)\n"," print(concated_sentence_embeddings.size())\n"," saveToDisc(sentence_embeddings,output_embeddings_file_batched)\n"," total_read += count\n","print(datetime.now())"]},{"cell_type":"markdown","metadata":{"id":"1rGQc9GRCuNy"},"source":["# Test: Query embeddings"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"FT7CwpM0Bwhi"},"outputs":[],"source":["query_embedding = calculateEmbeddings(['Melyik a legnépesebb város a világon?'],tokenizer,model,device)\n","top_pairs = findTopKMostSimilar(query_embedding, concated_sentence_embeddings, 5)\n","\n","for pair in top_pairs:\n"," i = pair['index']\n"," score = pair['score']\n"," print(\"{} \\t\\t Score: {:.4f}\".format(all_sentences[skip_index+i], score))"]},{"cell_type":"markdown","metadata":{"id":"6Hdu_5FiDYJr"},"source":["# Test: Load pre-calculated embeddings\n","\n","* Load embedding from files and stitch them together\n","* Save into one file\n"]},{"cell_type":"code","execution_count":20,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":1722,"status":"ok","timestamp":1637682006152,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"gkWt0Uj_Ddsp","outputId":"1921456e-1fd6-4218-9ebb-cbe503f402b1"},"outputs":[{"name":"stdout","output_type":"stream","text":["Processing file 0\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 1\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 2\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 3\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 4\n","Read file using 67 number of read+unpickle operations\n","torch.Size([66529, 384])\n","torch.Size([466529, 384])\n"]}],"source":["def concatTensors(new_tensor, acc_tensor='None'):\n"," if acc_tensor == None:\n"," acc_tensor = new_tensor\n"," else:\n"," acc_tensor = torch.cat([acc_tensor, new_tensor], dim=0)\n"," return acc_tensor\n","\n","def loadFromDisc(batch_size, number_of_batches, filename):\n"," concated_sentence_embeddings = None\n"," count = 0\n"," batches = 0\n"," with open(filename, \"rb\") as f:\n"," loaded_embeddings = torch.empty([batch_size])\n"," while count < number_of_batches and loaded_embeddings.size()[0]==batch_size:\n"," loaded_embeddings = pickle.load(f)\n"," count += 1\n"," concated_sentence_embeddings = concatTensors(loaded_embeddings,concated_sentence_embeddings)\n"," print(f'Read file using {count} number of read+unpickle operations')\n"," print(concated_sentence_embeddings.size())\n"," return concated_sentence_embeddings\n","\n","\n","output_embeddings_file = 'data/preprocessed/DBpedia_shortened_abstracts_hu_embeddings.pkl'\n","\n","embeddings_files = [\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:17:17.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:28:46.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:40:54.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:56:26.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_09:31:47.pkl'\n","]\n","\n","all_embeddings = None\n","for idx,emb_file in enumerate(embeddings_files):\n"," print(f'Processing file {idx}')\n"," file_embeddings = loadFromDisc(1000, 100, emb_file)\n"," all_embeddings = concatTensors(file_embeddings,all_embeddings)\n","\n","print(all_embeddings.size())"]},{"cell_type":"code","execution_count":28,"metadata":{"executionInfo":{"elapsed":384,"status":"ok","timestamp":1637683739951,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"M_8RHpNnIU7o"},"outputs":[],"source":["all_embeddings_output_file = '/content/drive/MyDrive/huggingface/shortened_abstracts_hu_2021_09_01_embedded.pt'\n","#saveToDisc(all_embeddings, all_embeddings_output_file)\n","torch.save(all_embeddings,all_embeddings_output_file)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"LYCwyDpMjsXg"},"outputs":[],"source":[]}],"metadata":{"accelerator":"GPU","colab":{"authorship_tag":"ABX9TyN3TvKBRyS+wRVSLWNFgC+f","collapsed_sections":[],"mount_file_id":"1e_NcpgIuSh8rfI_Xf16ltcybK8TbgJWB","name":"QA_retrieval_huggingface_couser_2021_Nov.ipynb","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}
src/app.py CHANGED
@@ -1,18 +1,23 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModel
 
3
  import torch
4
  from sentence_transformers import util
5
 
6
- @st.cache
 
 
 
7
  def load_raw_sentences(filename):
8
  with open(filename) as f:
9
  return f.readlines()
10
 
11
- @st.cache
12
  def load_embeddings(filename):
13
  with open(filename) as f:
14
  return torch.load(filename,map_location=torch.device('cpu') )
15
 
 
16
  #Mean Pooling - Take attention mask into account for correct averaging
17
  def mean_pooling(model_output, attention_mask):
18
  token_embeddings = model_output[0] #First element of model_output contains all token embeddings
@@ -27,7 +32,7 @@ def findTopKMostSimilar(query_embedding, embeddings, all_sentences, k):
27
  pairs = []
28
  for idx,score in enumerate(cosine_scores_list):
29
  if idx < len(all_sentences):
30
- pairs.append({'score': score, 'text': all_sentences[idx]})
31
  pairs = sorted(pairs, key=lambda x: x['score'], reverse=True)
32
  return pairs[0:k]
33
 
@@ -38,25 +43,35 @@ def calculateEmbeddings(sentences,tokenizer,model):
38
  sentence_embeddings = mean_pooling(model_output, tokenized_sentences['attention_mask'])
39
  return sentence_embeddings
40
 
 
 
 
 
 
 
 
 
 
41
 
42
- multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
43
- tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
44
- model = AutoModel.from_pretrained(multilingual_checkpoint)
45
 
46
- raw_text_file = 'data/processed/shortened_abstracts_hu_2021_09_01.txt'
 
47
  all_sentences = load_raw_sentences(raw_text_file)
48
 
49
- embeddings_file = 'data/processed/shortened_abstracts_hu_2021_09_01_embedded.pt'
50
  all_embeddings = load_embeddings(embeddings_file)
51
 
52
 
53
  st.header('Wikipedia absztrakt kereső')
54
- st.subheader('Search Wikipedia abstracts in Hungarian')
55
- st.caption('Input some search term and see the top-5 most similar wikipedia abstracts')
 
 
 
56
 
57
- input_query = st.text_area("Adjon meg egy tetszőleges kifejezést és a rendszer visszaadja az 5 hozzá legjobban hasonlító Wikipedia absztraktot", value='Hol élnek a bengali tigrisek?')
58
 
59
- if input_query:
60
- query_embedding = calculateEmbeddings([input_query],tokenizer,model)
61
  top_pairs = findTopKMostSimilar(query_embedding, all_embeddings, all_sentences, 5)
62
- st.json(top_pairs)
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModel
3
+ import transformers
4
  import torch
5
  from sentence_transformers import util
6
 
7
+ # explicit no operation hash functions defined, because raw sentences, embedding, model and tokenizer are not going to change
8
+
9
+
10
+ @st.cache(hash_funcs={list: lambda _: None})
11
  def load_raw_sentences(filename):
12
  with open(filename) as f:
13
  return f.readlines()
14
 
15
+ @st.cache(hash_funcs={torch.Tensor: lambda _: None})
16
  def load_embeddings(filename):
17
  with open(filename) as f:
18
  return torch.load(filename,map_location=torch.device('cpu') )
19
 
20
+
21
  #Mean Pooling - Take attention mask into account for correct averaging
22
  def mean_pooling(model_output, attention_mask):
23
  token_embeddings = model_output[0] #First element of model_output contains all token embeddings
 
32
  pairs = []
33
  for idx,score in enumerate(cosine_scores_list):
34
  if idx < len(all_sentences):
35
+ pairs.append({'score': '{:.4f}'.format(score), 'text': all_sentences[idx]})
36
  pairs = sorted(pairs, key=lambda x: x['score'], reverse=True)
37
  return pairs[0:k]
38
 
 
43
  sentence_embeddings = mean_pooling(model_output, tokenized_sentences['attention_mask'])
44
  return sentence_embeddings
45
 
46
+ # explicit no operation hash function, because model and tokenizer are not going to change
47
+ @st.cache(hash_funcs={transformers.models.bert.tokenization_bert_fast.BertTokenizerFast: lambda _: None, transformers.models.bert.modeling_bert.BertModel: lambda _: None})
48
+ def load_model_and_tokenizer():
49
+ multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
50
+ tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
51
+ model = AutoModel.from_pretrained(multilingual_checkpoint)
52
+ print(type(tokenizer))
53
+ print(type(model))
54
+ return model, tokenizer
55
 
 
 
 
56
 
57
+ model,tokenizer = load_model_and_tokenizer();
58
+ raw_text_file = 'data/preprocessed/shortened_abstracts_hu_2021_09_01.txt'
59
  all_sentences = load_raw_sentences(raw_text_file)
60
 
61
+ embeddings_file = 'data/preprocessed/shortened_abstracts_hu_2021_09_01_embedded.pt'
62
  all_embeddings = load_embeddings(embeddings_file)
63
 
64
 
65
  st.header('Wikipedia absztrakt kereső')
66
+ st.subheader('Search Wikipedia abstracts in Hungarian!')
67
+
68
+ st.caption('[HU] Adjon meg egy tetszőleges kifejezést és a rendszer visszaadja az 5 hozzá legjobban hasonlító Wikipedia absztraktot')
69
+ st.caption('[EN] Input some search term and see the top-5 most similar wikipedia abstracts')
70
+
71
 
72
+ text_area_input_query = st.text_area('[HU] Beviteli mező - [EN] Query input',value='Mi Japán fővárosa?')
73
 
74
+ if text_area_input_query:
75
+ query_embedding = calculateEmbeddings([text_area_input_query],tokenizer,model)
76
  top_pairs = findTopKMostSimilar(query_embedding, all_embeddings, all_sentences, 5)
77
+ st.json(top_pairs)
src/data/dbpedia_dump_embeddings.py CHANGED
@@ -31,8 +31,8 @@ dt = datetime.now()
31
  datetime_formatted = dt.strftime('%Y-%m-%d_%H:%M:%S')
32
  batch_size = 1000
33
 
34
- input_text_file = 'data/processed/shortened_abstracts_hu_2021_09_01.txt'
35
- output_embeddings_file = f'data/processed/embeddings_{batch_size}_batches_at_{datetime_formatted}.pkl'
36
 
37
  multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
38
  tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
 
31
  datetime_formatted = dt.strftime('%Y-%m-%d_%H:%M:%S')
32
  batch_size = 1000
33
 
34
+ input_text_file = 'data/preprocessed/shortened_abstracts_hu_2021_09_01.txt'
35
+ output_embeddings_file = f'data/preprocessed/embeddings_{batch_size}_batches_at_{datetime_formatted}.pkl'
36
 
37
  multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
38
  tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
src/data/dbpedia_dump_wiki_text.py CHANGED
@@ -2,14 +2,14 @@ from rdflib import Graph
2
 
3
  # Downloaded from https://databus.dbpedia.org/dbpedia/text/short-abstracts
4
  raw_data_path = 'data/raw/short-abstracts_lang=hu.ttl'
5
- processed_data_path = 'data/processed/shortened_abstracts_hu_2021_09_01.txt'
6
 
7
  g = Graph()
8
  g.parse(raw_data_path, format='turtle')
9
 
10
  i = 0
11
  objects = []
12
- with open(processed_data_path, 'w') as f:
13
  print(len(g))
14
  for subject, predicate, object in g:
15
  objects.append(object.replace(' +/-','').replace('\n',' '))
 
2
 
3
  # Downloaded from https://databus.dbpedia.org/dbpedia/text/short-abstracts
4
  raw_data_path = 'data/raw/short-abstracts_lang=hu.ttl'
5
+ preprocessed_data_path = 'data/preprocessed/shortened_abstracts_hu_2021_09_01.txt'
6
 
7
  g = Graph()
8
  g.parse(raw_data_path, format='turtle')
9
 
10
  i = 0
11
  objects = []
12
+ with open(preprocessed_data_path, 'w') as f:
13
  print(len(g))
14
  for subject, predicate, object in g:
15
  objects.append(object.replace(' +/-','').replace('\n',' '))
src/exploration/serialize_test.py CHANGED
@@ -25,6 +25,6 @@ def loadFromDiskRaw(batch_number, filename='embeddings.pkl'):
25
  count += 1
26
  return stored_data
27
 
28
- output_embeddings_file = 'data/processed/DBpedia_shortened_abstracts_hu_embeddings.pkl'
29
  loadFromDiskRaw(3, output_embeddings_file)
30
 
 
25
  count += 1
26
  return stored_data
27
 
28
+ output_embeddings_file = 'data/preprocessed/DBpedia_shortened_abstracts_hu_embeddings.pkl'
29
  loadFromDiskRaw(3, output_embeddings_file)
30
 
src/features/semantic_retreiver.py CHANGED
@@ -16,7 +16,7 @@ def mean_pooling(model_output, attention_mask):
16
  dt = datetime.now()
17
  datetime_formatted = dt.strftime('%Y-%m-%d_%H:%M:%S')
18
  batch_size = 1000
19
- output_embeddings_file = f'data/processed/embeddings_{batch_size}_batches_at_{datetime_formatted}.pkl'
20
  def saveToDisc(embeddings):
21
  with open(output_embeddings_file, "ab") as f:
22
  pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
@@ -75,7 +75,7 @@ multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-
75
  tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
76
  model = AutoModel.from_pretrained(multilingual_checkpoint)
77
 
78
- raw_text_file = 'data/processed/shortened_abstracts_hu_2021_09_01.txt'
79
 
80
 
81
  concated_sentence_embeddings = None
 
16
  dt = datetime.now()
17
  datetime_formatted = dt.strftime('%Y-%m-%d_%H:%M:%S')
18
  batch_size = 1000
19
+ output_embeddings_file = f'data/preprocessed/embeddings_{batch_size}_batches_at_{datetime_formatted}.pkl'
20
  def saveToDisc(embeddings):
21
  with open(output_embeddings_file, "ab") as f:
22
  pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
 
75
  tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
76
  model = AutoModel.from_pretrained(multilingual_checkpoint)
77
 
78
+ raw_text_file = 'data/preprocessed/shortened_abstracts_hu_2021_09_01.txt'
79
 
80
 
81
  concated_sentence_embeddings = None
src/main_qa.py CHANGED
@@ -34,8 +34,8 @@ multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-
34
  tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
35
  model = AutoModel.from_pretrained(multilingual_checkpoint)
36
 
37
- raw_text_file = 'data/processed/shortened_abstracts_hu_2021_09_01.txt'
38
- embeddings_file = 'data/processed/shortened_abstracts_hu_2021_09_01_embedded.pt'
39
 
40
  all_sentences = load_raw_sentences(raw_text_file)
41
  all_embeddings = torch.load(embeddings_file,map_location=torch.device('cpu') )
 
34
  tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
35
  model = AutoModel.from_pretrained(multilingual_checkpoint)
36
 
37
+ raw_text_file = 'data/preprocessed/shortened_abstracts_hu_2021_09_01.txt'
38
+ embeddings_file = 'data/preprocessed/shortened_abstracts_hu_2021_09_01_embedded.pt'
39
 
40
  all_sentences = load_raw_sentences(raw_text_file)
41
  all_embeddings = torch.load(embeddings_file,map_location=torch.device('cpu') )