prompt-tuned-phi3 / prompt_tune_phi3.ipnb
Granther's picture
Upload prompt_tune_phi3.ipnb with huggingface_hub
20c8824 verified
raw
history blame
9.87 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": 17,
"id": "3890292a-c99e-4367-955d-5883b93dba36",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0mCollecting flash-attn\n",
" Downloading flash_attn-2.5.9.post1.tar.gz (2.6 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.6/2.6 MB\u001b[0m \u001b[31m24.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
"\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n",
"\u001b[?25hRequirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (from flash-attn) (2.2.0)\n",
"Collecting einops (from flash-attn)\n",
" Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)\n",
"Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.13.1)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (4.9.0)\n",
"Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (1.12)\n",
"Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.1)\n",
"Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.1.2)\n",
"Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (2023.12.2)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch->flash-attn) (2.1.3)\n",
"Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch->flash-attn) (1.3.0)\n",
"Downloading einops-0.8.0-py3-none-any.whl (43 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hBuilding wheels for collected packages: flash-attn\n",
" Building wheel for flash-attn (setup.py) ... \u001b[?25ldone\n",
"\u001b[?25h Created wheel for flash-attn: filename=flash_attn-2.5.9.post1-cp310-cp310-linux_x86_64.whl size=120821333 sha256=7bfd5ecaaf20577cd1255eaa90d9008a09050b3408ba6388bcbc5b6144f482d0\n",
" Stored in directory: /root/.cache/pip/wheels/cc/ad/f6/7ccf0238790d6346e9fe622923a76ec218e890d356b9a2754a\n",
"Successfully built flash-attn\n",
"Installing collected packages: einops, flash-attn\n",
"Successfully installed einops-0.8.0 flash-attn-2.5.9.post1\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m"
]
}
],
"source": [
"!pip install -q peft transformers datasets huggingface_hub\n",
"!pip install flash-attn --no-build-isolation"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "f1cc378f-afb6-441f-a4c6-2ec427b4cd4b",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup\n",
"from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType\n",
"import torch\n",
"from datasets import load_dataset\n",
"import os\n",
"from torch.utils.data import DataLoader\n",
"from tqdm import tqdm\n",
"from huggingface_hub import notebook_login\n",
"from huggingface_hub import HfApi"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "e4ab50d7-a4c9-4246-acd8-8875b87fe0da",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "baaa64cf8c0d415ba41abf52b03667b5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"notebook_login()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "8a1cb1f9-b89d-4cac-a595-44e1e0ef85b2",
"metadata": {},
"outputs": [],
"source": [
"api = HfApi()\n",
"api.upload_file(path_or_fileobj='Granther/prompt-tuned-phi3',\n",
" path_in_repo='"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6cad1e5c-038f-4e75-8c3f-8ce0a43713a4",
"metadata": {},
"outputs": [],
"source": [
"device = 'cuda'\n",
"\n",
"model_id = 'microsoft/Phi-3-mini-128k-instruct'\n",
"\n",
"peft_conf = PromptTuningConfig(\n",
" peft_type=PeftType.PROMPT_TUNING, # what kind of peft\n",
" task_type=TaskType.CAUSAL_LM, # config task\n",
" prompt_tuning_init=PromptTuningInit.TEXT, # Set to 'TEXT' to use prompt_tuning_init_text\n",
" num_virtual_tokens=8, # x times the number of hidden transformer layers\n",
" prompt_tuning_init_text=\"Classify if the tweet is a complaint or not:\",\n",
" tokenizer_name_or_path=model_id\n",
")\n",
"\n",
"dataset_name = \"twitter_complaints\"\n",
"checkpoint_name = f\"{dataset_name}_{model_id}_{peft_conf.peft_type}_{peft_conf.task_type}_v1.pt\".replace(\n",
" \"/\", \"_\"\n",
")\n",
"\n",
"text_col = 'Tweet text'\n",
"lab_col = 'text_label'\n",
"max_len = 64\n",
"lr = 3e-2\n",
"epochs = 50\n",
"batch_size = 8"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "6f677839-ef23-428a-bcfe-f596590804ca",
"metadata": {},
"outputs": [],
"source": [
"dataset = load_dataset('ought/raft', dataset_name, split='train')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c0c05613-7941-4959-ada9-49ed1093bec4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Unlabeled', 'complaint', 'no complaint']"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset.features['Label'].names\n",
"#>>> ['Unlabeled', 'complaint', 'no complaint']"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "14e2bc8b-b4e3-49c9-ae2b-5946e412caa5",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d9e958c687dd493880d18d4f1621dad9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map (num_proc=10): 0%| | 0/50 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'Unlabeled'"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create lambda function\n",
"classes = [k.replace('_', ' ') for k in dataset.features['Label'].names]\n",
"dataset = dataset.map(\n",
" lambda x: {'text_label': [classes[label] for label in x['Label']]},\n",
" batched=True,\n",
" num_proc=10,\n",
")\n",
"\n",
"dataset[0]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "19f0865d-e490-4c9f-a5f4-e781ed270f47",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
},
{
"data": {
"text/plain": [
"[1, 853, 29880, 24025, 32000]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"\n",
"if tokenizer.pad_token_id == None:\n",
" tokenizer.pad_token_id = tokenizer.eos_token_id\n",
"\n",
"target_max_len = max([len(tokenizer(class_lab)['input_ids']) for class_lab in classes])\n",
"target_max_len # max length for tokenized labels\n",
"\n",
"tokenizer(classes[0])['input_ids'] \n",
"# Ids corresponding to the tokens in the sequence\n",
"# Attention mask is a binary tensor used in the transformer block to differentiate between padding tokens and meaningful ones"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "459d4f69-1d85-42e8-acac-b2c7983c3a33",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}