{ "cells": [ { "cell_type": "code", "execution_count": 3, "metadata": { "id": "zq-6s7LbPnKH" }, "outputs": [], "source": [ "import nest_asyncio\n", "\n", "nest_asyncio.apply()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ulZIBA1ZoSsV", "outputId": "afcdef9b-9586-4650-f828-3f794bd185d6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/292.8 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m292.8/292.8 kB\u001b[0m \u001b[31m9.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h" ] } ], "source": [ "!pip install -qU langchain_openai langchain_huggingface langchain_core==0.2.38 langchain langchain_community langchain-text-splitters pypdf" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3GFD7B-tOCrx", "outputId": "5e701205-1d35-48f2-8217-8943a7e94921" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/981.5 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m981.5/981.5 kB\u001b[0m \u001b[31m32.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m21.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m472.8/472.8 kB\u001b[0m \u001b[31m18.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.5/1.5 MB\u001b[0m \u001b[31m49.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m27.0/27.0 MB\u001b[0m \u001b[31m36.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m159.9/159.9 kB\u001b[0m \u001b[31m12.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m431.4/431.4 kB\u001b[0m \u001b[31m23.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m274.7/274.7 kB\u001b[0m \u001b[31m15.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m46.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m45.3/45.3 kB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m82.7/82.7 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m54.5/54.5 kB\u001b[0m \u001b[31m4.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Building wheel for langdetect (setup.py) ... \u001b[?25l\u001b[?25hdone\n" ] } ], "source": [ "!pip install -qU faiss-cpu unstructured==0.15.7 python-pptx==1.0.2 nltk==3.9.1" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wA_mlurVqtrp", "outputId": "a6a166b9-5a5e-4840-d620-3243489c3fd7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Enter Your OpenAI API Key: ··········\n" ] } ], "source": [ "import os\n", "import getpass\n", "\n", "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"Enter Your OpenAI API Key: \")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "qzAdFsClBJ6v" }, "outputs": [], "source": [ "pdf_paths = [\"AI_Risk_Management_Framework.pdf\",\n", "\"Blueprint-for-an-AI-Bill-of-Rights.pdf\"]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "DRi4ZSJLBY9T" }, "outputs": [], "source": [ "from langchain_community.document_loaders import PyPDFLoader\n", "\n", "pdf_documents = []\n", "for pdf_path in pdf_paths:\n", " loader = PyPDFLoader(pdf_path)\n", " pdf_documents.extend(loader.load())" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "NsPrOOqXOsNX" }, "outputs": [], "source": [ "from langchain_text_splitters import RecursiveCharacterTextSplitter\n", "\n", "text_splitter = RecursiveCharacterTextSplitter(\n", " chunk_size = 750,\n", " chunk_overlap = 20,\n", " length_function = len\n", ")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "OMYPX6N6Os8M" }, "outputs": [], "source": [ "training_documents = text_splitter.split_documents(pdf_documents)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "PAozuMoNOvnp", "outputId": "57d60a5a-e6d6-494c-8237-546db0a27210" }, "outputs": [ { "data": { "text/plain": [ "640" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(training_documents)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "AwyIForybIpo" }, "outputs": [], "source": [ "import uuid\n", "\n", "id_set = set()\n", "\n", "for document in training_documents:\n", " id = str(uuid.uuid4())\n", " while id in id_set:\n", " id = uuid.uuid4()\n", " id_set.add(id)\n", " document.metadata[\"id\"] = id" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "MTS4GTSEcnG4" }, "outputs": [], "source": [ "training_split_documents = training_documents[:300]\n", "val_split_documents = training_documents[300:350]\n", "test_split_documents = training_documents[350:400]" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "_EWfmIscMrvg" }, "outputs": [], "source": [ "from langchain_openai import ChatOpenAI\n", "\n", "qa_chat_model = ChatOpenAI(\n", " model=\"gpt-4o-mini\",\n", " temperature=0\n", ")" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "diEWcw00NMSj" }, "outputs": [], "source": [ "from langchain_core.prompts import ChatPromptTemplate\n", "\n", "qa_prompt = \"\"\"\\\n", "Given the following context, you must generate questions based on only the provided context.\n", "\n", "You are to generate {n_questions} questions which should be provided in the following format:\n", "\n", "1. QUESTION #1\n", "2. QUESTION #2\n", "...\n", "\n", "Context:\n", "{context}\n", "\"\"\"\n", "\n", "qa_prompt_template = ChatPromptTemplate.from_template(qa_prompt)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "ggl9SSjiNbpG" }, "outputs": [], "source": [ "question_generation_chain = qa_prompt_template | qa_chat_model" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "pNmybW-zZyI_" }, "outputs": [], "source": [ "response = question_generation_chain.invoke({\"context\": \"regulation\", \"n_questions\":2})" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qmzfIaeeaL5E", "outputId": "09853ab7-ee69-4407-dfc7-efc0d3830c0d" }, "outputs": [ { "data": { "text/plain": [ "AIMessage(content='1. What is the purpose of regulation in various industries? \\n2. How do regulations impact businesses and consumers?', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 23, 'prompt_tokens': 57, 'total_tokens': 80, 'completion_tokens_details': {'reasoning_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_71fab63729', 'finish_reason': 'stop', 'logprobs': None}, id='run-42261cd9-dae6-4062-ba45-7c01a26ba144-0', usage_metadata={'input_tokens': 57, 'output_tokens': 23, 'total_tokens': 80})" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "response" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "id": "0jd6Cyp7bdyf" }, "outputs": [], "source": [ "import tqdm\n", "def create_questions(documents, n_questions):\n", " questions = {}\n", " relevant_docs = {}\n", " for document in tqdm.tqdm(documents):\n", " document_content = {\"context\" : document.page_content, \"questions\" : []}\n", " questions_generated = question_generation_chain.invoke({\"context\": document.page_content, \"n_questions\": n_questions})\n", " for question in questions_generated.content.split(\"\\n\"):\n", " question_id = str(uuid.uuid4())\n", " questions[question_id] = \"\".join(question.split(\".\")[1:]).strip()\n", " relevant_docs[question_id] = [document.metadata[\"id\"]]\n", " return questions, relevant_docs" ] }, { "cell_type": "markdown", "metadata": { "id": "_FSTG0bb7w73" }, "source": [ "We'll use the function to generate training, validation, and test data with `n_questions=2` for each." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "85Dq6KRqEs0F", "outputId": "5bd88461-82e9-4845-833b-b4b497672ba9" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 300/300 [05:40<00:00, 1.13s/it]\n" ] } ], "source": [ "training_questions, training_relevant_contexts = create_questions(training_split_documents,2)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "eIZm4CqGVzBx", "outputId": "2b30c397-21a5-422a-aebe-0f0696abc0dd" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:44<00:00, 1.11it/s]\n" ] } ], "source": [ "val_questions, val_relevant_contexts = create_questions(val_split_documents,2)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "o6qUHg9sV2_y", "outputId": "c6e4560f-c718-4e1a-ebae-2ef8f27c781e" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 50/50 [00:49<00:00, 1.01it/s]\n" ] } ], "source": [ "test_questions, test_relevant_contexts = create_questions(test_split_documents,2)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "iF6IFFq9VsNu" }, "outputs": [], "source": [ "import json\n", "\n", "training_corpus = {train_item.metadata[\"id\"] : train_item.page_content for train_item in training_split_documents}\n", "\n", "train_dataset = {\n", " \"questions\" : training_questions,\n", " \"relevant_contexts\" : training_relevant_contexts,\n", " \"corpus\" : training_corpus\n", "}\n", "\n", "with open(\"training_dataset.jsonl\", \"w\") as f:\n", " json.dump(train_dataset, f)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "PqF9WaueV-V8" }, "outputs": [], "source": [ "val_corpus = {val_item.metadata[\"id\"] : val_item.page_content for val_item in val_split_documents}\n", "\n", "val_dataset = {\n", " \"questions\" : val_questions,\n", " \"relevant_contexts\" : val_relevant_contexts,\n", " \"corpus\" : val_corpus\n", "}\n", "\n", "with open(\"val_dataset.jsonl\", \"w\") as f:\n", " json.dump(val_dataset, f)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "0DSQ7WMnWAu6" }, "outputs": [], "source": [ "train_corpus = {test_item.metadata[\"id\"] : test_item.page_content for test_item in test_split_documents}\n", "\n", "test_dataset = {\n", " \"questions\" : test_questions,\n", " \"relevant_contexts\" : test_relevant_contexts,\n", " \"corpus\" : train_corpus\n", "}\n", "\n", "with open(\"test_dataset.jsonl\", \"w\") as f:\n", " json.dump(test_dataset, f)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "AXzVHP3v1Cno", "outputId": "8c3db6f0-023d-4a49-a231-eb3417b273a6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found existing installation: pyarrow 14.0.2\n", "Uninstalling pyarrow-14.0.2:\n", " Successfully uninstalled pyarrow-14.0.2\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m38.0/38.0 MB\u001b[0m \u001b[31m27.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m542.1/542.1 kB\u001b[0m \u001b[31m35.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m172.0/172.0 kB\u001b[0m \u001b[31m13.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m11.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m16.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "gcsfs 2024.6.1 requires fsspec==2024.6.1, but you have fsspec 2024.3.1 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0m" ] } ], "source": [ "!pip uninstall -y pyarrow\n", "!pip install -qU sentence_transformers datasets pyarrow==14.0.1" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 528, "referenced_widgets": [ "5ee51a7f95ff4eb0bb53def4627707f8", "9caa59b526134f098ad81e4d9a4c2d72", "3fb7a72345824c92a255b6c8117133cd", "36a1aba9ee4942c08751243d2fa0941f", "a7ca014629b54781affdccfee1fcafe3", "f5fd0c1004e94947975bf0ae9527b9b2", "2af315c0c4e94e08abecffdab7113181", "b21a6d1e58c741ca99bdb6a506ca1096", "dcba4a2f141243b895eaf2fa42992822", "e11bc84360cd428ba140cb23524e235a", "6c384aa04ff0414d8692eed19651c69e", "43c2e9e387c5465eaa6c69d657121a2c", "3a0a648056bf4c0db8933596d19eb4e3", "c8cc98f990954a838549ae7ddd016882", "af643ef6d57448eea3ea5341962a340c", "5b193c84b3944628b02235d351f75732", "c504196e646443f78679bc2da14fb62a", "fda21989532c4ead8828bc959a0795cb", "c900f4ff02264e7daa93dad48df3a138", "09e4750ad3f349c683083b335c1cc064", "29c9016598284bd59036c07450374db0", "ed8f2c44dbed43c794009f474fa72276", "ca535b2a4a654518a0a761b90fbf852a", "e3f432ea38854214b1ae6a1e5643a4c3", "979b861f9f314600b01955caf9614b38", "aa478586bb0549fb9a1e4ecaca0121e1", "829171391b9747f2bfc08e6bdb570af1", "3a7c4cca21914e31a916677d0a44fe63", "835cb3542ce14cda969f3245fd5d83dd", "033ad54b4a424670942aca473617b3df", "07c749942a82404e9747ddd286e8dc82", "57fe7191fef84bf7aac0c6d5e241a371", "2ed30ea3f634445286f1aa468c32d04c", "9b224bcb3144401d8579e32a719d21b3", "1ea0c07c270e4f29be18165ab0693b6a", "fe67fa83c5cb423a8d3418fc6bb09ccc", "fc25bdce853449249de1e005d7eb2d93", "7c1b828e49dd4634bf9ae7e681b1ac9c", "859a79c69c904f8fa4186970f3ae8936", "9da1509a9f2048e7bdd5136d00b68b02", "5fb3fafdd23343bda02b972a94a781f3", "9209beaf41cf46169f885322c87d5fde", "0139907e775a4d13a0c513cca059b367", "187b82d71e5843a888c5b65d360fc761", "dd79fa96509f4dbda89ce93af0f6aa19", "aff59512ba9d48d39eb5dd6f4f199f43", "e6ed0da85f404035bd45d757c94936de", "257d3d9c241f45f9aac144b2843b66e9", "8427f07a10f148daa4b60fb0fa62dad6", "d099e1a0d33c4adca923ee63c39943c5", "4c7d34d9689e482c831ba4f876f46d8f", "1812e139b5da4a1594407c5a5d220511", "0e5dfdb66ef541c8b888c40a13274132", "addacffe1286488794e27ace3992eda0", "ca68d65a8eb74f46b16aef73f7bbd70b", "e3c5d1a3c2c54f628bbf8882a0e836b3", "99cb6e3325cf4fd6a3c3b79b8df5e02d", "4b0aa6eea9fe4d5290c5bcf4865fae1c", "38bdb9a05e3d4bd88433c894a4565ca8", "95b597483dcc477d823b7e3eeb33abc8", "d6a2a4404bae4bb28d5e9d3e23fe0dc0", "2c0f55a1114d47eeb5ae87b1eca538a3", "e7b12db2684145edafd883defd01e3c5", "c94f3197aee54606a370918fb0dccb20", "b35b095d663e4fb29df536e3bee810ce", "dd415fd7ff074c269a9bbe0f20b15a44", "20e214d0cb654e76a0ce00bffc5a0fa8", "677bcb29dbdc411f9f8ce40a7ef48f7a", "393457c3ea90487d9cc4ba14aaa2ee3d", "d0e25a0bc2c94ea2b9b6e3c51d67e606", "bd58d37efc3d489c86e36c61d241e830", "f2cd67bcd8314299a9ce97a81bc5139e", "088998bf77fc4b53a1d44288096edc11", "60c95e8ab575426db1d1d83937f41205", "abf9e221b3794c3091d41654ddc83960", "05ebf6d03fcd4e1ca03b0b6ae8234f0b", "e11ecb04aa9e4cbf8f3f003abad4ae74", "3aa767354363441185173530a8c4b7c8", "cbee7826a9ae4039a4034aae820c04af", "e409f596b690457e9b136b2d22353e9a", "c3dc4c746ec9409d92edd98dfa5402d0", "72e991a930f4432794b051bf2ab95d4c", "72e4a0419b494b0fb4fbc51a280e6551", "937b004c7199467f8db52278f4723626", "66b2a58efcea460db56000789366a337", "a9af1ddab95a46e7a0d11c2a6c62cc52", "aa83650d963742c18d3ab1ec9f23ea6c", "23a67623f01d4ef385568904c5309ff2", "46eb742db2614250a29f4af09e41dddf", "de795472b45b49bbabc8df63a3f45434", "e0b5a6562d30417e88ef50c620cc9a87", "4334f324ee2540aea54fe7dbed4bad3d", "0de0be5b9e6f44bc842c78c4f66fef61", "121bfd2ce75e4195a40cebf92760ceeb", "db93056822fb49dfae71514f111e0e8e", "3bb761b0815d43a5a329075f76ce0953", "a552093981d54dd78c89dec7f1d32420", "af11755be3ad4847a04c816b64108057", "9c1c65cdb9aa4c5db8cbe9c98413d6b5", "bc5e3faa57f8459a9d1ba1c9428b8b83", "d88c31340381409b93e0bfb71f8396dd", "90ad0cee99c446f1ae8bf162dcfd3e36", "50ded92bc15f4f05b4df3015cf88c9fd", "de5e4152298b4af3b242756a94c74e2b", "2f8bacbaee784eaf8e43364f24c4fbb1", "f7fabf189fda485089321e6ee18a4946", "88d7e3b8a7f44c879c4a43d38d65ecdf", "6d5ce8f1122a44b1a77a7dfa8b1db379", "9a56095a012645ac927e0b5f8cb844a8", "db309f161fc54a558c9b9d8e9f126289", "b6a6dd121f61444bbc157f6ae3b7cf43", "dbddc36f17684eedbd5b51d64fba7f36", "9dd66a2d59a443509e2fb3139e4615b8", "a5fba065706549779b9094720acd8f07", "7cca295918474653a381765d15466091", "2e2701834356453692e625e597625c20", "b41f335ad5ef4c7c96f3ee598b5e69b4", "4c52b2b6ad9e47c3a3dfb52a19f435a8", "96baf773844e40cb8448685791b56b5c", "c6d83ab139db433892fce27459221a7e", "2719ff89262f42e69cc4d6d1b664542f" ] }, "id": "G-PGsQB7Xo6V", "outputId": "539bd9ea-e92f-41da-95b8-5b67251745cc" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/sentence_transformers/cross_encoder/CrossEncoder.py:13: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", " from tqdm.autonotebook import tqdm, trange\n", "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5ee51a7f95ff4eb0bb53def4627707f8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "modules.json: 0%| | 0.00/349 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "43c2e9e387c5465eaa6c69d657121a2c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config_sentence_transformers.json: 0%| | 0.00/252 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ca535b2a4a654518a0a761b90fbf852a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "README.md: 0%| | 0.00/84.6k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9b224bcb3144401d8579e32a719d21b3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "sentence_bert_config.json: 0%| | 0.00/107 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dd79fa96509f4dbda89ce93af0f6aa19", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.json: 0%| | 0.00/738 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e3c5d1a3c2c54f628bbf8882a0e836b3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model.safetensors: 0%| | 0.00/436M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "20e214d0cb654e76a0ce00bffc5a0fa8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0%| | 0.00/1.38k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3aa767354363441185173530a8c4b7c8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "vocab.txt: 0%| | 0.00/232k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "46eb742db2614250a29f4af09e41dddf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer.json: 0%| | 0.00/712k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bc5e3faa57f8459a9d1ba1c9428b8b83", "version_major": 2, "version_minor": 0 }, "text/plain": [ "special_tokens_map.json: 0%| | 0.00/695 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b6a6dd121f61444bbc157f6ae3b7cf43", "version_major": 2, "version_minor": 0 }, "text/plain": [ "1_Pooling/config.json: 0%| | 0.00/296 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sentence_transformers import SentenceTransformer\n", "\n", "model_id = \"Snowflake/snowflake-arctic-embed-m\"\n", "model = SentenceTransformer(model_id)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "id": "B-WbpuUWYFJr" }, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "from torch.utils.data import Dataset\n", "from sentence_transformers import InputExample" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "id": "8Lokhy6KYHAv" }, "outputs": [], "source": [ "BATCH_SIZE = 20" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "id": "JJk37zQsYJ4P" }, "outputs": [], "source": [ "corpus = train_dataset['corpus']\n", "queries = train_dataset['questions']\n", "relevant_docs = train_dataset['relevant_contexts']\n", "\n", "examples = []\n", "for query_id, query in queries.items():\n", " doc_id = relevant_docs[query_id][0]\n", " text = corpus[doc_id]\n", " example = InputExample(texts=[query, text])\n", " examples.append(example)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "id": "tiizmeIqZ_-w" }, "outputs": [], "source": [ "loader = DataLoader(\n", " examples, batch_size=BATCH_SIZE\n", ")" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "id": "Uga4nnBqlVeh" }, "outputs": [], "source": [ "from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss\n", "\n", "matryoshka_dimensions = [768, 512, 256, 128, 64]\n", "inner_train_loss = MultipleNegativesRankingLoss(model)\n", "train_loss = MatryoshkaLoss(\n", " model, inner_train_loss, matryoshka_dims=matryoshka_dimensions\n", ")" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "id": "f0hAFwUyaHQG" }, "outputs": [], "source": [ "from sentence_transformers.evaluation import InformationRetrievalEvaluator\n", "\n", "corpus = val_dataset['corpus']\n", "queries = val_dataset['questions']\n", "relevant_docs = val_dataset['relevant_contexts']\n", "\n", "evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "id": "svZG0pBHiQr6" }, "outputs": [], "source": [ "EPOCHS = 5" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 332, "referenced_widgets": [ "25a7c23f2cb24e81868320b9852addca", "40d63e923ef04db3852b8453e441ac97", "414b51ac65e2481ca8c19f3d0f329aff", "5ddde4bebc9e47a18d151cde9129cc0f", "b6e92bd140634e4c963ca31b41dd1bd4", "5f34f8e4934440178853338c6da5c8cb", "726f85d7a9da4600969ee9ab13acfd89", "2859beb2a74f4904b25d5cbf803961c2", "165791aa719144e68b5af93e9bb59918", "b547c2615c62449b938fa8b338c12d6f", "38e483e60dde4b6396f3350e2bad72b1" ] }, "id": "aDhUHZY-iR09", "outputId": "a47c6bef-cbe9-481c-f97c-af059f4a9b72" }, "outputs": [ { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "Validation Loss | \n", "Cosine Accuracy@1 | \n", "Cosine Accuracy@3 | \n", "Cosine Accuracy@5 | \n", "Cosine Accuracy@10 | \n", "Cosine Precision@1 | \n", "Cosine Precision@3 | \n", "Cosine Precision@5 | \n", "Cosine Precision@10 | \n", "Cosine Recall@1 | \n", "Cosine Recall@3 | \n", "Cosine Recall@5 | \n", "Cosine Recall@10 | \n", "Cosine Ndcg@10 | \n", "Cosine Mrr@10 | \n", "Cosine Map@100 | \n", "Dot Accuracy@1 | \n", "Dot Accuracy@3 | \n", "Dot Accuracy@5 | \n", "Dot Accuracy@10 | \n", "Dot Precision@1 | \n", "Dot Precision@3 | \n", "Dot Precision@5 | \n", "Dot Precision@10 | \n", "Dot Recall@1 | \n", "Dot Recall@3 | \n", "Dot Recall@5 | \n", "Dot Recall@10 | \n", "Dot Ndcg@10 | \n", "Dot Mrr@10 | \n", "Dot Map@100 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
30 | \n", "No log | \n", "No log | \n", "0.780000 | \n", "0.960000 | \n", "0.980000 | \n", "1.000000 | \n", "0.780000 | \n", "0.320000 | \n", "0.196000 | \n", "0.100000 | \n", "0.780000 | \n", "0.960000 | \n", "0.980000 | \n", "1.000000 | \n", "0.904303 | \n", "0.872179 | \n", "0.872179 | \n", "0.780000 | \n", "0.960000 | \n", "0.980000 | \n", "1.000000 | \n", "0.780000 | \n", "0.320000 | \n", "0.196000 | \n", "0.100000 | \n", "0.780000 | \n", "0.960000 | \n", "0.980000 | \n", "1.000000 | \n", "0.904303 | \n", "0.872179 | \n", "0.872179 | \n", "
50 | \n", "No log | \n", "No log | \n", "0.780000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.780000 | \n", "0.330000 | \n", "0.198000 | \n", "0.100000 | \n", "0.780000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.912129 | \n", "0.881667 | \n", "0.881667 | \n", "0.780000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.780000 | \n", "0.330000 | \n", "0.198000 | \n", "0.100000 | \n", "0.780000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.912129 | \n", "0.881667 | \n", "0.881667 | \n", "
60 | \n", "No log | \n", "No log | \n", "0.790000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.790000 | \n", "0.330000 | \n", "0.198000 | \n", "0.100000 | \n", "0.790000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.915820 | \n", "0.886667 | \n", "0.886667 | \n", "0.790000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.790000 | \n", "0.330000 | \n", "0.198000 | \n", "0.100000 | \n", "0.790000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.915820 | \n", "0.886667 | \n", "0.886667 | \n", "
90 | \n", "No log | \n", "No log | \n", "0.790000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.790000 | \n", "0.330000 | \n", "0.198000 | \n", "0.100000 | \n", "0.790000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.915820 | \n", "0.886667 | \n", "0.886667 | \n", "0.790000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.790000 | \n", "0.330000 | \n", "0.198000 | \n", "0.100000 | \n", "0.790000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.915820 | \n", "0.886667 | \n", "0.886667 | \n", "
100 | \n", "No log | \n", "No log | \n", "0.800000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.800000 | \n", "0.330000 | \n", "0.198000 | \n", "0.100000 | \n", "0.800000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.919511 | \n", "0.891667 | \n", "0.891667 | \n", "0.800000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.800000 | \n", "0.330000 | \n", "0.198000 | \n", "0.100000 | \n", "0.800000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.919511 | \n", "0.891667 | \n", "0.891667 | \n", "
120 | \n", "No log | \n", "No log | \n", "0.800000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.800000 | \n", "0.330000 | \n", "0.198000 | \n", "0.100000 | \n", "0.800000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.919511 | \n", "0.891667 | \n", "0.891667 | \n", "0.800000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.800000 | \n", "0.330000 | \n", "0.198000 | \n", "0.100000 | \n", "0.800000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.919511 | \n", "0.891667 | \n", "0.891667 | \n", "
150 | \n", "No log | \n", "No log | \n", "0.800000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.800000 | \n", "0.330000 | \n", "0.198000 | \n", "0.100000 | \n", "0.800000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.918202 | \n", "0.890000 | \n", "0.890000 | \n", "0.800000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.800000 | \n", "0.330000 | \n", "0.198000 | \n", "0.100000 | \n", "0.800000 | \n", "0.990000 | \n", "0.990000 | \n", "1.000000 | \n", "0.918202 | \n", "0.890000 | \n", "0.890000 | \n", "
"
],
"text/plain": [
"
Copy a token from your Hugging Face\ntokens page and paste it below.
Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file.