Spaces:
Runtime error
Runtime error
endre sukosd
commited on
Commit
•
eb6656d
1
Parent(s):
31f8e76
Streamlit app fixup
Browse files- README.md +0 -7
- notebooks/QA_retrieval_precalculate_embeddings.ipynb +1 -1
- src/app.py +29 -14
- src/data/dbpedia_dump_embeddings.py +2 -2
- src/data/dbpedia_dump_wiki_text.py +2 -2
- src/exploration/serialize_test.py +1 -1
- src/features/semantic_retreiver.py +2 -2
- src/main_qa.py +2 -2
README.md
CHANGED
@@ -1,18 +1,11 @@
|
|
1 |
---
|
2 |
title: SemanticSearch HU
|
3 |
-
|
4 |
emoji: 💻
|
5 |
-
|
6 |
colorFrom: green
|
7 |
-
|
8 |
colorTo: white
|
9 |
-
|
10 |
sdk: streamlit
|
11 |
-
|
12 |
app_file: src/app.py
|
13 |
-
|
14 |
pinned: false
|
15 |
-
|
16 |
---
|
17 |
|
18 |
# Huggingface Course Project - 2021 November
|
|
|
1 |
---
|
2 |
title: SemanticSearch HU
|
|
|
3 |
emoji: 💻
|
|
|
4 |
colorFrom: green
|
|
|
5 |
colorTo: white
|
|
|
6 |
sdk: streamlit
|
|
|
7 |
app_file: src/app.py
|
|
|
8 |
pinned: false
|
|
|
9 |
---
|
10 |
|
11 |
# Huggingface Course Project - 2021 November
|
notebooks/QA_retrieval_precalculate_embeddings.ipynb
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"QA_retrieval_huggingface_couser_2021_Nov.ipynb","provenance":[],"collapsed_sections":[],"mount_file_id":"1e_NcpgIuSh8rfI_Xf16ltcybK8TbgJWB","authorship_tag":"ABX9TyN3TvKBRyS+wRVSLWNFgC+f"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","metadata":{"id":"GI4Sz98ItJW7"},"source":["# TPU\n","# !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.10-cp37-cp37m-linux_x86_64.whl"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"97-OsdFhlD20","executionInfo":{"status":"ok","timestamp":1637680969592,"user_tz":-60,"elapsed":3348,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}},"outputId":"c47a98a7-f016-4a4f-827b-edc9229c5eca"},"source":["!pip install transformers sentence_transformers"],"execution_count":9,"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.12.5)\n","Requirement already satisfied: sentence_transformers in /usr/local/lib/python3.7/dist-packages (2.1.0)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (6.0)\n","Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n","Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.10.3)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.1.2)\n","Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers) (0.0.46)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.4.0)\n","Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.8.2)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.62.3)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (3.10.0.2)\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.6)\n","Requirement already satisfied: nltk in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (3.2.5)\n","Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.0.1)\n","Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.4.1)\n","Requirement already satisfied: sentencepiece in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (0.1.96)\n","Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (0.11.1+cu111)\n","Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.10.0+cu111)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.6.0)\n","Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from nltk->sentence_transformers) (1.15.0)\n","Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.10.8)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n","Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n","Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)\n","Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.1.0)\n","Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->sentence_transformers) (3.0.0)\n","Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision->sentence_transformers) (7.1.2)\n"]}]},{"cell_type":"code","metadata":{"id":"3-jkyQkdkdPQ","executionInfo":{"status":"ok","timestamp":1637680970023,"user_tz":-60,"elapsed":3,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}}},"source":["from transformers import AutoTokenizer, AutoModel\n","import torch\n","import pickle\n","from sentence_transformers import util\n","from datetime import datetime"],"execution_count":10,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"kA2h5mH8m-n8","executionInfo":{"status":"ok","timestamp":1637654036646,"user_tz":-60,"elapsed":26589,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}},"outputId":"88fcd97f-276c-4f70-de60-d1c5c9810443"},"source":["from google.colab import drive\n","drive.mount('/content/drive')\n","#drive.mount('/content/drive', force_remount=True)"],"execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}]},{"cell_type":"markdown","metadata":{"id":"b8SkQGWuB1z7"},"source":["# Load pretrained \n","\n","- multilingual sentence transformers from checkpoint\n","- tokenizer from checkpoint"]},{"cell_type":"code","metadata":{"id":"1R83LLVAk98K","executionInfo":{"status":"ok","timestamp":1637655426545,"user_tz":-60,"elapsed":6237,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}}},"source":["multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'\n","tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)\n","model = AutoModel.from_pretrained(multilingual_checkpoint)"],"execution_count":3,"outputs":[]},{"cell_type":"code","metadata":{"id":"wcdik3tQpkyi"},"source":["# GPU\n","device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n","model.to(device)\n","print(device)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"-YzAkemLsrC9"},"source":["# TPU\n","# unfortunately incompatible wheel package for pytorch-xla 1.10 version\n","#import torch_xla.core.xla_model as xm\n","#device = xm.xla_device()\n","#print(device)\n","#pip list | grep torch"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"dfeEQJOglxdw","executionInfo":{"status":"ok","timestamp":1637682096594,"user_tz":-60,"elapsed":362,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}}},"source":["#Mean Pooling - Take attention mask into account for correct averaging\n","def mean_pooling(model_output, attention_mask):\n"," token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n"," input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n"," sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)\n"," sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n"," return sum_embeddings / sum_mask\n","\n","def calculateEmbeddings(sentences,tokenizer,model,device=\"cpu\"):\n"," tokenized_sentences = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')\n"," tokenized_sentences.to(device)\n"," with torch.no_grad():\n"," model_output = model(**tokenized_sentences)\n"," sentence_embeddings = mean_pooling(model_output, tokenized_sentences['attention_mask'])\n"," del tokenized_sentences\n"," torch.cuda.empty_cache()\n"," return sentence_embeddings\n","\n","def findTopKMostSimilar(query_embedding, embeddings, k):\n"," cosine_scores = util.pytorch_cos_sim(query_embedding, embeddings)\n"," cosine_scores_list = cosine_scores.squeeze().tolist()\n"," pairs = []\n"," for idx,score in enumerate(cosine_scores_list):\n"," pairs.append({'index': idx, 'score': score})\n"," pairs = sorted(pairs, key=lambda x: x['score'], reverse=True)\n"," return pairs[0:k]\n","\n","def saveToDisc(embeddings, output_filename):\n"," with open(output_filename, \"ab\") as f:\n"," pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)"],"execution_count":23,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MddjkKfMCH81"},"source":["# Create sentence embeddings\n","\n","\n","* Load sentences from raw text file\n","* Precalculate in batches of 1000, to avoid running out of memory\n","* Save to disc/files incrementally, to be able to reuse later (in total 5 files of 100.000 embedding each)\n","\n"]},{"cell_type":"code","metadata":{"id":"yfOsCAVImIAl"},"source":["batch_size = 1000\n","\n","raw_text_file = '/content/drive/MyDrive/huggingface/shortened_abstracts_hu_2021_09_01.txt'\n","datetime_formatted = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')\n","output_embeddings_file_batched = f'/content/drive/MyDrive/huggingface/embeddings_{batch_size}_batches_at_{datetime_formatted}.pkl'\n","output_embeddings_file = f'/content/drive/MyDrive/huggingface/embeddings_at_{datetime_formatted}.pkl'\n","\n","print(datetime.now())\n","concated_sentence_embeddings = None\n","all_sentences = []\n","line = 'init'\n","total_read = 0\n","total_read_limit = 500000\n","skip_index = 400000\n","with open(raw_text_file) as f:\n"," while line and total_read < total_read_limit:\n"," count = 0\n"," sentence_batch = []\n"," while line and count < batch_size:\n"," line = f.readline()\n"," sentence_batch.append(line)\n"," count += 1\n"," \n"," all_sentences.extend(sentence_batch)\n"," \n"," if total_read >= skip_index:\n"," sentence_embeddings = calculateEmbeddings(sentence_batch,tokenizer,model,device)\n"," if concated_sentence_embeddings == None:\n"," concated_sentence_embeddings = sentence_embeddings\n"," else:\n"," concated_sentence_embeddings = torch.cat([concated_sentence_embeddings, sentence_embeddings], dim=0)\n"," print(concated_sentence_embeddings.size())\n"," saveToDisc(sentence_embeddings,output_embeddings_file_batched)\n"," total_read += count\n","print(datetime.now())"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1rGQc9GRCuNy"},"source":["# Test: Query embeddings"]},{"cell_type":"code","metadata":{"id":"FT7CwpM0Bwhi"},"source":["query_embedding = calculateEmbeddings(['Melyik a legnépesebb város a világon?'],tokenizer,model,device)\n","top_pairs = findTopKMostSimilar(query_embedding, concated_sentence_embeddings, 5)\n","\n","for pair in top_pairs:\n"," i = pair['index']\n"," score = pair['score']\n"," print(\"{} \\t\\t Score: {:.4f}\".format(all_sentences[skip_index+i], score))"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"6Hdu_5FiDYJr"},"source":["# Test: Load pre-calculated embeddings\n","\n","* Load embedding from files and stitch them together\n","* Save into one file\n"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"gkWt0Uj_Ddsp","executionInfo":{"status":"ok","timestamp":1637682006152,"user_tz":-60,"elapsed":1722,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}},"outputId":"1921456e-1fd6-4218-9ebb-cbe503f402b1"},"source":["def concatTensors(new_tensor, acc_tensor='None'):\n"," if acc_tensor == None:\n"," acc_tensor = new_tensor\n"," else:\n"," acc_tensor = torch.cat([acc_tensor, new_tensor], dim=0)\n"," return acc_tensor\n","\n","def loadFromDisc(batch_size, number_of_batches, filename):\n"," concated_sentence_embeddings = None\n"," count = 0\n"," batches = 0\n"," with open(filename, \"rb\") as f:\n"," loaded_embeddings = torch.empty([batch_size])\n"," while count < number_of_batches and loaded_embeddings.size()[0]==batch_size:\n"," loaded_embeddings = pickle.load(f)\n"," count += 1\n"," concated_sentence_embeddings = concatTensors(loaded_embeddings,concated_sentence_embeddings)\n"," print(f'Read file using {count} number of read+unpickle operations')\n"," print(concated_sentence_embeddings.size())\n"," return concated_sentence_embeddings\n","\n","\n","output_embeddings_file = 'data/processed/DBpedia_shortened_abstracts_hu_embeddings.pkl'\n","\n","embeddings_files = [\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:17:17.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:28:46.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:40:54.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:56:26.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_09:31:47.pkl'\n","]\n","\n","all_embeddings = None\n","for idx,emb_file in enumerate(embeddings_files):\n"," print(f'Processing file {idx}')\n"," file_embeddings = loadFromDisc(1000, 100, emb_file)\n"," all_embeddings = concatTensors(file_embeddings,all_embeddings)\n","\n","print(all_embeddings.size())"],"execution_count":20,"outputs":[{"output_type":"stream","name":"stdout","text":["Processing file 0\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 1\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 2\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 3\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 4\n","Read file using 67 number of read+unpickle operations\n","torch.Size([66529, 384])\n","torch.Size([466529, 384])\n"]}]},{"cell_type":"code","metadata":{"id":"M_8RHpNnIU7o","executionInfo":{"status":"ok","timestamp":1637683739951,"user_tz":-60,"elapsed":384,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"}}},"source":["all_embeddings_output_file = '/content/drive/MyDrive/huggingface/shortened_abstracts_hu_2021_09_01_embedded.pt'\n","#saveToDisc(all_embeddings, all_embeddings_output_file)\n","torch.save(all_embeddings,all_embeddings_output_file)"],"execution_count":28,"outputs":[]},{"cell_type":"code","metadata":{"id":"LYCwyDpMjsXg"},"source":[""],"execution_count":null,"outputs":[]}]}
|
|
|
1 |
+
{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"GI4Sz98ItJW7"},"outputs":[],"source":["# TPU\n","# !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.10-cp37-cp37m-linux_x86_64.whl"]},{"cell_type":"code","execution_count":9,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":3348,"status":"ok","timestamp":1637680969592,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"97-OsdFhlD20","outputId":"c47a98a7-f016-4a4f-827b-edc9229c5eca"},"outputs":[{"name":"stdout","output_type":"stream","text":["Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.12.5)\n","Requirement already satisfied: sentence_transformers in /usr/local/lib/python3.7/dist-packages (2.1.0)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (6.0)\n","Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n","Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.10.3)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.1.2)\n","Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers) (0.0.46)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.4.0)\n","Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.8.2)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.62.3)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (3.10.0.2)\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.6)\n","Requirement already satisfied: nltk in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (3.2.5)\n","Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.0.1)\n","Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.4.1)\n","Requirement already satisfied: sentencepiece in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (0.1.96)\n","Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (0.11.1+cu111)\n","Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from sentence_transformers) (1.10.0+cu111)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.6.0)\n","Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from nltk->sentence_transformers) (1.15.0)\n","Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.10.8)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n","Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n","Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)\n","Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.1.0)\n","Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->sentence_transformers) (3.0.0)\n","Requirement already satisfied: pillow!=8.3.0,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision->sentence_transformers) (7.1.2)\n"]}],"source":["!pip install transformers sentence_transformers"]},{"cell_type":"code","execution_count":10,"metadata":{"executionInfo":{"elapsed":3,"status":"ok","timestamp":1637680970023,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"3-jkyQkdkdPQ"},"outputs":[],"source":["from transformers import AutoTokenizer, AutoModel\n","import torch\n","import pickle\n","from sentence_transformers import util\n","from datetime import datetime"]},{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":26589,"status":"ok","timestamp":1637654036646,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"kA2h5mH8m-n8","outputId":"88fcd97f-276c-4f70-de60-d1c5c9810443"},"outputs":[{"name":"stdout","output_type":"stream","text":["Mounted at /content/drive\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')\n","#drive.mount('/content/drive', force_remount=True)"]},{"cell_type":"markdown","metadata":{"id":"b8SkQGWuB1z7"},"source":["# Load pretrained \n","\n","- multilingual sentence transformers from checkpoint\n","- tokenizer from checkpoint"]},{"cell_type":"code","execution_count":3,"metadata":{"executionInfo":{"elapsed":6237,"status":"ok","timestamp":1637655426545,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"1R83LLVAk98K"},"outputs":[],"source":["multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'\n","tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)\n","model = AutoModel.from_pretrained(multilingual_checkpoint)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"wcdik3tQpkyi"},"outputs":[],"source":["# GPU\n","device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n","model.to(device)\n","print(device)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"-YzAkemLsrC9"},"outputs":[],"source":["# TPU\n","# unfortunately incompatible wheel package for pytorch-xla 1.10 version\n","#import torch_xla.core.xla_model as xm\n","#device = xm.xla_device()\n","#print(device)\n","#pip list | grep torch"]},{"cell_type":"code","execution_count":23,"metadata":{"executionInfo":{"elapsed":362,"status":"ok","timestamp":1637682096594,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"dfeEQJOglxdw"},"outputs":[],"source":["#Mean Pooling - Take attention mask into account for correct averaging\n","def mean_pooling(model_output, attention_mask):\n"," token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n"," input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n"," sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)\n"," sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n"," return sum_embeddings / sum_mask\n","\n","def calculateEmbeddings(sentences,tokenizer,model,device=\"cpu\"):\n"," tokenized_sentences = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')\n"," tokenized_sentences.to(device)\n"," with torch.no_grad():\n"," model_output = model(**tokenized_sentences)\n"," sentence_embeddings = mean_pooling(model_output, tokenized_sentences['attention_mask'])\n"," del tokenized_sentences\n"," torch.cuda.empty_cache()\n"," return sentence_embeddings\n","\n","def findTopKMostSimilar(query_embedding, embeddings, k):\n"," cosine_scores = util.pytorch_cos_sim(query_embedding, embeddings)\n"," cosine_scores_list = cosine_scores.squeeze().tolist()\n"," pairs = []\n"," for idx,score in enumerate(cosine_scores_list):\n"," pairs.append({'index': idx, 'score': score})\n"," pairs = sorted(pairs, key=lambda x: x['score'], reverse=True)\n"," return pairs[0:k]\n","\n","def saveToDisc(embeddings, output_filename):\n"," with open(output_filename, \"ab\") as f:\n"," pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)"]},{"cell_type":"markdown","metadata":{"id":"MddjkKfMCH81"},"source":["# Create sentence embeddings\n","\n","\n","* Load sentences from raw text file\n","* Precalculate in batches of 1000, to avoid running out of memory\n","* Save to disc/files incrementally, to be able to reuse later (in total 5 files of 100.000 embedding each)\n","\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"yfOsCAVImIAl"},"outputs":[],"source":["batch_size = 1000\n","\n","raw_text_file = '/content/drive/MyDrive/huggingface/shortened_abstracts_hu_2021_09_01.txt'\n","datetime_formatted = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')\n","output_embeddings_file_batched = f'/content/drive/MyDrive/huggingface/embeddings_{batch_size}_batches_at_{datetime_formatted}.pkl'\n","output_embeddings_file = f'/content/drive/MyDrive/huggingface/embeddings_at_{datetime_formatted}.pkl'\n","\n","print(datetime.now())\n","concated_sentence_embeddings = None\n","all_sentences = []\n","line = 'init'\n","total_read = 0\n","total_read_limit = 500000\n","skip_index = 400000\n","with open(raw_text_file) as f:\n"," while line and total_read < total_read_limit:\n"," count = 0\n"," sentence_batch = []\n"," while line and count < batch_size:\n"," line = f.readline()\n"," sentence_batch.append(line)\n"," count += 1\n"," \n"," all_sentences.extend(sentence_batch)\n"," \n"," if total_read >= skip_index:\n"," sentence_embeddings = calculateEmbeddings(sentence_batch,tokenizer,model,device)\n"," if concated_sentence_embeddings == None:\n"," concated_sentence_embeddings = sentence_embeddings\n"," else:\n"," concated_sentence_embeddings = torch.cat([concated_sentence_embeddings, sentence_embeddings], dim=0)\n"," print(concated_sentence_embeddings.size())\n"," saveToDisc(sentence_embeddings,output_embeddings_file_batched)\n"," total_read += count\n","print(datetime.now())"]},{"cell_type":"markdown","metadata":{"id":"1rGQc9GRCuNy"},"source":["# Test: Query embeddings"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"FT7CwpM0Bwhi"},"outputs":[],"source":["query_embedding = calculateEmbeddings(['Melyik a legnépesebb város a világon?'],tokenizer,model,device)\n","top_pairs = findTopKMostSimilar(query_embedding, concated_sentence_embeddings, 5)\n","\n","for pair in top_pairs:\n"," i = pair['index']\n"," score = pair['score']\n"," print(\"{} \\t\\t Score: {:.4f}\".format(all_sentences[skip_index+i], score))"]},{"cell_type":"markdown","metadata":{"id":"6Hdu_5FiDYJr"},"source":["# Test: Load pre-calculated embeddings\n","\n","* Load embedding from files and stitch them together\n","* Save into one file\n"]},{"cell_type":"code","execution_count":20,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":1722,"status":"ok","timestamp":1637682006152,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"gkWt0Uj_Ddsp","outputId":"1921456e-1fd6-4218-9ebb-cbe503f402b1"},"outputs":[{"name":"stdout","output_type":"stream","text":["Processing file 0\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 1\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 2\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 3\n","Read file using 100 number of read+unpickle operations\n","torch.Size([100000, 384])\n","Processing file 4\n","Read file using 67 number of read+unpickle operations\n","torch.Size([66529, 384])\n","torch.Size([466529, 384])\n"]}],"source":["def concatTensors(new_tensor, acc_tensor='None'):\n"," if acc_tensor == None:\n"," acc_tensor = new_tensor\n"," else:\n"," acc_tensor = torch.cat([acc_tensor, new_tensor], dim=0)\n"," return acc_tensor\n","\n","def loadFromDisc(batch_size, number_of_batches, filename):\n"," concated_sentence_embeddings = None\n"," count = 0\n"," batches = 0\n"," with open(filename, \"rb\") as f:\n"," loaded_embeddings = torch.empty([batch_size])\n"," while count < number_of_batches and loaded_embeddings.size()[0]==batch_size:\n"," loaded_embeddings = pickle.load(f)\n"," count += 1\n"," concated_sentence_embeddings = concatTensors(loaded_embeddings,concated_sentence_embeddings)\n"," print(f'Read file using {count} number of read+unpickle operations')\n"," print(concated_sentence_embeddings.size())\n"," return concated_sentence_embeddings\n","\n","\n","output_embeddings_file = 'data/preprocessed/DBpedia_shortened_abstracts_hu_embeddings.pkl'\n","\n","embeddings_files = [\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:17:17.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:28:46.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:40:54.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_08:56:26.pkl',\n"," '/content/drive/MyDrive/huggingface/embeddings_1000_batches_at_2021-11-23_09:31:47.pkl'\n","]\n","\n","all_embeddings = None\n","for idx,emb_file in enumerate(embeddings_files):\n"," print(f'Processing file {idx}')\n"," file_embeddings = loadFromDisc(1000, 100, emb_file)\n"," all_embeddings = concatTensors(file_embeddings,all_embeddings)\n","\n","print(all_embeddings.size())"]},{"cell_type":"code","execution_count":28,"metadata":{"executionInfo":{"elapsed":384,"status":"ok","timestamp":1637683739951,"user":{"displayName":"Sukosd Endre","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GioD4JjUyxYNK5t_w13NsB1TIlQ1P_x6Xj99-re5w=s64","userId":"02963673169135048018"},"user_tz":-60},"id":"M_8RHpNnIU7o"},"outputs":[],"source":["all_embeddings_output_file = '/content/drive/MyDrive/huggingface/shortened_abstracts_hu_2021_09_01_embedded.pt'\n","#saveToDisc(all_embeddings, all_embeddings_output_file)\n","torch.save(all_embeddings,all_embeddings_output_file)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"LYCwyDpMjsXg"},"outputs":[],"source":[]}],"metadata":{"accelerator":"GPU","colab":{"authorship_tag":"ABX9TyN3TvKBRyS+wRVSLWNFgC+f","collapsed_sections":[],"mount_file_id":"1e_NcpgIuSh8rfI_Xf16ltcybK8TbgJWB","name":"QA_retrieval_huggingface_couser_2021_Nov.ipynb","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}
|
src/app.py
CHANGED
@@ -1,18 +1,23 @@
|
|
1 |
import streamlit as st
|
2 |
from transformers import AutoTokenizer, AutoModel
|
|
|
3 |
import torch
|
4 |
from sentence_transformers import util
|
5 |
|
6 |
-
|
|
|
|
|
|
|
7 |
def load_raw_sentences(filename):
|
8 |
with open(filename) as f:
|
9 |
return f.readlines()
|
10 |
|
11 |
-
@st.cache
|
12 |
def load_embeddings(filename):
|
13 |
with open(filename) as f:
|
14 |
return torch.load(filename,map_location=torch.device('cpu') )
|
15 |
|
|
|
16 |
#Mean Pooling - Take attention mask into account for correct averaging
|
17 |
def mean_pooling(model_output, attention_mask):
|
18 |
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
@@ -27,7 +32,7 @@ def findTopKMostSimilar(query_embedding, embeddings, all_sentences, k):
|
|
27 |
pairs = []
|
28 |
for idx,score in enumerate(cosine_scores_list):
|
29 |
if idx < len(all_sentences):
|
30 |
-
pairs.append({'score': score, 'text': all_sentences[idx]})
|
31 |
pairs = sorted(pairs, key=lambda x: x['score'], reverse=True)
|
32 |
return pairs[0:k]
|
33 |
|
@@ -38,25 +43,35 @@ def calculateEmbeddings(sentences,tokenizer,model):
|
|
38 |
sentence_embeddings = mean_pooling(model_output, tokenized_sentences['attention_mask'])
|
39 |
return sentence_embeddings
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
|
43 |
-
tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
|
44 |
-
model = AutoModel.from_pretrained(multilingual_checkpoint)
|
45 |
|
46 |
-
|
|
|
47 |
all_sentences = load_raw_sentences(raw_text_file)
|
48 |
|
49 |
-
embeddings_file = 'data/
|
50 |
all_embeddings = load_embeddings(embeddings_file)
|
51 |
|
52 |
|
53 |
st.header('Wikipedia absztrakt kereső')
|
54 |
-
st.subheader('Search Wikipedia abstracts in Hungarian')
|
55 |
-
|
|
|
|
|
|
|
56 |
|
57 |
-
|
58 |
|
59 |
-
if
|
60 |
-
query_embedding = calculateEmbeddings([
|
61 |
top_pairs = findTopKMostSimilar(query_embedding, all_embeddings, all_sentences, 5)
|
62 |
-
st.json(top_pairs)
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import AutoTokenizer, AutoModel
|
3 |
+
import transformers
|
4 |
import torch
|
5 |
from sentence_transformers import util
|
6 |
|
7 |
+
# explicit no operation hash functions defined, because raw sentences, embedding, model and tokenizer are not going to change
|
8 |
+
|
9 |
+
|
10 |
+
@st.cache(hash_funcs={list: lambda _: None})
|
11 |
def load_raw_sentences(filename):
|
12 |
with open(filename) as f:
|
13 |
return f.readlines()
|
14 |
|
15 |
+
@st.cache(hash_funcs={torch.Tensor: lambda _: None})
|
16 |
def load_embeddings(filename):
|
17 |
with open(filename) as f:
|
18 |
return torch.load(filename,map_location=torch.device('cpu') )
|
19 |
|
20 |
+
|
21 |
#Mean Pooling - Take attention mask into account for correct averaging
|
22 |
def mean_pooling(model_output, attention_mask):
|
23 |
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
|
|
32 |
pairs = []
|
33 |
for idx,score in enumerate(cosine_scores_list):
|
34 |
if idx < len(all_sentences):
|
35 |
+
pairs.append({'score': '{:.4f}'.format(score), 'text': all_sentences[idx]})
|
36 |
pairs = sorted(pairs, key=lambda x: x['score'], reverse=True)
|
37 |
return pairs[0:k]
|
38 |
|
|
|
43 |
sentence_embeddings = mean_pooling(model_output, tokenized_sentences['attention_mask'])
|
44 |
return sentence_embeddings
|
45 |
|
46 |
+
# explicit no operation hash function, because model and tokenizer are not going to change
|
47 |
+
@st.cache(hash_funcs={transformers.models.bert.tokenization_bert_fast.BertTokenizerFast: lambda _: None, transformers.models.bert.modeling_bert.BertModel: lambda _: None})
|
48 |
+
def load_model_and_tokenizer():
|
49 |
+
multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
|
50 |
+
tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
|
51 |
+
model = AutoModel.from_pretrained(multilingual_checkpoint)
|
52 |
+
print(type(tokenizer))
|
53 |
+
print(type(model))
|
54 |
+
return model, tokenizer
|
55 |
|
|
|
|
|
|
|
56 |
|
57 |
+
model,tokenizer = load_model_and_tokenizer();
|
58 |
+
raw_text_file = 'data/preprocessed/shortened_abstracts_hu_2021_09_01.txt'
|
59 |
all_sentences = load_raw_sentences(raw_text_file)
|
60 |
|
61 |
+
embeddings_file = 'data/preprocessed/shortened_abstracts_hu_2021_09_01_embedded.pt'
|
62 |
all_embeddings = load_embeddings(embeddings_file)
|
63 |
|
64 |
|
65 |
st.header('Wikipedia absztrakt kereső')
|
66 |
+
st.subheader('Search Wikipedia abstracts in Hungarian!')
|
67 |
+
|
68 |
+
st.caption('[HU] Adjon meg egy tetszőleges kifejezést és a rendszer visszaadja az 5 hozzá legjobban hasonlító Wikipedia absztraktot')
|
69 |
+
st.caption('[EN] Input some search term and see the top-5 most similar wikipedia abstracts')
|
70 |
+
|
71 |
|
72 |
+
text_area_input_query = st.text_area('[HU] Beviteli mező - [EN] Query input',value='Mi Japán fővárosa?')
|
73 |
|
74 |
+
if text_area_input_query:
|
75 |
+
query_embedding = calculateEmbeddings([text_area_input_query],tokenizer,model)
|
76 |
top_pairs = findTopKMostSimilar(query_embedding, all_embeddings, all_sentences, 5)
|
77 |
+
st.json(top_pairs)
|
src/data/dbpedia_dump_embeddings.py
CHANGED
@@ -31,8 +31,8 @@ dt = datetime.now()
|
|
31 |
datetime_formatted = dt.strftime('%Y-%m-%d_%H:%M:%S')
|
32 |
batch_size = 1000
|
33 |
|
34 |
-
input_text_file = 'data/
|
35 |
-
output_embeddings_file = f'data/
|
36 |
|
37 |
multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
|
38 |
tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
|
|
|
31 |
datetime_formatted = dt.strftime('%Y-%m-%d_%H:%M:%S')
|
32 |
batch_size = 1000
|
33 |
|
34 |
+
input_text_file = 'data/preprocessed/shortened_abstracts_hu_2021_09_01.txt'
|
35 |
+
output_embeddings_file = f'data/preprocessed/embeddings_{batch_size}_batches_at_{datetime_formatted}.pkl'
|
36 |
|
37 |
multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
|
38 |
tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
|
src/data/dbpedia_dump_wiki_text.py
CHANGED
@@ -2,14 +2,14 @@ from rdflib import Graph
|
|
2 |
|
3 |
# Downloaded from https://databus.dbpedia.org/dbpedia/text/short-abstracts
|
4 |
raw_data_path = 'data/raw/short-abstracts_lang=hu.ttl'
|
5 |
-
|
6 |
|
7 |
g = Graph()
|
8 |
g.parse(raw_data_path, format='turtle')
|
9 |
|
10 |
i = 0
|
11 |
objects = []
|
12 |
-
with open(
|
13 |
print(len(g))
|
14 |
for subject, predicate, object in g:
|
15 |
objects.append(object.replace(' +/-','').replace('\n',' '))
|
|
|
2 |
|
3 |
# Downloaded from https://databus.dbpedia.org/dbpedia/text/short-abstracts
|
4 |
raw_data_path = 'data/raw/short-abstracts_lang=hu.ttl'
|
5 |
+
preprocessed_data_path = 'data/preprocessed/shortened_abstracts_hu_2021_09_01.txt'
|
6 |
|
7 |
g = Graph()
|
8 |
g.parse(raw_data_path, format='turtle')
|
9 |
|
10 |
i = 0
|
11 |
objects = []
|
12 |
+
with open(preprocessed_data_path, 'w') as f:
|
13 |
print(len(g))
|
14 |
for subject, predicate, object in g:
|
15 |
objects.append(object.replace(' +/-','').replace('\n',' '))
|
src/exploration/serialize_test.py
CHANGED
@@ -25,6 +25,6 @@ def loadFromDiskRaw(batch_number, filename='embeddings.pkl'):
|
|
25 |
count += 1
|
26 |
return stored_data
|
27 |
|
28 |
-
output_embeddings_file = 'data/
|
29 |
loadFromDiskRaw(3, output_embeddings_file)
|
30 |
|
|
|
25 |
count += 1
|
26 |
return stored_data
|
27 |
|
28 |
+
output_embeddings_file = 'data/preprocessed/DBpedia_shortened_abstracts_hu_embeddings.pkl'
|
29 |
loadFromDiskRaw(3, output_embeddings_file)
|
30 |
|
src/features/semantic_retreiver.py
CHANGED
@@ -16,7 +16,7 @@ def mean_pooling(model_output, attention_mask):
|
|
16 |
dt = datetime.now()
|
17 |
datetime_formatted = dt.strftime('%Y-%m-%d_%H:%M:%S')
|
18 |
batch_size = 1000
|
19 |
-
output_embeddings_file = f'data/
|
20 |
def saveToDisc(embeddings):
|
21 |
with open(output_embeddings_file, "ab") as f:
|
22 |
pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
|
@@ -75,7 +75,7 @@ multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-
|
|
75 |
tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
|
76 |
model = AutoModel.from_pretrained(multilingual_checkpoint)
|
77 |
|
78 |
-
raw_text_file = 'data/
|
79 |
|
80 |
|
81 |
concated_sentence_embeddings = None
|
|
|
16 |
dt = datetime.now()
|
17 |
datetime_formatted = dt.strftime('%Y-%m-%d_%H:%M:%S')
|
18 |
batch_size = 1000
|
19 |
+
output_embeddings_file = f'data/preprocessed/embeddings_{batch_size}_batches_at_{datetime_formatted}.pkl'
|
20 |
def saveToDisc(embeddings):
|
21 |
with open(output_embeddings_file, "ab") as f:
|
22 |
pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
75 |
tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
|
76 |
model = AutoModel.from_pretrained(multilingual_checkpoint)
|
77 |
|
78 |
+
raw_text_file = 'data/preprocessed/shortened_abstracts_hu_2021_09_01.txt'
|
79 |
|
80 |
|
81 |
concated_sentence_embeddings = None
|
src/main_qa.py
CHANGED
@@ -34,8 +34,8 @@ multilingual_checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-
|
|
34 |
tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
|
35 |
model = AutoModel.from_pretrained(multilingual_checkpoint)
|
36 |
|
37 |
-
raw_text_file = 'data/
|
38 |
-
embeddings_file = 'data/
|
39 |
|
40 |
all_sentences = load_raw_sentences(raw_text_file)
|
41 |
all_embeddings = torch.load(embeddings_file,map_location=torch.device('cpu') )
|
|
|
34 |
tokenizer = AutoTokenizer.from_pretrained(multilingual_checkpoint)
|
35 |
model = AutoModel.from_pretrained(multilingual_checkpoint)
|
36 |
|
37 |
+
raw_text_file = 'data/preprocessed/shortened_abstracts_hu_2021_09_01.txt'
|
38 |
+
embeddings_file = 'data/preprocessed/shortened_abstracts_hu_2021_09_01_embedded.pt'
|
39 |
|
40 |
all_sentences = load_raw_sentences(raw_text_file)
|
41 |
all_embeddings = torch.load(embeddings_file,map_location=torch.device('cpu') )
|