Granther commited on
Commit
912e66e
1 Parent(s): 1a005d8

Upload prompt_tune_phi3.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. prompt_tune_phi3.ipynb +246 -0
prompt_tune_phi3.ipynb ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "3890292a-c99e-4367-955d-5883b93dba36",
7
+ "metadata": {
8
+ "scrolled": true
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "!pip install -q peft transformers datasets huggingface_hub\n",
13
+ "!pip install flash-attn --no-build-isolation"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": 20,
19
+ "id": "f1cc378f-afb6-441f-a4c6-2ec427b4cd4b",
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup\n",
24
+ "from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType\n",
25
+ "import torch\n",
26
+ "from datasets import load_dataset\n",
27
+ "import os\n",
28
+ "from torch.utils.data import DataLoader\n",
29
+ "from tqdm import tqdm\n",
30
+ "from huggingface_hub import notebook_login\n",
31
+ "from huggingface_hub import HfApi"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "id": "e4ab50d7-a4c9-4246-acd8-8875b87fe0da",
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "notebook_login()"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 23,
47
+ "id": "8a1cb1f9-b89d-4cac-a595-44e1e0ef85b2",
48
+ "metadata": {},
49
+ "outputs": [
50
+ {
51
+ "data": {
52
+ "text/plain": [
53
+ "CommitInfo(commit_url='https://huggingface.co/Granther/prompt-tuned-phi3/commit/1a005d8478e96bc972732562f77846878c4ba7b3', commit_message='Upload prompt_tune_phi3.ipnb with huggingface_hub', commit_description='', oid='1a005d8478e96bc972732562f77846878c4ba7b3', pr_url=None, pr_revision=None, pr_num=None)"
54
+ ]
55
+ },
56
+ "execution_count": 23,
57
+ "metadata": {},
58
+ "output_type": "execute_result"
59
+ }
60
+ ],
61
+ "source": [
62
+ "api = HfApi()\n",
63
+ "api.upload_file(path_or_fileobj='prompt_tune_phi3.ipynb',\n",
64
+ " path_in_repo='prompt_tune_phi3.ipynb',\n",
65
+ " repo_id='Granther/prompt-tuned-phi3',\n",
66
+ " repo_type='model'\n",
67
+ " )"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": 6,
73
+ "id": "6cad1e5c-038f-4e75-8c3f-8ce0a43713a4",
74
+ "metadata": {},
75
+ "outputs": [],
76
+ "source": [
77
+ "device = 'cuda'\n",
78
+ "\n",
79
+ "model_id = 'microsoft/Phi-3-mini-128k-instruct'\n",
80
+ "\n",
81
+ "peft_conf = PromptTuningConfig(\n",
82
+ " peft_type=PeftType.PROMPT_TUNING, # what kind of peft\n",
83
+ " task_type=TaskType.CAUSAL_LM, # config task\n",
84
+ " prompt_tuning_init=PromptTuningInit.TEXT, # Set to 'TEXT' to use prompt_tuning_init_text\n",
85
+ " num_virtual_tokens=8, # x times the number of hidden transformer layers\n",
86
+ " prompt_tuning_init_text=\"Classify if the tweet is a complaint or not:\",\n",
87
+ " tokenizer_name_or_path=model_id\n",
88
+ ")\n",
89
+ "\n",
90
+ "dataset_name = \"twitter_complaints\"\n",
91
+ "checkpoint_name = f\"{dataset_name}_{model_id}_{peft_conf.peft_type}_{peft_conf.task_type}_v1.pt\".replace(\n",
92
+ " \"/\", \"_\"\n",
93
+ ")\n",
94
+ "\n",
95
+ "text_col = 'Tweet text'\n",
96
+ "lab_col = 'text_label'\n",
97
+ "max_len = 64\n",
98
+ "lr = 3e-2\n",
99
+ "epochs = 50\n",
100
+ "batch_size = 8"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": 7,
106
+ "id": "6f677839-ef23-428a-bcfe-f596590804ca",
107
+ "metadata": {},
108
+ "outputs": [],
109
+ "source": [
110
+ "dataset = load_dataset('ought/raft', dataset_name, split='train')"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": 8,
116
+ "id": "c0c05613-7941-4959-ada9-49ed1093bec4",
117
+ "metadata": {},
118
+ "outputs": [
119
+ {
120
+ "data": {
121
+ "text/plain": [
122
+ "['Unlabeled', 'complaint', 'no complaint']"
123
+ ]
124
+ },
125
+ "execution_count": 8,
126
+ "metadata": {},
127
+ "output_type": "execute_result"
128
+ }
129
+ ],
130
+ "source": [
131
+ "dataset.features['Label'].names\n",
132
+ "#>>> ['Unlabeled', 'complaint', 'no complaint']"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": 11,
138
+ "id": "14e2bc8b-b4e3-49c9-ae2b-5946e412caa5",
139
+ "metadata": {},
140
+ "outputs": [
141
+ {
142
+ "data": {
143
+ "application/vnd.jupyter.widget-view+json": {
144
+ "model_id": "d9e958c687dd493880d18d4f1621dad9",
145
+ "version_major": 2,
146
+ "version_minor": 0
147
+ },
148
+ "text/plain": [
149
+ "Map (num_proc=10): 0%| | 0/50 [00:00<?, ? examples/s]"
150
+ ]
151
+ },
152
+ "metadata": {},
153
+ "output_type": "display_data"
154
+ },
155
+ {
156
+ "data": {
157
+ "text/plain": [
158
+ "'Unlabeled'"
159
+ ]
160
+ },
161
+ "execution_count": 11,
162
+ "metadata": {},
163
+ "output_type": "execute_result"
164
+ }
165
+ ],
166
+ "source": [
167
+ "# Create lambda function\n",
168
+ "classes = [k.replace('_', ' ') for k in dataset.features['Label'].names]\n",
169
+ "dataset = dataset.map(\n",
170
+ " lambda x: {'text_label': [classes[label] for label in x['Label']]},\n",
171
+ " batched=True,\n",
172
+ " num_proc=10,\n",
173
+ ")\n",
174
+ "\n",
175
+ "dataset[0]"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": 16,
181
+ "id": "19f0865d-e490-4c9f-a5f4-e781ed270f47",
182
+ "metadata": {},
183
+ "outputs": [
184
+ {
185
+ "name": "stderr",
186
+ "output_type": "stream",
187
+ "text": [
188
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
189
+ ]
190
+ },
191
+ {
192
+ "data": {
193
+ "text/plain": [
194
+ "[1, 853, 29880, 24025, 32000]"
195
+ ]
196
+ },
197
+ "execution_count": 16,
198
+ "metadata": {},
199
+ "output_type": "execute_result"
200
+ }
201
+ ],
202
+ "source": [
203
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
204
+ "\n",
205
+ "if tokenizer.pad_token_id == None:\n",
206
+ " tokenizer.pad_token_id = tokenizer.eos_token_id\n",
207
+ "\n",
208
+ "target_max_len = max([len(tokenizer(class_lab)['input_ids']) for class_lab in classes])\n",
209
+ "target_max_len # max length for tokenized labels\n",
210
+ "\n",
211
+ "tokenizer(classes[0])['input_ids'] \n",
212
+ "# Ids corresponding to the tokens in the sequence\n",
213
+ "# Attention mask is a binary tensor used in the transformer block to differentiate between padding tokens and meaningful ones"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "id": "459d4f69-1d85-42e8-acac-b2c7983c3a33",
220
+ "metadata": {},
221
+ "outputs": [],
222
+ "source": []
223
+ }
224
+ ],
225
+ "metadata": {
226
+ "kernelspec": {
227
+ "display_name": "Python 3 (ipykernel)",
228
+ "language": "python",
229
+ "name": "python3"
230
+ },
231
+ "language_info": {
232
+ "codemirror_mode": {
233
+ "name": "ipython",
234
+ "version": 3
235
+ },
236
+ "file_extension": ".py",
237
+ "mimetype": "text/x-python",
238
+ "name": "python",
239
+ "nbconvert_exporter": "python",
240
+ "pygments_lexer": "ipython3",
241
+ "version": "3.10.13"
242
+ }
243
+ },
244
+ "nbformat": 4,
245
+ "nbformat_minor": 5
246
+ }