|
{ |
|
"cells": [ |
|
{ |
|
"cell_type": "code", |
|
"execution_count": 17, |
|
"id": "3890292a-c99e-4367-955d-5883b93dba36", |
|
"metadata": { |
|
"scrolled": true |
|
}, |
|
"outputs": [ |
|
{ |
|
"name": "stdout", |
|
"output_type": "stream", |
|
"text": [ |
|
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", |
|
"\u001b[0mCollecting flash-attn\n", |
|
" Downloading flash_attn-2.5.9.post1.tar.gz (2.6 MB)\n", |
|
"\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m2.6/2.6 MB\u001b[0m \u001b[31m24.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", |
|
"\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n", |
|
"\u001b[?25hRequirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (from flash-attn) (2.2.0)\n", |
|
"Collecting einops (from flash-attn)\n", |
|
" Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)\n", |
|
"Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.13.1)\n", |
|
"Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (4.9.0)\n", |
|
"Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (1.12)\n", |
|
"Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.1)\n", |
|
"Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.1.2)\n", |
|
"Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (2023.12.2)\n", |
|
"Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch->flash-attn) (2.1.3)\n", |
|
"Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch->flash-attn) (1.3.0)\n", |
|
"Downloading einops-0.8.0-py3-none-any.whl (43 kB)\n", |
|
"\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
|
"\u001b[?25hBuilding wheels for collected packages: flash-attn\n", |
|
" Building wheel for flash-attn (setup.py) ... \u001b[?25ldone\n", |
|
"\u001b[?25h Created wheel for flash-attn: filename=flash_attn-2.5.9.post1-cp310-cp310-linux_x86_64.whl size=120821333 sha256=7bfd5ecaaf20577cd1255eaa90d9008a09050b3408ba6388bcbc5b6144f482d0\n", |
|
" Stored in directory: /root/.cache/pip/wheels/cc/ad/f6/7ccf0238790d6346e9fe622923a76ec218e890d356b9a2754a\n", |
|
"Successfully built flash-attn\n", |
|
"Installing collected packages: einops, flash-attn\n", |
|
"Successfully installed einops-0.8.0 flash-attn-2.5.9.post1\n", |
|
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", |
|
"\u001b[0m" |
|
] |
|
} |
|
], |
|
"source": [ |
|
"!pip install -q peft transformers datasets huggingface_hub\n", |
|
"!pip install flash-attn --no-build-isolation" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": 20, |
|
"id": "f1cc378f-afb6-441f-a4c6-2ec427b4cd4b", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup\n", |
|
"from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType\n", |
|
"import torch\n", |
|
"from datasets import load_dataset\n", |
|
"import os\n", |
|
"from torch.utils.data import DataLoader\n", |
|
"from tqdm import tqdm\n", |
|
"from huggingface_hub import notebook_login\n", |
|
"from huggingface_hub import HfApi" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": 19, |
|
"id": "e4ab50d7-a4c9-4246-acd8-8875b87fe0da", |
|
"metadata": {}, |
|
"outputs": [ |
|
{ |
|
"data": { |
|
"application/vnd.jupyter.widget-view+json": { |
|
"model_id": "baaa64cf8c0d415ba41abf52b03667b5", |
|
"version_major": 2, |
|
"version_minor": 0 |
|
}, |
|
"text/plain": [ |
|
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.svβ¦" |
|
] |
|
}, |
|
"metadata": {}, |
|
"output_type": "display_data" |
|
} |
|
], |
|
"source": [ |
|
"notebook_login()" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": 21, |
|
"id": "8a1cb1f9-b89d-4cac-a595-44e1e0ef85b2", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"api = HfApi()\n", |
|
"api.upload_file(path_or_fileobj='Granther/prompt-tuned-phi3',\n", |
|
" path_in_repo='" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": 6, |
|
"id": "6cad1e5c-038f-4e75-8c3f-8ce0a43713a4", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"device = 'cuda'\n", |
|
"\n", |
|
"model_id = 'microsoft/Phi-3-mini-128k-instruct'\n", |
|
"\n", |
|
"peft_conf = PromptTuningConfig(\n", |
|
" peft_type=PeftType.PROMPT_TUNING, # what kind of peft\n", |
|
" task_type=TaskType.CAUSAL_LM, # config task\n", |
|
" prompt_tuning_init=PromptTuningInit.TEXT, # Set to 'TEXT' to use prompt_tuning_init_text\n", |
|
" num_virtual_tokens=8, # x times the number of hidden transformer layers\n", |
|
" prompt_tuning_init_text=\"Classify if the tweet is a complaint or not:\",\n", |
|
" tokenizer_name_or_path=model_id\n", |
|
")\n", |
|
"\n", |
|
"dataset_name = \"twitter_complaints\"\n", |
|
"checkpoint_name = f\"{dataset_name}_{model_id}_{peft_conf.peft_type}_{peft_conf.task_type}_v1.pt\".replace(\n", |
|
" \"/\", \"_\"\n", |
|
")\n", |
|
"\n", |
|
"text_col = 'Tweet text'\n", |
|
"lab_col = 'text_label'\n", |
|
"max_len = 64\n", |
|
"lr = 3e-2\n", |
|
"epochs = 50\n", |
|
"batch_size = 8" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": 7, |
|
"id": "6f677839-ef23-428a-bcfe-f596590804ca", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"dataset = load_dataset('ought/raft', dataset_name, split='train')" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": 8, |
|
"id": "c0c05613-7941-4959-ada9-49ed1093bec4", |
|
"metadata": {}, |
|
"outputs": [ |
|
{ |
|
"data": { |
|
"text/plain": [ |
|
"['Unlabeled', 'complaint', 'no complaint']" |
|
] |
|
}, |
|
"execution_count": 8, |
|
"metadata": {}, |
|
"output_type": "execute_result" |
|
} |
|
], |
|
"source": [ |
|
"dataset.features['Label'].names\n", |
|
"#>>> ['Unlabeled', 'complaint', 'no complaint']" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": 11, |
|
"id": "14e2bc8b-b4e3-49c9-ae2b-5946e412caa5", |
|
"metadata": {}, |
|
"outputs": [ |
|
{ |
|
"data": { |
|
"application/vnd.jupyter.widget-view+json": { |
|
"model_id": "d9e958c687dd493880d18d4f1621dad9", |
|
"version_major": 2, |
|
"version_minor": 0 |
|
}, |
|
"text/plain": [ |
|
"Map (num_proc=10): 0%| | 0/50 [00:00<?, ? examples/s]" |
|
] |
|
}, |
|
"metadata": {}, |
|
"output_type": "display_data" |
|
}, |
|
{ |
|
"data": { |
|
"text/plain": [ |
|
"'Unlabeled'" |
|
] |
|
}, |
|
"execution_count": 11, |
|
"metadata": {}, |
|
"output_type": "execute_result" |
|
} |
|
], |
|
"source": [ |
|
"# Create lambda function\n", |
|
"classes = [k.replace('_', ' ') for k in dataset.features['Label'].names]\n", |
|
"dataset = dataset.map(\n", |
|
" lambda x: {'text_label': [classes[label] for label in x['Label']]},\n", |
|
" batched=True,\n", |
|
" num_proc=10,\n", |
|
")\n", |
|
"\n", |
|
"dataset[0]" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": 16, |
|
"id": "19f0865d-e490-4c9f-a5f4-e781ed270f47", |
|
"metadata": {}, |
|
"outputs": [ |
|
{ |
|
"name": "stderr", |
|
"output_type": "stream", |
|
"text": [ |
|
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" |
|
] |
|
}, |
|
{ |
|
"data": { |
|
"text/plain": [ |
|
"[1, 853, 29880, 24025, 32000]" |
|
] |
|
}, |
|
"execution_count": 16, |
|
"metadata": {}, |
|
"output_type": "execute_result" |
|
} |
|
], |
|
"source": [ |
|
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n", |
|
"\n", |
|
"if tokenizer.pad_token_id == None:\n", |
|
" tokenizer.pad_token_id = tokenizer.eos_token_id\n", |
|
"\n", |
|
"target_max_len = max([len(tokenizer(class_lab)['input_ids']) for class_lab in classes])\n", |
|
"target_max_len # max length for tokenized labels\n", |
|
"\n", |
|
"tokenizer(classes[0])['input_ids'] \n", |
|
"# Ids corresponding to the tokens in the sequence\n", |
|
"# Attention mask is a binary tensor used in the transformer block to differentiate between padding tokens and meaningful ones" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "459d4f69-1d85-42e8-acac-b2c7983c3a33", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [] |
|
} |
|
], |
|
"metadata": { |
|
"kernelspec": { |
|
"display_name": "Python 3 (ipykernel)", |
|
"language": "python", |
|
"name": "python3" |
|
}, |
|
"language_info": { |
|
"codemirror_mode": { |
|
"name": "ipython", |
|
"version": 3 |
|
}, |
|
"file_extension": ".py", |
|
"mimetype": "text/x-python", |
|
"name": "python", |
|
"nbconvert_exporter": "python", |
|
"pygments_lexer": "ipython3", |
|
"version": "3.10.13" |
|
} |
|
}, |
|
"nbformat": 4, |
|
"nbformat_minor": 5 |
|
} |
|
|