{ "cells": [ { "cell_type": "code", "execution_count": 17, "id": "3890292a-c99e-4367-955d-5883b93dba36", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mCollecting flash-attn\n", " Downloading flash_attn-2.5.9.post1.tar.gz (2.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.6/2.6 MB\u001b[0m \u001b[31m24.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n", "\u001b[?25hRequirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (from flash-attn) (2.2.0)\n", "Collecting einops (from flash-attn)\n", " Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)\n", "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.13.1)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (4.9.0)\n", "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (1.12)\n", "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.1)\n", "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.1.2)\n", "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (2023.12.2)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch->flash-attn) (2.1.3)\n", "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch->flash-attn) (1.3.0)\n", "Downloading einops-0.8.0-py3-none-any.whl (43 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hBuilding wheels for collected packages: flash-attn\n", " Building wheel for flash-attn (setup.py) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for flash-attn: filename=flash_attn-2.5.9.post1-cp310-cp310-linux_x86_64.whl size=120821333 sha256=7bfd5ecaaf20577cd1255eaa90d9008a09050b3408ba6388bcbc5b6144f482d0\n", " Stored in directory: /root/.cache/pip/wheels/cc/ad/f6/7ccf0238790d6346e9fe622923a76ec218e890d356b9a2754a\n", "Successfully built flash-attn\n", "Installing collected packages: einops, flash-attn\n", "Successfully installed einops-0.8.0 flash-attn-2.5.9.post1\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install -q peft transformers datasets huggingface_hub\n", "!pip install flash-attn --no-build-isolation" ] }, { "cell_type": "code", "execution_count": 20, "id": "f1cc378f-afb6-441f-a4c6-2ec427b4cd4b", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup\n", "from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType\n", "import torch\n", "from datasets import load_dataset\n", "import os\n", "from torch.utils.data import DataLoader\n", "from tqdm import tqdm\n", "from huggingface_hub import notebook_login\n", "from huggingface_hub import HfApi" ] }, { "cell_type": "code", "execution_count": 19, "id": "e4ab50d7-a4c9-4246-acd8-8875b87fe0da", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "baaa64cf8c0d415ba41abf52b03667b5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HTML(value='