ryefoxlime commited on
Commit
3301578
1 Parent(s): b37902d

Fine Tuned v0.0.1

Browse files
.gitignore CHANGED
@@ -3,6 +3,8 @@
3
  FER/Images/
4
  TADBot.code-workspace
5
  FER/models/checkpoints
6
- **\*/__pycache__/\*
7
- **\*/.ipynb_checkpoints/\*
8
- **\*/.cache/\*
 
 
 
3
  FER/Images/
4
  TADBot.code-workspace
5
  FER/models/checkpoints
6
+ FER/__pycache__
7
+ FER/models/__pycache__
8
+ Gemma2_2B/.cache
9
+ Gemma2_2B/__pycache__
10
+ Gemma2_2B/results
Gemma2_2B/finetune.ipynb ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from huggingface_hub import login\n",
10
+ "from dotenv import load_dotenv\n",
11
+ "import os\n",
12
+ "load_dotenv()\n",
13
+ "\n",
14
+ "# Login to Hugging Face Hub\n",
15
+ "login(token=os.getenv(\"HUGGINGFACE_TOKEN\"))"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": 10,
21
+ "metadata": {},
22
+ "outputs": [
23
+ {
24
+ "data": {
25
+ "application/vnd.jupyter.widget-view+json": {
26
+ "model_id": "a39e6120cbea4462999cfa5f887a8015",
27
+ "version_major": 2,
28
+ "version_minor": 0
29
+ },
30
+ "text/plain": [
31
+ "README.md: 0%| | 0.00/288 [00:00<?, ?B/s]"
32
+ ]
33
+ },
34
+ "metadata": {},
35
+ "output_type": "display_data"
36
+ },
37
+ {
38
+ "name": "stderr",
39
+ "output_type": "stream",
40
+ "text": [
41
+ "f:\\TADBot\\.venv\\Lib\\site-packages\\huggingface_hub\\file_download.py:139: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\Nitin Kausik Remella\\.cache\\huggingface\\hub\\datasets--ai-bites--databricks-mini. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
42
+ "To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
43
+ " warnings.warn(message)\n"
44
+ ]
45
+ },
46
+ {
47
+ "data": {
48
+ "application/vnd.jupyter.widget-view+json": {
49
+ "model_id": "de15e48751c34c36b5d02c2449380d06",
50
+ "version_major": 2,
51
+ "version_minor": 0
52
+ },
53
+ "text/plain": [
54
+ "dolly-mini-train.jsonl: 0%| | 0.00/5.24M [00:00<?, ?B/s]"
55
+ ]
56
+ },
57
+ "metadata": {},
58
+ "output_type": "display_data"
59
+ },
60
+ {
61
+ "data": {
62
+ "application/vnd.jupyter.widget-view+json": {
63
+ "model_id": "d4094fd4af084a77a5bc3904b5db4197",
64
+ "version_major": 2,
65
+ "version_minor": 0
66
+ },
67
+ "text/plain": [
68
+ "Generating train split: 0%| | 0/10544 [00:00<?, ? examples/s]"
69
+ ]
70
+ },
71
+ "metadata": {},
72
+ "output_type": "display_data"
73
+ },
74
+ {
75
+ "data": {
76
+ "text/plain": [
77
+ "Dataset({\n",
78
+ " features: ['text'],\n",
79
+ " num_rows: 1000\n",
80
+ "})"
81
+ ]
82
+ },
83
+ "execution_count": 10,
84
+ "metadata": {},
85
+ "output_type": "execute_result"
86
+ }
87
+ ],
88
+ "source": [
89
+ "from datasets import load_dataset\n",
90
+ "dataset_name = \"ai-bites/databricks-mini\"\n",
91
+ "dataset = load_dataset(dataset_name, split=\"train[0:1000]\", cache_dir=\".cache/\")\n",
92
+ "\n",
93
+ "dataset"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": 11,
99
+ "metadata": {},
100
+ "outputs": [],
101
+ "source": [
102
+ "import torch\n",
103
+ "from transformers import (\n",
104
+ " AutoModelForCausalLM,\n",
105
+ " AutoTokenizer,\n",
106
+ " BitsAndBytesConfig,\n",
107
+ " HfArgumentParser,\n",
108
+ " TrainingArguments,\n",
109
+ " logging,\n",
110
+ ")\n",
111
+ "from peft import LoraConfig, PeftModel\n",
112
+ "from trl import SFTTrainer"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 30,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "import yaml\n",
122
+ "with open(\"hyperparams.yaml\", 'r') as file:\n",
123
+ " hyperparams = yaml.load(file, Loader=yaml.FullLoader)"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": 31,
129
+ "metadata": {},
130
+ "outputs": [],
131
+ "source": [
132
+ "compute_dtype = getattr(torch, hyperparams['bnb_4bit_compute_dtype'])\n",
133
+ "\n",
134
+ "bnb_config = BitsAndBytesConfig(\n",
135
+ " load_in_4bit=hyperparams['use_4bit'], # Activates 4-bit precision loading\n",
136
+ " bnb_4bit_quant_type=hyperparams['bnb_4bit_quant_type'], # nf4\n",
137
+ " bnb_4bit_compute_dtype=compute_dtype, # float16\n",
138
+ " bnb_4bit_use_double_quant=hyperparams['use_nested_quant'], # False\n",
139
+ ")"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": 32,
145
+ "metadata": {},
146
+ "outputs": [
147
+ {
148
+ "name": "stdout",
149
+ "output_type": "stream",
150
+ "text": [
151
+ "Setting BF16 to True\n"
152
+ ]
153
+ }
154
+ ],
155
+ "source": [
156
+ "# Check GPU compatibility with bfloat16\n",
157
+ "if compute_dtype == torch.float16 and hyperparams['use_4bit']:\n",
158
+ " major, _ = torch.cuda.get_device_capability()\n",
159
+ " if major >= 8:\n",
160
+ " print(\"Setting BF16 to True\")\n",
161
+ " hyperparams['bf16'] = True\n",
162
+ " else:\n",
163
+ " hyperparams['bf16'] = False"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": 33,
169
+ "metadata": {},
170
+ "outputs": [
171
+ {
172
+ "data": {
173
+ "application/vnd.jupyter.widget-view+json": {
174
+ "model_id": "9ab84ef6c43249de9726940a78f2717f",
175
+ "version_major": 2,
176
+ "version_minor": 0
177
+ },
178
+ "text/plain": [
179
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
180
+ ]
181
+ },
182
+ "metadata": {},
183
+ "output_type": "display_data"
184
+ }
185
+ ],
186
+ "source": [
187
+ "model = AutoModelForCausalLM.from_pretrained(\n",
188
+ " hyperparams['model_name'],\n",
189
+ " token=os.getenv(\"HUGGINGFACE_TOKEN\"),\n",
190
+ " quantization_config=bnb_config,\n",
191
+ " device_map=hyperparams['device_map'],\n",
192
+ " cache_dir=\".cache/\",\n",
193
+ ")\n",
194
+ "model.config.use_cache = False\n",
195
+ "model.config.pretraining_tp = 1\n",
196
+ "\n",
197
+ "tokenizer = AutoTokenizer.from_pretrained(hyperparams['model_name'], token=os.getenv(\"HUGGINGFACE_TOKEN\"), trust_remote_code=True, cache_dir=\".cache/\")\n",
198
+ "tokenizer.pad_token = tokenizer.eos_token\n",
199
+ "tokenizer.padding_side = \"right\" # Fix weird overflow issue with fp16 training"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": 34,
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "# Load LoRA configuration\n",
209
+ "peft_config = LoraConfig(\n",
210
+ " lora_alpha=hyperparams['lora_alpha'],\n",
211
+ " lora_dropout=hyperparams['lora_dropout'],\n",
212
+ " r=hyperparams['lora_r'],\n",
213
+ " bias=\"none\",\n",
214
+ " task_type=\"CAUSAL_LM\",\n",
215
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\"gate_proj\", \"up_proj\"]\n",
216
+ ")"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": 39,
222
+ "metadata": {},
223
+ "outputs": [
224
+ {
225
+ "data": {
226
+ "text/plain": [
227
+ "TrainingArguments(\n",
228
+ "_n_gpu=1,\n",
229
+ "accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},\n",
230
+ "adafactor=False,\n",
231
+ "adam_beta1=0.9,\n",
232
+ "adam_beta2=0.999,\n",
233
+ "adam_epsilon=1e-08,\n",
234
+ "auto_find_batch_size=False,\n",
235
+ "average_tokens_across_devices=False,\n",
236
+ "batch_eval_metrics=False,\n",
237
+ "bf16=True,\n",
238
+ "bf16_full_eval=False,\n",
239
+ "data_seed=None,\n",
240
+ "dataloader_drop_last=False,\n",
241
+ "dataloader_num_workers=0,\n",
242
+ "dataloader_persistent_workers=False,\n",
243
+ "dataloader_pin_memory=True,\n",
244
+ "dataloader_prefetch_factor=None,\n",
245
+ "ddp_backend=None,\n",
246
+ "ddp_broadcast_buffers=None,\n",
247
+ "ddp_bucket_cap_mb=None,\n",
248
+ "ddp_find_unused_parameters=None,\n",
249
+ "ddp_timeout=1800,\n",
250
+ "debug=[],\n",
251
+ "deepspeed=None,\n",
252
+ "disable_tqdm=False,\n",
253
+ "dispatch_batches=None,\n",
254
+ "do_eval=False,\n",
255
+ "do_predict=False,\n",
256
+ "do_train=False,\n",
257
+ "eval_accumulation_steps=None,\n",
258
+ "eval_delay=0,\n",
259
+ "eval_do_concat_batches=True,\n",
260
+ "eval_on_start=False,\n",
261
+ "eval_steps=None,\n",
262
+ "eval_strategy=IntervalStrategy.NO,\n",
263
+ "eval_use_gather_object=False,\n",
264
+ "evaluation_strategy=None,\n",
265
+ "fp16=False,\n",
266
+ "fp16_backend=auto,\n",
267
+ "fp16_full_eval=False,\n",
268
+ "fp16_opt_level=O1,\n",
269
+ "fsdp=[],\n",
270
+ "fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False},\n",
271
+ "fsdp_min_num_params=0,\n",
272
+ "fsdp_transformer_layer_cls_to_wrap=None,\n",
273
+ "full_determinism=False,\n",
274
+ "gradient_accumulation_steps=1,\n",
275
+ "gradient_checkpointing=False,\n",
276
+ "gradient_checkpointing_kwargs=None,\n",
277
+ "greater_is_better=None,\n",
278
+ "group_by_length=True,\n",
279
+ "half_precision_backend=auto,\n",
280
+ "hub_always_push=False,\n",
281
+ "hub_model_id=None,\n",
282
+ "hub_private_repo=False,\n",
283
+ "hub_strategy=HubStrategy.EVERY_SAVE,\n",
284
+ "hub_token=<HUB_TOKEN>,\n",
285
+ "ignore_data_skip=False,\n",
286
+ "include_for_metrics=[],\n",
287
+ "include_inputs_for_metrics=False,\n",
288
+ "include_num_input_tokens_seen=False,\n",
289
+ "include_tokens_per_second=False,\n",
290
+ "jit_mode_eval=False,\n",
291
+ "label_names=None,\n",
292
+ "label_smoothing_factor=0.0,\n",
293
+ "learning_rate=0.0002,\n",
294
+ "length_column_name=length,\n",
295
+ "load_best_model_at_end=False,\n",
296
+ "local_rank=0,\n",
297
+ "log_level=passive,\n",
298
+ "log_level_replica=warning,\n",
299
+ "log_on_each_node=True,\n",
300
+ "logging_dir=./results\\runs\\Nov15_13-14-10_FutureGadgetLab,\n",
301
+ "logging_first_step=False,\n",
302
+ "logging_nan_inf_filter=True,\n",
303
+ "logging_steps=25,\n",
304
+ "logging_strategy=IntervalStrategy.STEPS,\n",
305
+ "lr_scheduler_kwargs={},\n",
306
+ "lr_scheduler_type=SchedulerType.CONSTANT,\n",
307
+ "max_grad_norm=0.3,\n",
308
+ "max_steps=-1,\n",
309
+ "metric_for_best_model=None,\n",
310
+ "mp_parameters=,\n",
311
+ "neftune_noise_alpha=None,\n",
312
+ "no_cuda=False,\n",
313
+ "num_train_epochs=1,\n",
314
+ "optim=OptimizerNames.PAGED_ADAMW,\n",
315
+ "optim_args=None,\n",
316
+ "optim_target_modules=None,\n",
317
+ "output_dir=./results,\n",
318
+ "overwrite_output_dir=False,\n",
319
+ "past_index=-1,\n",
320
+ "per_device_eval_batch_size=8,\n",
321
+ "per_device_train_batch_size=2,\n",
322
+ "prediction_loss_only=False,\n",
323
+ "push_to_hub=False,\n",
324
+ "push_to_hub_model_id=None,\n",
325
+ "push_to_hub_organization=None,\n",
326
+ "push_to_hub_token=<PUSH_TO_HUB_TOKEN>,\n",
327
+ "ray_scope=last,\n",
328
+ "remove_unused_columns=True,\n",
329
+ "report_to=['tensorboard'],\n",
330
+ "restore_callback_states_from_checkpoint=False,\n",
331
+ "resume_from_checkpoint=None,\n",
332
+ "run_name=./results,\n",
333
+ "save_on_each_node=False,\n",
334
+ "save_only_model=False,\n",
335
+ "save_safetensors=True,\n",
336
+ "save_steps=25,\n",
337
+ "save_strategy=IntervalStrategy.STEPS,\n",
338
+ "save_total_limit=None,\n",
339
+ "seed=42,\n",
340
+ "skip_memory_metrics=True,\n",
341
+ "split_batches=None,\n",
342
+ "tf32=None,\n",
343
+ "torch_compile=False,\n",
344
+ "torch_compile_backend=None,\n",
345
+ "torch_compile_mode=None,\n",
346
+ "torch_empty_cache_steps=None,\n",
347
+ "torchdynamo=None,\n",
348
+ "tpu_metrics_debug=False,\n",
349
+ "tpu_num_cores=None,\n",
350
+ "use_cpu=False,\n",
351
+ "use_ipex=False,\n",
352
+ "use_legacy_prediction_loop=False,\n",
353
+ "use_liger_kernel=False,\n",
354
+ "use_mps_device=False,\n",
355
+ "warmup_ratio=0.03,\n",
356
+ "warmup_steps=0,\n",
357
+ "weight_decay=0.001,\n",
358
+ ")"
359
+ ]
360
+ },
361
+ "execution_count": 39,
362
+ "metadata": {},
363
+ "output_type": "execute_result"
364
+ }
365
+ ],
366
+ "source": [
367
+ "# Set training parameters\n",
368
+ "training_arguments = TrainingArguments(\n",
369
+ " output_dir=hyperparams['output_dir'],\n",
370
+ " num_train_epochs=hyperparams['num_train_epochs'],\n",
371
+ " per_device_train_batch_size=hyperparams['per_device_train_batch_size'],\n",
372
+ " gradient_accumulation_steps=hyperparams['gradient_accumulation_steps'],\n",
373
+ " optim=hyperparams['optimizer'],\n",
374
+ " save_steps=hyperparams['save_steps'],\n",
375
+ " logging_steps=hyperparams['logging_steps'],\n",
376
+ " learning_rate=float(hyperparams['learning_rate']),\n",
377
+ " weight_decay=hyperparams['weight_decay'],\n",
378
+ " fp16=hyperparams['fp16'],\n",
379
+ " bf16=hyperparams['bf16'],\n",
380
+ " max_grad_norm=hyperparams['max_grad_norm'],\n",
381
+ " max_steps=hyperparams['max_steps'],\n",
382
+ " warmup_ratio=hyperparams['warmup_ratio'],\n",
383
+ " group_by_length=hyperparams['group_by_length'],\n",
384
+ " lr_scheduler_type=hyperparams['lr_scheduler_type'],\n",
385
+ " report_to=\"tensorboard\",\n",
386
+ ")\n",
387
+ "training_arguments"
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "code",
392
+ "execution_count": 40,
393
+ "metadata": {},
394
+ "outputs": [
395
+ {
396
+ "name": "stderr",
397
+ "output_type": "stream",
398
+ "text": [
399
+ "f:\\TADBot\\.venv\\Lib\\site-packages\\huggingface_hub\\utils\\_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': dataset_text_field, max_seq_length, packing. Will not be supported from version '0.13.0'.\n",
400
+ "\n",
401
+ "Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.\n",
402
+ " warnings.warn(message, FutureWarning)\n",
403
+ "f:\\TADBot\\.venv\\Lib\\site-packages\\trl\\trainer\\sft_trainer.py:212: UserWarning: You passed a `packing` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n",
404
+ " warnings.warn(\n",
405
+ "f:\\TADBot\\.venv\\Lib\\site-packages\\trl\\trainer\\sft_trainer.py:300: UserWarning: You passed a `max_seq_length` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n",
406
+ " warnings.warn(\n",
407
+ "f:\\TADBot\\.venv\\Lib\\site-packages\\trl\\trainer\\sft_trainer.py:328: UserWarning: You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n",
408
+ " warnings.warn(\n"
409
+ ]
410
+ }
411
+ ],
412
+ "source": [
413
+ "trainer = SFTTrainer(\n",
414
+ " model=model,\n",
415
+ " train_dataset=dataset,\n",
416
+ " peft_config=peft_config,\n",
417
+ " dataset_text_field=\"text\",\n",
418
+ " # formatting_func=format_prompts_fn,\n",
419
+ " max_seq_length=hyperparams['max_seq_length'],\n",
420
+ " tokenizer=tokenizer,\n",
421
+ " args=training_arguments,\n",
422
+ " packing=hyperparams['packing'],\n",
423
+ ")"
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "code",
428
+ "execution_count": null,
429
+ "metadata": {},
430
+ "outputs": [
431
+ {
432
+ "data": {
433
+ "application/vnd.jupyter.widget-view+json": {
434
+ "model_id": "0033f5bb31a7416facfd8a3fd3bd5ad1",
435
+ "version_major": 2,
436
+ "version_minor": 0
437
+ },
438
+ "text/plain": [
439
+ " 0%| | 0/1340 [00:00<?, ?it/s]"
440
+ ]
441
+ },
442
+ "metadata": {},
443
+ "output_type": "display_data"
444
+ },
445
+ {
446
+ "name": "stdout",
447
+ "output_type": "stream",
448
+ "text": [
449
+ "{'loss': 3.8879, 'grad_norm': 18.030195236206055, 'learning_rate': 0.0002, 'epoch': 0.02}\n",
450
+ "{'loss': 2.9569, 'grad_norm': 9.667036056518555, 'learning_rate': 0.0002, 'epoch': 0.04}\n",
451
+ "{'loss': 2.6361, 'grad_norm': 9.089476585388184, 'learning_rate': 0.0002, 'epoch': 0.06}\n",
452
+ "{'loss': 2.9523, 'grad_norm': 6.053662300109863, 'learning_rate': 0.0002, 'epoch': 0.07}\n",
453
+ "{'loss': 2.8543, 'grad_norm': 7.764152526855469, 'learning_rate': 0.0002, 'epoch': 0.09}\n",
454
+ "{'loss': 2.8802, 'grad_norm': 6.539248466491699, 'learning_rate': 0.0002, 'epoch': 0.11}\n",
455
+ "{'loss': 2.7047, 'grad_norm': 5.485109329223633, 'learning_rate': 0.0002, 'epoch': 0.13}\n",
456
+ "{'loss': 2.6576, 'grad_norm': 9.22624397277832, 'learning_rate': 0.0002, 'epoch': 0.15}\n",
457
+ "{'loss': 2.7756, 'grad_norm': 6.477100372314453, 'learning_rate': 0.0002, 'epoch': 0.17}\n",
458
+ "{'loss': 2.7012, 'grad_norm': 5.891603946685791, 'learning_rate': 0.0002, 'epoch': 0.19}\n",
459
+ "{'loss': 2.5026, 'grad_norm': 5.75968599319458, 'learning_rate': 0.0002, 'epoch': 0.21}\n",
460
+ "{'loss': 2.8085, 'grad_norm': 7.938610076904297, 'learning_rate': 0.0002, 'epoch': 0.22}\n",
461
+ "{'loss': 2.5286, 'grad_norm': 5.600504398345947, 'learning_rate': 0.0002, 'epoch': 0.24}\n",
462
+ "{'loss': 2.5495, 'grad_norm': 6.746212005615234, 'learning_rate': 0.0002, 'epoch': 0.26}\n",
463
+ "{'loss': 2.7405, 'grad_norm': 3.8923749923706055, 'learning_rate': 0.0002, 'epoch': 0.28}\n",
464
+ "{'loss': 2.5657, 'grad_norm': 5.949460506439209, 'learning_rate': 0.0002, 'epoch': 0.3}\n",
465
+ "{'loss': 2.6052, 'grad_norm': 5.733223915100098, 'learning_rate': 0.0002, 'epoch': 0.32}\n",
466
+ "{'loss': 2.673, 'grad_norm': 6.0587310791015625, 'learning_rate': 0.0002, 'epoch': 0.34}\n",
467
+ "{'loss': 2.4631, 'grad_norm': 4.734077453613281, 'learning_rate': 0.0002, 'epoch': 0.35}\n",
468
+ "{'loss': 2.7288, 'grad_norm': 6.7847700119018555, 'learning_rate': 0.0002, 'epoch': 0.37}\n",
469
+ "{'loss': 2.7797, 'grad_norm': 5.118943214416504, 'learning_rate': 0.0002, 'epoch': 0.39}\n",
470
+ "{'loss': 2.8644, 'grad_norm': 5.4167304039001465, 'learning_rate': 0.0002, 'epoch': 0.41}\n",
471
+ "{'loss': 2.5779, 'grad_norm': 6.73247766494751, 'learning_rate': 0.0002, 'epoch': 0.43}\n",
472
+ "{'loss': 2.6459, 'grad_norm': 4.644010066986084, 'learning_rate': 0.0002, 'epoch': 0.45}\n",
473
+ "{'loss': 2.5321, 'grad_norm': 6.347738265991211, 'learning_rate': 0.0002, 'epoch': 0.47}\n",
474
+ "{'loss': 2.6865, 'grad_norm': 5.185911655426025, 'learning_rate': 0.0002, 'epoch': 0.49}\n",
475
+ "{'loss': 2.4668, 'grad_norm': 5.355742454528809, 'learning_rate': 0.0002, 'epoch': 0.5}\n",
476
+ "{'loss': 2.8465, 'grad_norm': 5.4434380531311035, 'learning_rate': 0.0002, 'epoch': 0.52}\n",
477
+ "{'loss': 2.7376, 'grad_norm': 4.8459882736206055, 'learning_rate': 0.0002, 'epoch': 0.54}\n",
478
+ "{'loss': 2.5205, 'grad_norm': 5.886116981506348, 'learning_rate': 0.0002, 'epoch': 0.56}\n",
479
+ "{'loss': 2.7473, 'grad_norm': 4.946981906890869, 'learning_rate': 0.0002, 'epoch': 0.58}\n",
480
+ "{'loss': 2.6824, 'grad_norm': 6.349016189575195, 'learning_rate': 0.0002, 'epoch': 0.6}\n",
481
+ "{'loss': 2.6485, 'grad_norm': 5.024953365325928, 'learning_rate': 0.0002, 'epoch': 0.62}\n",
482
+ "{'loss': 2.7172, 'grad_norm': 5.583380222320557, 'learning_rate': 0.0002, 'epoch': 0.63}\n",
483
+ "{'loss': 2.5879, 'grad_norm': 6.582890033721924, 'learning_rate': 0.0002, 'epoch': 0.65}\n"
484
+ ]
485
+ }
486
+ ],
487
+ "source": [
488
+ "trainer.train()\n",
489
+ "trainer.model.save_pretrained(hyperparams['new_model_name'])"
490
+ ]
491
+ }
492
+ ],
493
+ "metadata": {
494
+ "kernelspec": {
495
+ "display_name": ".venv",
496
+ "language": "python",
497
+ "name": "python3"
498
+ },
499
+ "language_info": {
500
+ "codemirror_mode": {
501
+ "name": "ipython",
502
+ "version": 3
503
+ },
504
+ "file_extension": ".py",
505
+ "mimetype": "text/x-python",
506
+ "name": "python",
507
+ "nbconvert_exporter": "python",
508
+ "pygments_lexer": "ipython3",
509
+ "version": "3.12.7"
510
+ }
511
+ },
512
+ "nbformat": 4,
513
+ "nbformat_minor": 2
514
+ }
Gemma2_2B/finetune.py DELETED
File without changes
Gemma2_2B/hyperparams.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "google/gemma-2-2b-it"
2
+ new_model_name: "gemma-2-2b-ft"
3
+
4
+ lora_r: 4
5
+ lora_alpha: 16
6
+ lora_dropout: 0.1
7
+
8
+ use_4bit: True
9
+ bnb_4bit_compute_dtype: "float16"
10
+ bnb_4bit_quant_type: "nf4"
11
+ use_nested_quant: False
12
+
13
+ output_dir: "./results"
14
+ num_train_epochs: 1
15
+ fp16: False
16
+ bf16: False
17
+ per_device_train_batch_size: 2
18
+ per_device_eval_batch_size: 2
19
+ gradient_accumulation_steps: 1
20
+ gradient_checkpointing: True
21
+ max_grad_norm: 0.3
22
+ learning_rate: 2e-4
23
+ weight_decay: 0.001
24
+ optimizer: "paged_adamw_32bit"
25
+ lr_scheduler_type: "constant"
26
+ max_steps: -1
27
+ warmup_ratio: 0.03
28
+ group_by_length: True
29
+ save_steps: 25
30
+ logging_steps: 25
31
+
32
+ max_seq_length: 40
33
+ packing: True
34
+ device_map: "auto"
Gemma2_2B/inference.ipynb ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from huggingface_hub import login\n",
10
+ "from dotenv import load_dotenv\n",
11
+ "import os\n",
12
+ "load_dotenv()\n",
13
+ "\n",
14
+ "# Login to Hugging Face Hub\n",
15
+ "login(token=os.getenv(\"HUGGINGFACE_TOKEN\"))"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": null,
21
+ "metadata": {},
22
+ "outputs": [
23
+ {
24
+ "data": {
25
+ "application/vnd.jupyter.widget-view+json": {
26
+ "model_id": "d00ec085003e409d906784abc1f89dc1",
27
+ "version_major": 2,
28
+ "version_minor": 0
29
+ },
30
+ "text/plain": [
31
+ "config.json: 0%| | 0.00/838 [00:00<?, ?B/s]"
32
+ ]
33
+ },
34
+ "metadata": {},
35
+ "output_type": "display_data"
36
+ },
37
+ {
38
+ "name": "stderr",
39
+ "output_type": "stream",
40
+ "text": [
41
+ "f:\\TADBot\\.venv\\Lib\\site-packages\\huggingface_hub\\file_download.py:139: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in F:\\TADBot\\Gemma2_2B\\.cache\\models--google--gemma-2-2b-it. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
42
+ "To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
43
+ " warnings.warn(message)\n"
44
+ ]
45
+ },
46
+ {
47
+ "data": {
48
+ "application/vnd.jupyter.widget-view+json": {
49
+ "model_id": "bdee67c51d7547a48e45f17db7fb3734",
50
+ "version_major": 2,
51
+ "version_minor": 0
52
+ },
53
+ "text/plain": [
54
+ "model.safetensors.index.json: 0%| | 0.00/24.2k [00:00<?, ?B/s]"
55
+ ]
56
+ },
57
+ "metadata": {},
58
+ "output_type": "display_data"
59
+ },
60
+ {
61
+ "data": {
62
+ "application/vnd.jupyter.widget-view+json": {
63
+ "model_id": "ad86eff32cc1447486e69c5f5f90e4a4",
64
+ "version_major": 2,
65
+ "version_minor": 0
66
+ },
67
+ "text/plain": [
68
+ "Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]"
69
+ ]
70
+ },
71
+ "metadata": {},
72
+ "output_type": "display_data"
73
+ },
74
+ {
75
+ "data": {
76
+ "application/vnd.jupyter.widget-view+json": {
77
+ "model_id": "78cab016a2d54731a94ef45e85d65ddd",
78
+ "version_major": 2,
79
+ "version_minor": 0
80
+ },
81
+ "text/plain": [
82
+ "model-00001-of-00002.safetensors: 0%| | 0.00/4.99G [00:00<?, ?B/s]"
83
+ ]
84
+ },
85
+ "metadata": {},
86
+ "output_type": "display_data"
87
+ },
88
+ {
89
+ "data": {
90
+ "application/vnd.jupyter.widget-view+json": {
91
+ "model_id": "52b50ff81d0d481ab475878606935162",
92
+ "version_major": 2,
93
+ "version_minor": 0
94
+ },
95
+ "text/plain": [
96
+ "model-00002-of-00002.safetensors: 0%| | 0.00/241M [00:00<?, ?B/s]"
97
+ ]
98
+ },
99
+ "metadata": {},
100
+ "output_type": "display_data"
101
+ },
102
+ {
103
+ "data": {
104
+ "application/vnd.jupyter.widget-view+json": {
105
+ "model_id": "7dfed61b7e0a4338aee7ad14df4d85ca",
106
+ "version_major": 2,
107
+ "version_minor": 0
108
+ },
109
+ "text/plain": [
110
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
111
+ ]
112
+ },
113
+ "metadata": {},
114
+ "output_type": "display_data"
115
+ },
116
+ {
117
+ "data": {
118
+ "application/vnd.jupyter.widget-view+json": {
119
+ "model_id": "9ac1e6a0b72a44d3a8a648bce2138c3d",
120
+ "version_major": 2,
121
+ "version_minor": 0
122
+ },
123
+ "text/plain": [
124
+ "generation_config.json: 0%| | 0.00/187 [00:00<?, ?B/s]"
125
+ ]
126
+ },
127
+ "metadata": {},
128
+ "output_type": "display_data"
129
+ },
130
+ {
131
+ "data": {
132
+ "application/vnd.jupyter.widget-view+json": {
133
+ "model_id": "f0129c204a454f22968aebe59b75ea1a",
134
+ "version_major": 2,
135
+ "version_minor": 0
136
+ },
137
+ "text/plain": [
138
+ "tokenizer_config.json: 0%| | 0.00/47.0k [00:00<?, ?B/s]"
139
+ ]
140
+ },
141
+ "metadata": {},
142
+ "output_type": "display_data"
143
+ },
144
+ {
145
+ "data": {
146
+ "application/vnd.jupyter.widget-view+json": {
147
+ "model_id": "ca55b303b11347cbbf5970327d2d8a82",
148
+ "version_major": 2,
149
+ "version_minor": 0
150
+ },
151
+ "text/plain": [
152
+ "tokenizer.model: 0%| | 0.00/4.24M [00:00<?, ?B/s]"
153
+ ]
154
+ },
155
+ "metadata": {},
156
+ "output_type": "display_data"
157
+ },
158
+ {
159
+ "data": {
160
+ "application/vnd.jupyter.widget-view+json": {
161
+ "model_id": "33601521ca8544e7a98c88506257dd20",
162
+ "version_major": 2,
163
+ "version_minor": 0
164
+ },
165
+ "text/plain": [
166
+ "tokenizer.json: 0%| | 0.00/17.5M [00:00<?, ?B/s]"
167
+ ]
168
+ },
169
+ "metadata": {},
170
+ "output_type": "display_data"
171
+ },
172
+ {
173
+ "data": {
174
+ "application/vnd.jupyter.widget-view+json": {
175
+ "model_id": "f353232bbf6b4da3ac62e02fa7f58990",
176
+ "version_major": 2,
177
+ "version_minor": 0
178
+ },
179
+ "text/plain": [
180
+ "special_tokens_map.json: 0%| | 0.00/636 [00:00<?, ?B/s]"
181
+ ]
182
+ },
183
+ "metadata": {},
184
+ "output_type": "display_data"
185
+ }
186
+ ],
187
+ "source": [
188
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
189
+ "model_name = \"google/gemma-2-2b-it\"\n",
190
+ "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\", cache_dir=\".cache/\")\n",
191
+ "tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=\".cache/\")"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": 6,
197
+ "metadata": {},
198
+ "outputs": [
199
+ {
200
+ "name": "stdout",
201
+ "output_type": "stream",
202
+ "text": [
203
+ "Gemma2ForCausalLM(\n",
204
+ " (model): Gemma2Model(\n",
205
+ " (embed_tokens): Embedding(256000, 2304, padding_idx=0)\n",
206
+ " (layers): ModuleList(\n",
207
+ " (0-25): 26 x Gemma2DecoderLayer(\n",
208
+ " (self_attn): Gemma2Attention(\n",
209
+ " (q_proj): Linear(in_features=2304, out_features=2048, bias=False)\n",
210
+ " (k_proj): Linear(in_features=2304, out_features=1024, bias=False)\n",
211
+ " (v_proj): Linear(in_features=2304, out_features=1024, bias=False)\n",
212
+ " (o_proj): Linear(in_features=2048, out_features=2304, bias=False)\n",
213
+ " (rotary_emb): Gemma2RotaryEmbedding()\n",
214
+ " )\n",
215
+ " (mlp): Gemma2MLP(\n",
216
+ " (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)\n",
217
+ " (up_proj): Linear(in_features=2304, out_features=9216, bias=False)\n",
218
+ " (down_proj): Linear(in_features=9216, out_features=2304, bias=False)\n",
219
+ " (act_fn): PytorchGELUTanh()\n",
220
+ " )\n",
221
+ " (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
222
+ " (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
223
+ " (post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
224
+ " (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
225
+ " )\n",
226
+ " )\n",
227
+ " (norm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
228
+ " )\n",
229
+ " (lm_head): Linear(in_features=2304, out_features=256000, bias=False)\n",
230
+ ")\n"
231
+ ]
232
+ }
233
+ ],
234
+ "source": [
235
+ "print(model)"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": 9,
241
+ "metadata": {},
242
+ "outputs": [
243
+ {
244
+ "name": "stdout",
245
+ "output_type": "stream",
246
+ "text": [
247
+ "<bos>What should I do on a trip to Europe?\n",
248
+ "\n",
249
+ "That's a great question! To give you the best advice, I need a little more information. Tell me about:\n",
250
+ "\n",
251
+ "**1. Your Interests:** \n",
252
+ " * What kind of things do you enjoy doing? (History, art, food, nightlife, nature, adventure, relaxation, etc.)\n",
253
+ " * Are there any specific places or activities you've always wanted to experience?\n",
254
+ "\n",
255
+ "**2. Your Travel Style:**\n",
256
+ " * Do you prefer to travel on your own, with a partner, or with a group?\n",
257
+ " * Do you like to plan everything in advance or be more spontaneous?\n",
258
+ " * What's your budget like?\n",
259
+ "\n",
260
+ "**3. Your Trip Details:**\n",
261
+ " * How long will you be traveling for?\n",
262
+ " * What time of year are you planning to go?\n",
263
+ " * Do you have any specific destinations in mind?\n",
264
+ "\n",
265
+ "Once I have this information, I can give you personalized recommendations for your European adventure! \n",
266
+ "<end_of_turn>\n",
267
+ "CPU times: total: 7.23 s\n",
268
+ "Wall time: 7.56 s\n"
269
+ ]
270
+ }
271
+ ],
272
+ "source": [
273
+ "%%time\n",
274
+ "input_text = \"What should I do on a trip to Europe?\"\n",
275
+ "\n",
276
+ "input_ids = tokenizer(input_text, return_tensors=\"pt\").to(\"cuda\")\n",
277
+ "outputs = model.generate(**input_ids, max_length=2048)\n",
278
+ "print(tokenizer.decode(outputs[0]))"
279
+ ]
280
+ }
281
+ ],
282
+ "metadata": {
283
+ "kernelspec": {
284
+ "display_name": ".venv",
285
+ "language": "python",
286
+ "name": "python3"
287
+ },
288
+ "language_info": {
289
+ "codemirror_mode": {
290
+ "name": "ipython",
291
+ "version": 3
292
+ },
293
+ "file_extension": ".py",
294
+ "mimetype": "text/x-python",
295
+ "name": "python",
296
+ "nbconvert_exporter": "python",
297
+ "pygments_lexer": "ipython3",
298
+ "version": "3.12.7"
299
+ }
300
+ },
301
+ "nbformat": 4,
302
+ "nbformat_minor": 2
303
+ }
Gemma2_2B/inference.py DELETED
File without changes
pyproject.toml CHANGED
@@ -25,6 +25,7 @@ dependencies = [
25
  "python-dotenv>=1.0.1",
26
  "ipykernel>=6.29.5",
27
  "ipywidgets>=8.1.5",
 
28
  ]
29
 
30
  [tool.uv.sources]
 
25
  "python-dotenv>=1.0.1",
26
  "ipykernel>=6.29.5",
27
  "ipywidgets>=8.1.5",
28
+ "pyyaml>=6.0.2",
29
  ]
30
 
31
  [tool.uv.sources]
uv.lock CHANGED
@@ -800,6 +800,7 @@ dependencies = [
800
  { name = "numpy", marker = "(platform_machine != 'aarch64' and python_full_version >= '3.12') or (platform_system != 'Linux' and python_full_version >= '3.12') or platform_system == 'Darwin' or (platform_machine == 'aarch64' and platform_system == 'Linux')" },
801
  { name = "peft", marker = "(platform_machine != 'aarch64' and python_full_version >= '3.12') or (platform_system != 'Linux' and python_full_version >= '3.12') or platform_system == 'Darwin' or (platform_machine == 'aarch64' and platform_system == 'Linux')" },
802
  { name = "python-dotenv", marker = "(platform_machine != 'aarch64' and python_full_version >= '3.12') or (platform_system != 'Linux' and python_full_version >= '3.12') or platform_system == 'Darwin' or (platform_machine == 'aarch64' and platform_system == 'Linux')" },
 
803
  { name = "ruff", marker = "(platform_machine != 'aarch64' and python_full_version >= '3.12') or (platform_system != 'Linux' and python_full_version >= '3.12') or platform_system == 'Darwin' or (platform_machine == 'aarch64' and platform_system == 'Linux')" },
804
  { name = "tensorboard", marker = "(platform_machine != 'aarch64' and python_full_version >= '3.12') or (platform_system != 'Linux' and python_full_version >= '3.12') or platform_system == 'Darwin' or (platform_machine == 'aarch64' and platform_system == 'Linux')" },
805
  { name = "thop", marker = "(platform_machine != 'aarch64' and python_full_version >= '3.12') or (platform_system != 'Linux' and python_full_version >= '3.12') or platform_system == 'Darwin' or (platform_machine == 'aarch64' and platform_system == 'Linux')" },
@@ -836,6 +837,7 @@ requires-dist = [
836
  { name = "numpy", specifier = ">=1.26.4" },
837
  { name = "peft", specifier = ">=0.13.2" },
838
  { name = "python-dotenv", specifier = ">=1.0.1" },
 
839
  { name = "ruff", specifier = ">=0.7.3" },
840
  { name = "tensorboard", specifier = ">=2.18.0" },
841
  { name = "thop", specifier = ">=0.1.1.post2209072238" },
 
800
  { name = "numpy", marker = "(platform_machine != 'aarch64' and python_full_version >= '3.12') or (platform_system != 'Linux' and python_full_version >= '3.12') or platform_system == 'Darwin' or (platform_machine == 'aarch64' and platform_system == 'Linux')" },
801
  { name = "peft", marker = "(platform_machine != 'aarch64' and python_full_version >= '3.12') or (platform_system != 'Linux' and python_full_version >= '3.12') or platform_system == 'Darwin' or (platform_machine == 'aarch64' and platform_system == 'Linux')" },
802
  { name = "python-dotenv", marker = "(platform_machine != 'aarch64' and python_full_version >= '3.12') or (platform_system != 'Linux' and python_full_version >= '3.12') or platform_system == 'Darwin' or (platform_machine == 'aarch64' and platform_system == 'Linux')" },
803
+ { name = "pyyaml", marker = "(platform_machine != 'aarch64' and python_full_version >= '3.12') or (platform_system != 'Linux' and python_full_version >= '3.12') or platform_system == 'Darwin' or (platform_machine == 'aarch64' and platform_system == 'Linux')" },
804
  { name = "ruff", marker = "(platform_machine != 'aarch64' and python_full_version >= '3.12') or (platform_system != 'Linux' and python_full_version >= '3.12') or platform_system == 'Darwin' or (platform_machine == 'aarch64' and platform_system == 'Linux')" },
805
  { name = "tensorboard", marker = "(platform_machine != 'aarch64' and python_full_version >= '3.12') or (platform_system != 'Linux' and python_full_version >= '3.12') or platform_system == 'Darwin' or (platform_machine == 'aarch64' and platform_system == 'Linux')" },
806
  { name = "thop", marker = "(platform_machine != 'aarch64' and python_full_version >= '3.12') or (platform_system != 'Linux' and python_full_version >= '3.12') or platform_system == 'Darwin' or (platform_machine == 'aarch64' and platform_system == 'Linux')" },
 
837
  { name = "numpy", specifier = ">=1.26.4" },
838
  { name = "peft", specifier = ">=0.13.2" },
839
  { name = "python-dotenv", specifier = ">=1.0.1" },
840
+ { name = "pyyaml", specifier = ">=6.0.2" },
841
  { name = "ruff", specifier = ">=0.7.3" },
842
  { name = "tensorboard", specifier = ">=2.18.0" },
843
  { name = "thop", specifier = ">=0.1.1.post2209072238" },