{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "3890292a-c99e-4367-955d-5883b93dba36", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mRequirement already satisfied: flash-attn in /opt/conda/lib/python3.10/site-packages (2.5.9.post1)\n", "Requirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (from flash-attn) (2.2.0)\n", "Requirement already satisfied: einops in /opt/conda/lib/python3.10/site-packages (from flash-attn) (0.8.0)\n", "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.13.1)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (4.9.0)\n", "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (1.12)\n", "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.1)\n", "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.1.2)\n", "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (2023.12.2)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch->flash-attn) (2.1.3)\n", "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch->flash-attn) (1.3.0)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install -q peft transformers datasets huggingface_hub\n", "!pip install flash-attn --no-build-isolation" ] }, { "cell_type": "code", "execution_count": 2, "id": "f1cc378f-afb6-441f-a4c6-2ec427b4cd4b", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup\n", "from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType\n", "import torch\n", "from datasets import load_dataset\n", "import os\n", "from torch.utils.data import DataLoader\n", "from tqdm import tqdm\n", "from huggingface_hub import notebook_login\n", "from huggingface_hub import HfApi" ] }, { "cell_type": "code", "execution_count": null, "id": "e4ab50d7-a4c9-4246-acd8-8875b87fe0da", "metadata": {}, "outputs": [], "source": [ "notebook_login()" ] }, { "cell_type": "code", "execution_count": 25, "id": "8a1cb1f9-b89d-4cac-a595-44e1e0ef85b2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CommitInfo(commit_url='https://huggingface.co/Granther/prompt-tuned-phi3/commit/7ea57da9a4eccf3794c58bb4317df1c97a0fe2c8', commit_message='Upload prompt_tune_phi3.ipynb with huggingface_hub', commit_description='', oid='7ea57da9a4eccf3794c58bb4317df1c97a0fe2c8', pr_url=None, pr_revision=None, pr_num=None)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "api = HfApi()\n", "api.upload_file(path_or_fileobj='prompt_tune_phi3.ipynb',\n", " path_in_repo='prompt_tune_phi3.ipynb',\n", " repo_id='Granther/prompt-tuned-phi3',\n", " repo_type='model'\n", " )" ] }, { "cell_type": "code", "execution_count": 24, "id": "6cad1e5c-038f-4e75-8c3f-8ce0a43713a4", "metadata": {}, "outputs": [], "source": [ "device = 'cuda'\n", "\n", "model_id = 'microsoft/Phi-3-mini-128k-instruct'\n", "\n", "peft_conf = PromptTuningConfig(\n", " peft_type=PeftType.PROMPT_TUNING, # what kind of peft\n", " task_type=TaskType.CAUSAL_LM, # config task\n", " prompt_tuning_init=PromptTuningInit.TEXT, # Set to 'TEXT' to use prompt_tuning_init_text\n", " num_virtual_tokens=8, # x times the number of hidden transformer layers\n", " prompt_tuning_init_text=\"Classify if the tweet is a complaint or not:\",\n", " tokenizer_name_or_path=model_id\n", ")\n", "\n", "dataset_name = \"twitter_complaints\"\n", "checkpoint_name = f\"{dataset_name}_{model_id}_{peft_conf.peft_type}_{peft_conf.task_type}_v1.pt\".replace(\n", " \"/\", \"_\"\n", ")\n", "\n", "text_col = 'Tweet text'\n", "label_col = 'text_label'\n", "max_len = 64\n", "lr = 3e-2\n", "epochs = 50\n", "batch_size = 8" ] }, { "cell_type": "code", "execution_count": 6, "id": "6f677839-ef23-428a-bcfe-f596590804ca", "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset('ought/raft', dataset_name, split='train')" ] }, { "cell_type": "code", "execution_count": 7, "id": "c0c05613-7941-4959-ada9-49ed1093bec4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['Unlabeled', 'complaint', 'no complaint']" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.features['Label'].names\n", "#>>> ['Unlabeled', 'complaint', 'no complaint']" ] }, { "cell_type": "code", "execution_count": 8, "id": "14e2bc8b-b4e3-49c9-ae2b-5946e412caa5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'Tweet text': '@HMRCcustomers No this is my first job',\n", " 'ID': 0,\n", " 'Label': 2,\n", " 'text_label': 'no complaint'}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create lambda function\n", "classes = [k.replace('_', ' ') for k in dataset.features['Label'].names]\n", "dataset = dataset.map(\n", " lambda x: {'text_label': [classes[label] for label in x['Label']]},\n", " batched=True,\n", " num_proc=10,\n", ")\n", "\n", "dataset[0]" ] }, { "cell_type": "code", "execution_count": 9, "id": "19f0865d-e490-4c9f-a5f4-e781ed270f47", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] }, { "data": { "text/plain": [ "[1, 853, 29880, 24025]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "\n", "if tokenizer.pad_token_id == None:\n", " tokenizer.pad_token_id = tokenizer.eos_token_id\n", "\n", "target_max_len = max([len(tokenizer(class_lab)['input_ids']) for class_lab in classes])\n", "target_max_len # max length for tokenized labels\n", "\n", "tokenizer(classes[0])['input_ids'] \n", "# Ids corresponding to the tokens in the sequence\n", "# Attention mask is a binary tensor used in the transformer block to differentiate between padding tokens and meaningful ones" ] }, { "cell_type": "markdown", "id": "e1a15150-4bd9-45a2-ba43-d0bbbd16e60d", "metadata": {}, "source": [ "### Preprocess Function:\n", "- Tokenize text and label\n", "- Pad each example in the batch with tok.pad_token_id\n", "- " ] }, { "cell_type": "code", "execution_count": 31, "id": "03f05467-dce3-4e42-ab3b-c39ba620e164", "metadata": {}, "outputs": [], "source": [ "def preproc(example):\n", " batch_size = len(example[text_col])\n", " inputs = [f\"{text_col} : {x} Label : \" for x in example[text_col]]\n", " # This is the text data that will be tokenized as the model input\n", " targets = [str(x) for x in example[label_col]]\n", " # Define batch of targets corresponding to inputs\n", " model_inputs = tokenizer(inputs)\n", " labels = tokenizer(targets)\n", " # Tokenize\n", "\n", " for i in range(batch_size):\n", " sample_input_ids = model_inputs[\"input_ids\"][i]\n", " label_input_ids = labels[\"input_ids\"][i] + [tokenizer.pad_token_id] # Appends to `input_ids` and not i\n", "\n", " model_inputs[\"input_ids\"][i] = sample_input_ids + label_input_ids\n", " # Afer tokenization, concatinate\n", " labels[\"input_ids\"][i] = [-100] * len(sample_input_ids) + label_input_ids\n", " #>>> -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000\n", " # Pad the beginning of the sequence with n -100s (ignore tokens)\n", " model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n", " print(model_inputs[\"attention_mask\"][i])" ] }, { "cell_type": "code", "execution_count": 32, "id": "72ddca5f-7bce-4342-9414-9dd9d41d9dec", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cb9f37c876c548fbbcd07a7b889e1764", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Preprocessing dataset (num_proc=10): 0%| | 0/50 [00:00