{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "3890292a-c99e-4367-955d-5883b93dba36", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mRequirement already satisfied: flash-attn in /opt/conda/lib/python3.10/site-packages (2.5.9.post1)\n", "Requirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (from flash-attn) (2.2.0)\n", "Requirement already satisfied: einops in /opt/conda/lib/python3.10/site-packages (from flash-attn) (0.8.0)\n", "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.13.1)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (4.9.0)\n", "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (1.12)\n", "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.1)\n", "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.1.2)\n", "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (2023.12.2)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch->flash-attn) (2.1.3)\n", "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch->flash-attn) (1.3.0)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install -q peft transformers datasets huggingface_hub\n", "!pip install flash-attn --no-build-isolation" ] }, { "cell_type": "code", "execution_count": 3, "id": "f1cc378f-afb6-441f-a4c6-2ec427b4cd4b", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup\n", "from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType, PeftConfig\n", "import torch\n", "from datasets import load_dataset\n", "import os\n", "from torch.utils.data import DataLoader\n", "from tqdm import tqdm\n", "from huggingface_hub import notebook_login\n", "from huggingface_hub import HfApi" ] }, { "cell_type": "code", "execution_count": 17, "id": "e4ab50d7-a4c9-4246-acd8-8875b87fe0da", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7f03fcf3844743fcb41f8bfc9c6c9b70", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HTML(value='