Granther commited on
Commit
20c8824
1 Parent(s): 9f2b8b1

Upload prompt_tune_phi3.ipnb with huggingface_hub

Browse files
Files changed (1) hide show
  1. prompt_tune_phi3.ipnb +281 -0
prompt_tune_phi3.ipnb ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 17,
6
+ "id": "3890292a-c99e-4367-955d-5883b93dba36",
7
+ "metadata": {
8
+ "scrolled": true
9
+ },
10
+ "outputs": [
11
+ {
12
+ "name": "stdout",
13
+ "output_type": "stream",
14
+ "text": [
15
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
16
+ "\u001b[0mCollecting flash-attn\n",
17
+ " Downloading flash_attn-2.5.9.post1.tar.gz (2.6 MB)\n",
18
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.6/2.6 MB\u001b[0m \u001b[31m24.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
19
+ "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n",
20
+ "\u001b[?25hRequirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (from flash-attn) (2.2.0)\n",
21
+ "Collecting einops (from flash-attn)\n",
22
+ " Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)\n",
23
+ "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.13.1)\n",
24
+ "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (4.9.0)\n",
25
+ "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (1.12)\n",
26
+ "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.1)\n",
27
+ "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.1.2)\n",
28
+ "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (2023.12.2)\n",
29
+ "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch->flash-attn) (2.1.3)\n",
30
+ "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch->flash-attn) (1.3.0)\n",
31
+ "Downloading einops-0.8.0-py3-none-any.whl (43 kB)\n",
32
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
33
+ "\u001b[?25hBuilding wheels for collected packages: flash-attn\n",
34
+ " Building wheel for flash-attn (setup.py) ... \u001b[?25ldone\n",
35
+ "\u001b[?25h Created wheel for flash-attn: filename=flash_attn-2.5.9.post1-cp310-cp310-linux_x86_64.whl size=120821333 sha256=7bfd5ecaaf20577cd1255eaa90d9008a09050b3408ba6388bcbc5b6144f482d0\n",
36
+ " Stored in directory: /root/.cache/pip/wheels/cc/ad/f6/7ccf0238790d6346e9fe622923a76ec218e890d356b9a2754a\n",
37
+ "Successfully built flash-attn\n",
38
+ "Installing collected packages: einops, flash-attn\n",
39
+ "Successfully installed einops-0.8.0 flash-attn-2.5.9.post1\n",
40
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
41
+ "\u001b[0m"
42
+ ]
43
+ }
44
+ ],
45
+ "source": [
46
+ "!pip install -q peft transformers datasets huggingface_hub\n",
47
+ "!pip install flash-attn --no-build-isolation"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": 20,
53
+ "id": "f1cc378f-afb6-441f-a4c6-2ec427b4cd4b",
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup\n",
58
+ "from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType\n",
59
+ "import torch\n",
60
+ "from datasets import load_dataset\n",
61
+ "import os\n",
62
+ "from torch.utils.data import DataLoader\n",
63
+ "from tqdm import tqdm\n",
64
+ "from huggingface_hub import notebook_login\n",
65
+ "from huggingface_hub import HfApi"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": 19,
71
+ "id": "e4ab50d7-a4c9-4246-acd8-8875b87fe0da",
72
+ "metadata": {},
73
+ "outputs": [
74
+ {
75
+ "data": {
76
+ "application/vnd.jupyter.widget-view+json": {
77
+ "model_id": "baaa64cf8c0d415ba41abf52b03667b5",
78
+ "version_major": 2,
79
+ "version_minor": 0
80
+ },
81
+ "text/plain": [
82
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
83
+ ]
84
+ },
85
+ "metadata": {},
86
+ "output_type": "display_data"
87
+ }
88
+ ],
89
+ "source": [
90
+ "notebook_login()"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": 21,
96
+ "id": "8a1cb1f9-b89d-4cac-a595-44e1e0ef85b2",
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "api = HfApi()\n",
101
+ "api.upload_file(path_or_fileobj='Granther/prompt-tuned-phi3',\n",
102
+ " path_in_repo='"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": 6,
108
+ "id": "6cad1e5c-038f-4e75-8c3f-8ce0a43713a4",
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "device = 'cuda'\n",
113
+ "\n",
114
+ "model_id = 'microsoft/Phi-3-mini-128k-instruct'\n",
115
+ "\n",
116
+ "peft_conf = PromptTuningConfig(\n",
117
+ " peft_type=PeftType.PROMPT_TUNING, # what kind of peft\n",
118
+ " task_type=TaskType.CAUSAL_LM, # config task\n",
119
+ " prompt_tuning_init=PromptTuningInit.TEXT, # Set to 'TEXT' to use prompt_tuning_init_text\n",
120
+ " num_virtual_tokens=8, # x times the number of hidden transformer layers\n",
121
+ " prompt_tuning_init_text=\"Classify if the tweet is a complaint or not:\",\n",
122
+ " tokenizer_name_or_path=model_id\n",
123
+ ")\n",
124
+ "\n",
125
+ "dataset_name = \"twitter_complaints\"\n",
126
+ "checkpoint_name = f\"{dataset_name}_{model_id}_{peft_conf.peft_type}_{peft_conf.task_type}_v1.pt\".replace(\n",
127
+ " \"/\", \"_\"\n",
128
+ ")\n",
129
+ "\n",
130
+ "text_col = 'Tweet text'\n",
131
+ "lab_col = 'text_label'\n",
132
+ "max_len = 64\n",
133
+ "lr = 3e-2\n",
134
+ "epochs = 50\n",
135
+ "batch_size = 8"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": 7,
141
+ "id": "6f677839-ef23-428a-bcfe-f596590804ca",
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "dataset = load_dataset('ought/raft', dataset_name, split='train')"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": 8,
151
+ "id": "c0c05613-7941-4959-ada9-49ed1093bec4",
152
+ "metadata": {},
153
+ "outputs": [
154
+ {
155
+ "data": {
156
+ "text/plain": [
157
+ "['Unlabeled', 'complaint', 'no complaint']"
158
+ ]
159
+ },
160
+ "execution_count": 8,
161
+ "metadata": {},
162
+ "output_type": "execute_result"
163
+ }
164
+ ],
165
+ "source": [
166
+ "dataset.features['Label'].names\n",
167
+ "#>>> ['Unlabeled', 'complaint', 'no complaint']"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": 11,
173
+ "id": "14e2bc8b-b4e3-49c9-ae2b-5946e412caa5",
174
+ "metadata": {},
175
+ "outputs": [
176
+ {
177
+ "data": {
178
+ "application/vnd.jupyter.widget-view+json": {
179
+ "model_id": "d9e958c687dd493880d18d4f1621dad9",
180
+ "version_major": 2,
181
+ "version_minor": 0
182
+ },
183
+ "text/plain": [
184
+ "Map (num_proc=10): 0%| | 0/50 [00:00<?, ? examples/s]"
185
+ ]
186
+ },
187
+ "metadata": {},
188
+ "output_type": "display_data"
189
+ },
190
+ {
191
+ "data": {
192
+ "text/plain": [
193
+ "'Unlabeled'"
194
+ ]
195
+ },
196
+ "execution_count": 11,
197
+ "metadata": {},
198
+ "output_type": "execute_result"
199
+ }
200
+ ],
201
+ "source": [
202
+ "# Create lambda function\n",
203
+ "classes = [k.replace('_', ' ') for k in dataset.features['Label'].names]\n",
204
+ "dataset = dataset.map(\n",
205
+ " lambda x: {'text_label': [classes[label] for label in x['Label']]},\n",
206
+ " batched=True,\n",
207
+ " num_proc=10,\n",
208
+ ")\n",
209
+ "\n",
210
+ "dataset[0]"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": 16,
216
+ "id": "19f0865d-e490-4c9f-a5f4-e781ed270f47",
217
+ "metadata": {},
218
+ "outputs": [
219
+ {
220
+ "name": "stderr",
221
+ "output_type": "stream",
222
+ "text": [
223
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
224
+ ]
225
+ },
226
+ {
227
+ "data": {
228
+ "text/plain": [
229
+ "[1, 853, 29880, 24025, 32000]"
230
+ ]
231
+ },
232
+ "execution_count": 16,
233
+ "metadata": {},
234
+ "output_type": "execute_result"
235
+ }
236
+ ],
237
+ "source": [
238
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
239
+ "\n",
240
+ "if tokenizer.pad_token_id == None:\n",
241
+ " tokenizer.pad_token_id = tokenizer.eos_token_id\n",
242
+ "\n",
243
+ "target_max_len = max([len(tokenizer(class_lab)['input_ids']) for class_lab in classes])\n",
244
+ "target_max_len # max length for tokenized labels\n",
245
+ "\n",
246
+ "tokenizer(classes[0])['input_ids'] \n",
247
+ "# Ids corresponding to the tokens in the sequence\n",
248
+ "# Attention mask is a binary tensor used in the transformer block to differentiate between padding tokens and meaningful ones"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": null,
254
+ "id": "459d4f69-1d85-42e8-acac-b2c7983c3a33",
255
+ "metadata": {},
256
+ "outputs": [],
257
+ "source": []
258
+ }
259
+ ],
260
+ "metadata": {
261
+ "kernelspec": {
262
+ "display_name": "Python 3 (ipykernel)",
263
+ "language": "python",
264
+ "name": "python3"
265
+ },
266
+ "language_info": {
267
+ "codemirror_mode": {
268
+ "name": "ipython",
269
+ "version": 3
270
+ },
271
+ "file_extension": ".py",
272
+ "mimetype": "text/x-python",
273
+ "name": "python",
274
+ "nbconvert_exporter": "python",
275
+ "pygments_lexer": "ipython3",
276
+ "version": "3.10.13"
277
+ }
278
+ },
279
+ "nbformat": 4,
280
+ "nbformat_minor": 5
281
+ }