{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "C2EgqEPDQ8v6" }, "source": [ "## Finetune Falcon-7b on a Google colab\n", "\n", "Welcome to this Google Colab notebook that shows how to fine-tune the recent Falcon-7b model on a single Google colab and turn it into a chatbot\n", "\n", "We will leverage PEFT library from Hugging Face ecosystem, as well as QLoRA for more memory efficient finetuning" ] }, { "cell_type": "markdown", "metadata": { "id": "i-tTvEF1RT3y" }, "source": [ "## Setup\n", "\n", "Run the cells below to setup and install the required libraries. For our experiment we will need `accelerate`, `peft`, `transformers`, `datasets` and TRL to leverage the recent [`SFTTrainer`](https://huggingface.co/docs/trl/main/en/sft_trainer). We will use `bitsandbytes` to [quantize the base model into 4bit](https://huggingface.co/blog/4bit-transformers-bitsandbytes). We will also install `einops` as it is a requirement to load Falcon models." ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 39332, "status": "ok", "timestamp": 1686730698253, "user": { "displayName": "Younes Belkada", "userId": "15414910276690549281" }, "user_tz": -120 }, "id": "mNnkgBq7Q3EU", "outputId": "d7b64bce-31fb-4884-8a95-462792b68acd", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "!pip install -q -U trl transformers accelerate git+https://github.com/huggingface/peft.git\n", "!pip install -q datasets bitsandbytes einops wandb" ] }, { "cell_type": "markdown", "metadata": { "id": "Rnqmq7amRrU8" }, "source": [ "## Dataset\n", "\n", "For our experiment, we will use the Guanaco dataset, which is a clean subset of the OpenAssistant dataset adapted to train general purpose chatbots.\n", "\n", "The dataset can be found [here](https://huggingface.co/datasets/timdettmers/openassistant-guanaco)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", "Requirement already satisfied: ipywidgets in /opt/conda/lib/python3.10/site-packages (8.0.7)\n", "Requirement already satisfied: ipykernel>=4.5.1 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (6.23.0)\n", "Requirement already satisfied: ipython>=6.1.0 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (8.13.2)\n", "Requirement already satisfied: traitlets>=4.3.1 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (5.9.0)\n", "Requirement already satisfied: widgetsnbextension~=4.0.7 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (4.0.8)\n", "Requirement already satisfied: jupyterlab-widgets~=3.0.7 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (3.0.8)\n", "Requirement already satisfied: comm>=0.1.1 in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (0.1.3)\n", "Requirement already satisfied: debugpy>=1.6.5 in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (1.6.7)\n", "Requirement already satisfied: jupyter-client>=6.1.12 in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (8.2.0)\n", "Requirement already satisfied: jupyter-core!=5.0.*,>=4.12 in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (5.3.0)\n", "Requirement already satisfied: matplotlib-inline>=0.1 in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (0.1.6)\n", "Requirement already satisfied: nest-asyncio in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (1.5.6)\n", "Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (23.1)\n", "Requirement already satisfied: psutil in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (5.9.5)\n", "Requirement already satisfied: pyzmq>=20 in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (25.0.2)\n", "Requirement already satisfied: tornado>=6.1 in /opt/conda/lib/python3.10/site-packages (from ipykernel>=4.5.1->ipywidgets) (6.3)\n", "Requirement already satisfied: backcall in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.2.0)\n", "Requirement already satisfied: decorator in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (5.1.1)\n", "Requirement already satisfied: jedi>=0.16 in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.18.2)\n", "Requirement already satisfied: pickleshare in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.7.5)\n", "Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (3.0.38)\n", "Requirement already satisfied: pygments>=2.4.0 in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (2.15.1)\n", "Requirement already satisfied: stack-data in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.6.2)\n", "Requirement already satisfied: pexpect>4.3 in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (4.8.0)\n", "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /opt/conda/lib/python3.10/site-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets) (0.8.3)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.10/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets) (2.8.2)\n", "Requirement already satisfied: platformdirs>=2.5 in /opt/conda/lib/python3.10/site-packages (from jupyter-core!=5.0.*,>=4.12->ipykernel>=4.5.1->ipywidgets) (3.5.0)\n", "Requirement already satisfied: ptyprocess>=0.5 in /opt/conda/lib/python3.10/site-packages (from pexpect>4.3->ipython>=6.1.0->ipywidgets) (0.7.0)\n", "Requirement already satisfied: wcwidth in /opt/conda/lib/python3.10/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.1.0->ipywidgets) (0.2.6)\n", "Requirement already satisfied: executing>=1.2.0 in /opt/conda/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (1.2.0)\n", "Requirement already satisfied: asttokens>=2.1.0 in /opt/conda/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.2.1)\n", "Requirement already satisfied: pure-eval in /opt/conda/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (0.2.2)\n", "Requirement already satisfied: six in /opt/conda/lib/python3.10/site-packages (from asttokens>=2.1.0->stack-data->ipython>=6.1.0->ipywidgets) (1.16.0)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "!pip install ipywidgets\n" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", "Requirement already satisfied: datasets in /opt/conda/lib/python3.10/site-packages (2.14.0)\n", "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from datasets) (1.23.5)\n", "Requirement already satisfied: pyarrow>=8.0.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (12.0.0)\n", "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (0.3.6)\n", "Requirement already satisfied: pandas in /opt/conda/lib/python3.10/site-packages (from datasets) (2.0.1)\n", "Requirement already satisfied: requests>=2.19.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (2.28.2)\n", "Requirement already satisfied: tqdm>=4.62.1 in /opt/conda/lib/python3.10/site-packages (from datasets) (4.65.0)\n", "Requirement already satisfied: xxhash in /opt/conda/lib/python3.10/site-packages (from datasets) (3.2.0)\n", "Requirement already satisfied: multiprocess in /opt/conda/lib/python3.10/site-packages (from datasets) (0.70.14)\n", "Requirement already satisfied: fsspec[http]>=2021.11.1 in /opt/conda/lib/python3.10/site-packages (from datasets) (2023.5.0)\n", "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.10/site-packages (from datasets) (3.8.5)\n", "Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /opt/conda/lib/python3.10/site-packages (from datasets) (0.16.4)\n", "Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from datasets) (23.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from datasets) (5.4.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (22.2.0)\n", "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (3.1.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (6.0.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (4.0.2)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (1.9.2)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (1.4.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (3.12.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.14.0->datasets) (4.5.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (3.4)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (1.26.15)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (2023.5.7)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2023.3)\n", "Requirement already satisfied: tzdata>=2022.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets) (2023.3)\n", "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "!pip install datasets" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 232, "referenced_widgets": [ "6f571b2d32f5409987728ea9c8c39377", "45b5b31a70f94bcba5045ae9b8fc44b5", "2a12100f5a0845a6ace11f49f6e333c0", "76096265c4fa4b598e7ada98101779f6", "9034637c18d54c129a5193e2821eba53", "990827660f5f4facb9c9cd097124920c", "44a19cca5cf646d1a4544e45f5b0fcce", "f7e9db27e9d1454782d0bc8c9db86e22", "c270be36d07a47ef86b22123875f9496", "f09ab842eb194278b9643dc99bd225cc", "8173450e729748868ae98ba48bb0cdb6", "e05de646398444faa36b7dfad2f9672b", "bd2f1ad246c64a0ead1eff4cf9a79c0c", "1b1796ec19f94c37a7ee22b5a0c9e337", "4b8dd62ec02b49e89c6e658cbbbae447", "ec3f7fadf3b54dc9b7d6526f97350db6", "9ecc915631244114b0b603207659c2a9", "ccba392f28a247c79929d32226eb0459", "bddea7096980406f9b77fe7effffef17", "c7ed9c56d49540f6a192db24fbf1a183", "98db8f82daf2474d9c438135f1a33b45", "34c4945219974a0dacce4ba733d7e304", "dd32e4f5298a463b9fa0d4908100ee63", "e3efc51876e24a338af2b50a78fe64e2", "51eca9c7a83246e2afd107ff8658c268", "89f58a5c122747c6b55b1616c70e6c39", "71410f93dfdd4f32a470b389444b41ad", "b2ed038d24354024b76697d24e09cbdc", "a0561c041e644c3fbe8373751c3971aa", "e38b6a1da82740e88eaba364fd77ab27", "31a96d1c3b0f4dd89b735285cbb975f2", "c71190faec964b128c1c2f4dbf242d71", "2e01949718e8400fbabfa39f71c6c9fa", "cff979cd8890498e9d9540b16de9595c", "bee1730cfc4e4254b8e25d1797693775", "c2f69eec622842e590789d1fa70c8476", "becc8adffbde4048b433cc96e5fbf063", "e5e2e89270994205954ea1c309a136e4", "fe0b625e62de4a2e8f1d97d33ececf6f", "c5692a50c7e3472eb256e4d4c91719b5", "12638159e5f54d2faba4ecad9071281a", "73b803bf0bc0449c9d07529ac138b31a", "3270e9bcd327400babeca330164a75c4", "0cc5bcb625824db69e29731cf9598eb6", "d1a3251a95014d93af632dae53e1d3c2", "3f4406daf20d4185a3c2d8beaea4d732", "b75b653e75f4451bb2e4a7f477749da2", "92a4ef62d9af4989964f7d308dfe680e", "6789b55c59c14206919b4a57bcb02a4f", "abd0f918c58c417bbfbd4d27191d204f", "f0eefb8b3f9a4459b19ecb614f2f49d0", "d50060bdfc5e46a987e751fb75dcbced", "ac7f4dd6b4bc4f23b26476f1f7ec84c3", "577bd5b02cd64c6891556b0e7b1db356", "304a60e41b7f46169fbae1ab8607c638", "d16479cd3f604931bd75f6a026560cd3", "37cec5e69671425eafd940ee5d962928", "080a8f3330f04a82bd3b9334e0be1ec3", "f9ce9f4e7e854b439ae90d70510e58e7", "cc570b6ea1bc40b9a2aa85c9009ad488", "52b20a90d35b40b293299fedf0a09b56", "2a305e75f03244cbb0b54562b3b91a36", "b47acd1edaa143deb039ccad339a3b4f", "73d9428671c945a89c4f0937de215633", "8bad377faeba440ba200ab18b30aa213", "7cf8110a7e064113b8969cdd90fd2418", "f452852354e0470ca4cb209c64594138", "22160d30fd454414ad6d618b21a1b5a7", "f1b67a0c2c09417aa32406c99d9ed82a", "93a21af276a7495e9c66a47438b219ba", "8f7ec8c6d8c04277b6f4e90162b0cff4", "3100042605aa43389353e2715a3f1458", "f2f9ae37d73c42f5b40f69fce1e8d081", "a84fd3b0b67642a09746721fa4f6924b", "d3f1ba390d854409822008c94190ad6d", "58f0194370a9425d88a04a34a02c7797", "01f6dfa2ce894cad86a2ad9fdd51cbb9" ] }, "executionInfo": { "elapsed": 4258, "status": "ok", "timestamp": 1686730702504, "user": { "displayName": "Younes Belkada", "userId": "15414910276690549281" }, "user_tz": -120 }, "id": "0X3kHnskSWU4", "outputId": "c618b983-ac72-4a81-f2cd-307cda489f78", "tags": [] }, "outputs": [], "source": [ "import ipywidgets\n", "from datasets import load_dataset\n", "\n", "dataset_name = \"mskov/DaVinci_Completion\"\n", "dataset = load_dataset(dataset_name, split=\"train\")" ] }, { "cell_type": "markdown", "metadata": { "id": "rjOMoSbGSxx9" }, "source": [ "## Loading the model" ] }, { "cell_type": "markdown", "metadata": { "id": "AjB0WAqFSzlD" }, "source": [ "In this section we will load the [Falcon 7B model](https://huggingface.co/tiiuae/falcon-7b), quantize it in 4bit and attach LoRA adapters on it. Let's get started!" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", "Requirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (2.0.0)\n", "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch) (3.12.0)\n", "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.10/site-packages (from torch) (4.5.0)\n", "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch) (1.11.1)\n", "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch) (3.1)\n", "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch) (3.1.2)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch) (2.1.2)\n", "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.2.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "!pip install torch" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n" ] } ], "source": [ "import torch\n", "\n", "print(torch.cuda.is_available())\n" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "3418452aa85a4f7187edbf27c12a96c6", "e957cfa9145944eb9edfd9a04a36d191", "0b23afe0167d4fc3ad104e502cba11ec", "5b72de866ff94d23ba50fb80830412e5", "d84e94f921034a9e98660ffea4b6b828", "f6b100f6547249dc93235fef8a977919", "8c0b0ebfa9c9467db35dd720c9dce6c4", "7f4aa988d7314fe59de34ef6efdacb39", "960633c0c7304480ad84cf224300c36f", "6a4a4d6c45a742a3b0f557b11f84c886", "c518b25983f3452d87c38684c0d46b1a", "57537b549b3a4a6a9b81ec4624cec204", "adfbdb99a779446193dd1f052451640f", "60534daa50f24b4fb2873ee56dbeef5f", "be7520e8b9c2450cae4d80acbcb4a06e", "8528b3f2042545cb8eced01edf974508", "c28d673b39d243bc937ccded9c57b69d", "b3ba9f9e56f44d15bb931ba02858c94e", "f3a82733a81e439bb2bc57a159d26c2c", "cad27935c7af456483f5b052d0ddbe66", "df56e4c6864a4e2f9819fb5862b52749", "a384f3d0ac02499898e8bf45589cedc7", "71d06f0240e4471cad9980c789b60063", "4b888ff43f8341449d619b8fbf500900", "b50395c90704422abf923fe376c20c7b", "4c9de03af92c449289094070cc71971b", "8f1c0b9a617b4da88921227cecde7e1b", "485a493d14344eb4910ba39cf5e0af10", "ec0cb58456e5455da4741d1df9d2b7af", "b7ce2e2bd35748e29bd77b9fa6fa5f1f", "b8e34f70d26649ca83d63aa878a8df67", "8785fc48f91f458ab26d16418bfd654d", "56b9667df3d44c00b391292c1ba82ee7", "fc43df8f987e4066835c49609a45a1e7", "be562bce884e4a58b36b4ab322ae67ce", "8f9113646dd04ddba0f2fa93e923524c", "a35f476f6b1447ed8e53736947fb2f13", "5c6f953f123b4126a933ce96f7d9490c", "25535cc4426e48a8ac4ecf90d5b67c9f", "8281c833347e42cda4d814c8aa2623bb", "f5a5428c564e4e30a8294d62d64c70e9", "0d9b3f8b40a74ca4b8ed5502f8bc9c85", "e837691cf19c46e8ad5d5fbb9f7513df", "d0b802488b5347dbbf945b047af5c4cf", "967b05f5629c473799806cb719178e02", "0438c076543f43d395c9ea25814597ce", "c9fd30031e704fce977703db5040c79a", "05ce0e5b6a7847ffab8a8dd36ab05e99", "1da70a71e2144a3fbb410bcee001e349", "2f2a69400014448993afb6b52a9d740c", "bfd38da2907f42a5b6f98f2943afb3a0", "ee9350b0f26648139e1cb5d3071750d4", "fe07baa3846d4d839b33200bfa89afaa", "3193ea424dd645a48277e6578b67d37b", "4df0504c5ff24a519d49835d3ec34b27", "da473b604ef14b6d967a010b6cf6bb3f", "4db8828ddb2443968d0101f54e8a0bdd", "42803243f0bc444a9ae4ec961be3f8ef", "f563794e2ed249e5964f38c7e892c40e", "a3ff1082bf6e484281d3f675f9a020f7", "d1efdaf61b464fd9a8f4c743a4396f6d", "3c6488e4e0164ef0a11447f3c9f3452b", "16608ca8206045ecac81977c415b393d", "2feaa840e74b4489a1d48f568a61c9bc", "487cd20587914c6ab11cd9a2fde4ef32", "642ad7d1c5ec4b97b53870e5b646098d", "2e953e73b6a149e1b19be8d8a18da0bf", "77b4f5fa2242419fa5a2878306058642", "ef21d6a710754b129d315dc85761a3d2", "3b62d55d0b5543fcaeaba22fa0e069b3", "fe15178e946d4b54bce9baefe3b4228e", "403f0b742ed34e24af077c28f8e5ce18", "dd68ae21e3b84ce99cd1aa154146cab1", "0ba4719ac0754c96accd56ed3376d327", "801b4e185add4a18a533dd3e30ab045c", "a41f8b79c7e84d53bdae11ee5fd924f1", "045af64992e34d52b4047cfc1eaf64fd", "a5aac7c5c2b0433cbcab74e9ad1a2eb3", "8e0da4c982ce450ea223e9d14fdd28a9", "73932fb9f2544e95aee7b261a649ef9e", "014dce75b5fc4c04aac4f51fe4469f18", "77ffc27a4eb847e0bff2ee23c7bff2d2", "a884c7d0c1754f69b28dbd5732270990", "a65d86a8c4404060baf6fd4d1da7322f", "4ee83919b1bf4213960a35c10340e5fc", "3bc0ab69c5ba4acdb91b21412464721d", "70f983365c7c41dfbf21635cdf03106f", "03b92f1092504ed2b4bbebf3958a260a", "f7e2680e102c487fa46b01ed825dcaff", "5ebf3e11cf504846a17a01fb6c7eb624", "a9d2bd50dba148ea9dd35b29b7315855", "f345c795c4624816ba2d03f3328d0003", "2bee2d465022450eb456fe069ea9f08e", "466c9a18605944e595db74bfac65bc2c", "85c56cbbbebe461e9f87f4b9551ffc73", "78a640ec6aa842a6a606e2cc8b8679ed", "3d72cd13fd9b46cd89cbf963630e1015", "8aaee88d8f9242ae947936ef12562b66", "b1c46e9bb9c34c5ead78ca554a804b92", "96586235cd194ef2b4b46230188f4800", "a4a334d17df34347995a39da233fa025", "99d9f791c9b54df9bbab42b7b684c690", "0def29d2a2a246458970fee883829164", "d60e918a4c0a4c4d810024499a8ef0ca", "60884069cf9c4935a03e66bd8eb49f34", "d9ba0cdd96af4415badaa272b59038cb", "b59263c94a404876a71445925dbbee92", "8b4028166c134642a551f54c85958c30", "c34ecf136caa4ee7947552b7c8835271", "f281ace75cad4897a02054b6464a3319", "f0561d9a4304473b9726ba5c3bb46552", "438712584e62414d842a67c70bccf915", "06215fa4377242a68f908111a5815315", "5c6f93da94444139942cad714d15e7d5", "2c9fb31b892b480a885d6f89da2decc3", "85dba2a783804b058f09157fdeba8e05", "d95f7c9c9280442b8a49c7c3598cd4db", "648fc1a69d0f4e5c8d4955561075f5a7", "22dccd166467407b8fc87a94745945ab", "f0e3c6182fa144a29fa3e808da9cc771", "fc712dfc919c407ea4e0221e7765da9e", "0ad3637e4d5e415587bd7d5a345ee67b", "5d7f4509c5884f84a251b809221bfa82", "8bfe1f3002e74a1495ca0037756c22a6", "9a58633d24824077ae13a7e761d4010d", "bd7d0817b8e1485bb3a5b4c2353aa032", "d68b37d76e824448b9f1bf8efff81878", "6fb1c6ff571447b899cdc66991dd414b", "8864b96d1938490687328efb2e721961", "2f0aba0c62c940a39f180bd0a140dcb8", "1b8b05bbd9e64dd2bb7973243c4b0f8f", "52d65646f12a415795559d6534514b71", "5bf0d0a5d4994492bbb5629bec717202", "27fb87c251c8499381f0ed1c9d53c8dc", "f6820c69cb644197a377f8cc125c059b", "56b4924f533042afb2c4bc1cdf844a93", "54abc387eeb943fb9a6d2facfbc962f8", "a492a2523a414f9585bd65d3823f973b", "d913239cc4d048c489d9dcdb498bf25e", "03f8f085622646b4a0e7feb32af6198a", "695d9c3d1bcc4c63a2240f689ab44d41", "afa2787cd84c44dd8e2389b4d3f3075e", "1193cb0363884a9d8ed1ce70c1606e55", "5ffa25e4ca2742bba5b6c1b77a0d9b66", "6eb4dced23b948259726b6dc32d2d852", "c940e62308d24de891380df738711bc8", "3954141106b746328e358364981d9176", "196e7d6c8fe443d7b5fdc07208f4127b", "e905245faacc421c927e7885184cf564", "2ee44cafce984d7bb5a2a6f7ae79004a", "b7ef97a499bf4dad8f3d752a6ce668a0", "7fab2eeea95a41bdace4f67a03a6fab4", "97471ca5d46d4063b050e04474e34060", "ee82ee77d15b4094aeb89ed009193d67", "971b97c5cf5649498777b2b456941233", "8ee69156ca06492daebeff348defc20f", "2730746569f74997a4ce8c4f8b7bdbd0", "f27a6664d1eb485d929e1524be8d03b9", "8d12bbc64ca748a68907390e3b14ba7b", "c2312873da444974871c15cf3213408d", "69cd22eeecfa4d3bac293b34117683a1", "160af337a5d8473ebb8b53d8b8f391a9", "14d135a834de4218b288eda1b46c9568", "5dc95867535a47b0be04f223ad795ae5", "02e2a5844c434beea2ad186a3d89b50b" ] }, "executionInfo": { "elapsed": 189743, "status": "ok", "timestamp": 1686730892243, "user": { "displayName": "Younes Belkada", "userId": "15414910276690549281" }, "user_tz": -120 }, "id": "ZwXZbQ2dSwzI", "outputId": "b3dd0dbb-b991-47d0-b0ef-b447d22fa747", "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7383a3550baf4c5f8bca0eb954e207a8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/8 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import torch\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer\n", "\n", "model_name = \"ybelkada/falcon-7b-sharded-bf16\"\n", "\n", "bnb_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_compute_dtype=torch.float16,\n", ")\n", "\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_name,\n", " quantization_config=bnb_config,\n", " trust_remote_code=True\n", ")\n", "model.config.use_cache = False" ] }, { "cell_type": "markdown", "metadata": { "id": "xNqIYtQcUBSm" }, "source": [ "Let's also load the tokenizer below" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 113, "referenced_widgets": [ "fe1e4ce03e374711a749653e0168edb5", "4c04d9b6581b468697612fb2cbb4355e", "b8690b5bb29c4ca4a4c564b832261cbf", "6d9299498ae140b8a83856b85e97e269", "066d24010aca4fc394b9219affe7d788", "8689fc8a7ac648cbacb6cdb978114019", "9a8c8c5c77af4c359f6f6c6fce2eea85", "e19a50adcd304516bd79178aff0eed36", "1894eddc0bb14418942e77060dca4f0d", "0fee32efd3bc4b138907cbef79650732", "df11b17c98bf4549886a8f48a9df544b", "9b6ad5d4c2654963817ecbc0e1ce2f4d", "4f30ceb6f1e5489fa49832d6a8ed9618", "89c94e9f285542a1b6b45e928d423f4f", "ac7191cda5944ad3bdfcfbaaee4a3910", "f80d13105aed409d88b5dce57db3c6c8", "ef7dedba07754ec098cd66bbbe6cc3dd", "e7a11f7df2ec4c36be9ca3994c800e86", "41e3b849779343c08bb79decdc10d892", "c4a1effabe5e4099847313a5a4fb4569", "a37fca09f5ba409c87ca73df24ae98ed", "1f3bc9531f3045d59547dbbfb95ac1b0", "23cd5a17e94949a39d7aace3bf1bf4b8", "e551682052c64cc5ad66f878e69b7f8e", "1468686f771a4c3fa921cbcf6f015f88", "ff7e93de7212472282435079790d0953", "8f6cd746232f4069a7aa6668325d487d", "fdf12d3caaf147d382e5f305fdfb760b", "a6ba17bc18d84777bc6e8e155ebe93ed", "2c9ee56b64224604be07f0cfcb784d64", "5bb42ac8c10d4e8197214aa3876e3ce6", "8597cf5e0f9a4fe8a8672165b112e27b", "d2e27c2dfea44de39a3130a19aa6a439" ] }, "executionInfo": { "elapsed": 1888, "status": "ok", "timestamp": 1686730894128, "user": { "displayName": "Younes Belkada", "userId": "15414910276690549281" }, "user_tz": -120 }, "id": "XDS2yYmlUAD6", "outputId": "1f5c4180-1881-4484-f9e9-f289dd7c3bc3", "tags": [] }, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n", "tokenizer.pad_token = tokenizer.eos_token" ] }, { "cell_type": "markdown", "metadata": { "id": "NuAx3zBeUL1q" }, "source": [ "Below we will load the configuration file in order to create the LoRA model. According to QLoRA paper, it is important to consider all linear layers in the transformer block for maximum performance. Therefore we will add `dense`, `dense_h_to_4_h` and `dense_4h_to_h` layers in the target modules in addition to the mixed query key value layer." ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "id": "dQdvjTYTT1vQ", "tags": [] }, "outputs": [], "source": [ "from peft import LoraConfig\n", "\n", "lora_alpha = 16\n", "lora_dropout = 0.1\n", "lora_r = 64\n", "\n", "peft_config = LoraConfig(\n", " lora_alpha=lora_alpha,\n", " lora_dropout=lora_dropout,\n", " r=lora_r,\n", " bias=\"none\",\n", " task_type=\"CAUSAL_LM\",\n", " target_modules=[\n", " \"query_key_value\",\n", " \"dense\",\n", " \"dense_h_to_4h\",\n", " \"dense_4h_to_h\",\n", " ]\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "dzsYHLwIZoLm" }, "source": [ "## Loading the trainer" ] }, { "cell_type": "markdown", "metadata": { "id": "aTBJVE4PaJwK" }, "source": [ "Here we will use the [`SFTTrainer` from TRL library](https://huggingface.co/docs/trl/main/en/sft_trainer) that gives a wrapper around transformers `Trainer` to easily fine-tune models on instruction based datasets using PEFT adapters. Let's first load the training arguments below." ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "id": "OCFTvGW6aspE", "tags": [] }, "outputs": [], "source": [ "from transformers import TrainingArguments\n", "\n", "output_dir = \"./results\"\n", "per_device_train_batch_size = 4\n", "gradient_accumulation_steps = 4\n", "optim = \"paged_adamw_32bit\"\n", "save_steps = 10\n", "logging_steps = 10\n", "learning_rate = 2e-4\n", "max_grad_norm = 0.3\n", "max_steps = 500\n", "warmup_ratio = 0.03\n", "lr_scheduler_type = \"constant\"\n", "\n", "training_arguments = TrainingArguments(\n", " output_dir=output_dir,\n", " per_device_train_batch_size=per_device_train_batch_size,\n", " gradient_accumulation_steps=gradient_accumulation_steps,\n", " optim=optim,\n", " save_steps=save_steps,\n", " logging_steps=logging_steps,\n", " learning_rate=learning_rate,\n", " fp16=True,\n", " max_grad_norm=max_grad_norm,\n", " max_steps=max_steps,\n", " warmup_ratio=warmup_ratio,\n", " group_by_length=True,\n", " lr_scheduler_type=lr_scheduler_type,\n", " save_strategy=\"steps\", # Set to \"steps\" to save at specified intervals\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "I3t6b2TkcJwy" }, "source": [ "Then finally pass everthing to the trainer" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 72, "referenced_widgets": [ "71e95c840eea48cb8b3be8673dc1c18e", "d3a14a57ab774d09931dd1ae29797699", "79ea119fadb74c489112deaa51ddaacf", "7163c36a4a0e4fbb804baa2280e2f41f", "34725b5045d14fda86f9535f938629f8", "055142cabdd94b07979dfce25225074b", "609329ae177a417aa3656621f06aa230", "44c55abca8104685bc3b6da5f0fa2cfb", "fb80ff0189f94c83af0c05d71c330534", "b363e81907984ef6adf04a0ef63396eb", "f8f3870d43e446a6b5f55c3b1e869e8c" ] }, "executionInfo": { "elapsed": 73748, "status": "ok", "timestamp": 1686730967873, "user": { "displayName": "Younes Belkada", "userId": "15414910276690549281" }, "user_tz": -120 }, "id": "TNeOBgZeTl2H", "outputId": "2cfdd9b7-921f-4cd2-db81-aa445b12b9ed", "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/peft/utils/other.py:104: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.\n", " warnings.warn(\n" ] } ], "source": [ "from trl import SFTTrainer\n", "\n", "max_seq_length = 512\n", "\n", "trainer = SFTTrainer(\n", " model=model,\n", " train_dataset=dataset,\n", " peft_config=peft_config,\n", " dataset_text_field=\"prompt\",\n", " max_seq_length=max_seq_length,\n", " tokenizer=tokenizer,\n", " args=training_arguments,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "GWplqqDjb3sS" }, "source": [ "We will also pre-process the model by upcasting the layer norms in float 32 for more stable training" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "id": "7OyIvEx7b1GT", "tags": [] }, "outputs": [], "source": [ "for name, module in trainer.model.named_modules():\n", " if \"norm\" in name:\n", " module = module.to(torch.float32)" ] }, { "cell_type": "markdown", "metadata": { "id": "1JApkSrCcL3O" }, "source": [ "## Train the model" ] }, { "cell_type": "markdown", "metadata": { "id": "JjvisllacNZM" }, "source": [ "Now let's train the model! Simply call `trainer.train()`" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 286 }, "id": "_kbS7nRxcMt7", "outputId": "37aafaba-55ae-4bfc-b02b-00cb17cdccde", "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ] }, { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "
---|---|
10 | \n", "2.862500 | \n", "
20 | \n", "2.898700 | \n", "
30 | \n", "2.755200 | \n", "
40 | \n", "2.862700 | \n", "
50 | \n", "2.660200 | \n", "
60 | \n", "2.808900 | \n", "
70 | \n", "2.710200 | \n", "
80 | \n", "2.772000 | \n", "
90 | \n", "2.695300 | \n", "
100 | \n", "2.640100 | \n", "
110 | \n", "2.648000 | \n", "
120 | \n", "2.575600 | \n", "
130 | \n", "2.600200 | \n", "
140 | \n", "2.580100 | \n", "
150 | \n", "2.667800 | \n", "
160 | \n", "2.575800 | \n", "
170 | \n", "2.686100 | \n", "
180 | \n", "2.416900 | \n", "
190 | \n", "2.365700 | \n", "
200 | \n", "2.315300 | \n", "
210 | \n", "2.360100 | \n", "
220 | \n", "2.349800 | \n", "
230 | \n", "2.411400 | \n", "
240 | \n", "2.370800 | \n", "
250 | \n", "2.417600 | \n", "
260 | \n", "2.372800 | \n", "
270 | \n", "2.105800 | \n", "
280 | \n", "2.037500 | \n", "
290 | \n", "2.082800 | \n", "
300 | \n", "2.081000 | \n", "
310 | \n", "2.100400 | \n", "
320 | \n", "2.147200 | \n", "
330 | \n", "2.143700 | \n", "
340 | \n", "2.149000 | \n", "
350 | \n", "1.975800 | \n", "
360 | \n", "1.791000 | \n", "
370 | \n", "1.736600 | \n", "
380 | \n", "1.806600 | \n", "
390 | \n", "1.782100 | \n", "
400 | \n", "1.866100 | \n", "
410 | \n", "1.809600 | \n", "
420 | \n", "1.910500 | \n", "
430 | \n", "1.862400 | \n", "
440 | \n", "1.560100 | \n", "
450 | \n", "1.396700 | \n", "
460 | \n", "1.487500 | \n", "
470 | \n", "1.456300 | \n", "
480 | \n", "1.554500 | \n", "
490 | \n", "1.509300 | \n", "
500 | \n", "1.523400 | \n", "
"
],
"text/plain": [
"