Matthew Hollings commited on
Commit
9de53c6
1 Parent(s): 5497d17

Fine-tune a GPT model and load into the interface.

Browse files
.gitignore CHANGED
@@ -1,3 +1,5 @@
1
  __pycache__
2
  flagged/
3
- gutenberg-dammit-files-v002.zip
 
 
 
1
  __pycache__
2
  flagged/
3
+ gutenberg-dammit-files-v002.zip
4
+ tmp_trainer
5
+ *.gz
README.md CHANGED
@@ -10,23 +10,29 @@ pinned: false
10
  ---
11
 
12
  - 1. fine-tune a large language model (LLM) using the text corpus of a specific poet
13
- - 1.1 if possible the entire poem should be used for generating the next line not
14
- just the last line
15
- - 2. build a web interface for a user to prompt and then respond
16
- - 2.1 the poem should persist on machine reload
17
- - 2.2 it should be possible to remove the last line and rerun
18
- - 2.3 retry to get a new response from the model
19
 
20
  run in a docker container and transfer to another machine
21
 
 
 
 
 
 
 
22
  ## Research
23
 
24
  <https://github.com/aparrish/gutenberg-dammit/>
25
- TODO:
26
- automatically activate conda env on cd in directory
27
  implement language generation with a basic transformer
28
- get the website running to display responses in a user friendly way
29
- Docker image?
30
 
31
  <https://github.com/aparrish/gutenberg-poetry-corpus>
32
  Gutenberg Poetry Autocomplete, a search engine-like interface for writing poems mined from Project Gutenberg. (A poem written using this interface was recently published in the Indianapolis Review!)
 
 
 
 
 
 
 
 
10
  ---
11
 
12
  - 1. fine-tune a large language model (LLM) using the text corpus of a specific poet
13
+
14
+ - select a certain rhyme from the gutenberg corpus and fine-tune on this
15
+ - try fine-tuning on a few lines of a poem that Eva has started
 
 
 
16
 
17
  run in a docker container and transfer to another machine
18
 
19
+ Is it better to have a sequence to sequence transformer trained on sucessive lines of the poetry corpus??
20
+
21
+ merve/poetry only has 573 rows.
22
+
23
+ TODO: - upload the gutenberg poetry corpus up to huggingface - ask the lady who made it
24
+
25
  ## Research
26
 
27
  <https://github.com/aparrish/gutenberg-dammit/>
 
 
28
  implement language generation with a basic transformer
 
 
29
 
30
  <https://github.com/aparrish/gutenberg-poetry-corpus>
31
  Gutenberg Poetry Autocomplete, a search engine-like interface for writing poems mined from Project Gutenberg. (A poem written using this interface was recently published in the Indianapolis Review!)
32
+
33
+ https://ymeadows.com/en-articles/fine-tuning-transformer-based-language-models
34
+ https://thegradient.pub/prompting/
35
+ https://towardsdatascience.com/fine-tuning-for-domain-adaptation-in-nlp-c47def356fd6
36
+ https://ruder.io/recent-advances-lm-fine-tuning/
37
+
38
+ https://streamlit.io/
app.py CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
3
  from transformers import pipeline
4
 
5
  # Set up the generatove model transformer pipeline
6
- generator = pipeline("text-generation", model="gpt2")
7
 
8
  # A sequence of lines both those typed in and the line so far
9
  # when save is clicked the txt file is downloaded
 
3
  from transformers import pipeline
4
 
5
  # Set up the generatove model transformer pipeline
6
+ generator = pipeline("text-generation", model="tmp_trainer")
7
 
8
  # A sequence of lines both those typed in and the line so far
9
  # when save is clicked the txt file is downloaded
fine-tune-llm.ipynb CHANGED
@@ -379,30 +379,32 @@
379
  },
380
  {
381
  "cell_type": "code",
382
- "execution_count": 17,
383
  "metadata": {},
384
  "outputs": [
385
  {
386
  "data": {
387
  "text/plain": [
388
- "{'Author': ['Jules Verne'],\n",
389
- " 'Author Birth': [1828],\n",
390
- " 'Author Death': [1905],\n",
391
- " 'Author Given': ['Jules'],\n",
392
- " 'Author Surname': ['Verne'],\n",
393
  " 'Copyright Status': ['Not copyrighted in the United States.'],\n",
394
  " 'Language': ['English'],\n",
395
- " 'LoC Class': ['PQ: Language and Literatures: Romance literatures: French, Italian, Spanish, Portuguese'],\n",
396
- " 'Num': '103',\n",
397
- " 'Subject': ['Adventure stories', 'Voyages around the world -- Fiction'],\n",
398
- " 'Title': ['Around the World in 80 Days'],\n",
 
 
399
  " 'charset': 'us-ascii',\n",
400
- " 'gd-num-padded': '00103',\n",
401
- " 'gd-path': '001/00103.txt',\n",
402
- " 'href': '/1/0/103/103.zip'}"
403
  ]
404
  },
405
- "execution_count": 17,
406
  "metadata": {},
407
  "output_type": "execute_result"
408
  }
@@ -410,7 +412,7 @@
410
  "source": [
411
  "from gutenbergdammit.ziputils import loadmetadata\n",
412
  "metadata = loadmetadata(\"gutenberg-dammit-files-v002.zip\")\n",
413
- "metadata[100]\n",
414
  "# ['Essays in the Art of Writing']"
415
  ]
416
  },
@@ -557,6 +559,118 @@
557
  "tf.config.list_physical_devices('CPU')"
558
  ]
559
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
  {
561
  "cell_type": "code",
562
  "execution_count": null,
 
379
  },
380
  {
381
  "cell_type": "code",
382
+ "execution_count": 23,
383
  "metadata": {},
384
  "outputs": [
385
  {
386
  "data": {
387
  "text/plain": [
388
+ "{'Author': ['Franklin Delano Roosevelt'],\n",
389
+ " 'Author Birth': [1882],\n",
390
+ " 'Author Death': [1945],\n",
391
+ " 'Author Given': ['Franklin Delano'],\n",
392
+ " 'Author Surname': ['Roosevelt'],\n",
393
  " 'Copyright Status': ['Not copyrighted in the United States.'],\n",
394
  " 'Language': ['English'],\n",
395
+ " 'LoC Class': ['E740: History: America: Twentieth century'],\n",
396
+ " 'Num': '104',\n",
397
+ " 'Subject': ['New Deal, 1933-1939',\n",
398
+ " 'Presidents -- United States -- Inaugural addresses',\n",
399
+ " 'United States -- Politics and government -- 1933-1945'],\n",
400
+ " 'Title': [\"Franklin Delano Roosevelt's First Inaugural Address\"],\n",
401
  " 'charset': 'us-ascii',\n",
402
+ " 'gd-num-padded': '00104',\n",
403
+ " 'gd-path': '001/00104.txt',\n",
404
+ " 'href': '/1/0/104/104.zip'}"
405
  ]
406
  },
407
+ "execution_count": 23,
408
  "metadata": {},
409
  "output_type": "execute_result"
410
  }
 
412
  "source": [
413
  "from gutenbergdammit.ziputils import loadmetadata\n",
414
  "metadata = loadmetadata(\"gutenberg-dammit-files-v002.zip\")\n",
415
+ "metadata[101]\n",
416
  "# ['Essays in the Art of Writing']"
417
  ]
418
  },
 
559
  "tf.config.list_physical_devices('CPU')"
560
  ]
561
  },
562
+ {
563
+ "cell_type": "markdown",
564
+ "metadata": {},
565
+ "source": [
566
+ "# Source data"
567
+ ]
568
+ },
569
+ {
570
+ "cell_type": "markdown",
571
+ "metadata": {},
572
+ "source": [
573
+ "curl -O http://static.decontextualize.com/gutenberg-poetry-v001.ndjson.gz"
574
+ ]
575
+ },
576
+ {
577
+ "cell_type": "code",
578
+ "execution_count": 25,
579
+ "metadata": {},
580
+ "outputs": [],
581
+ "source": [
582
+ "import gzip, json\n",
583
+ "all_lines = []\n",
584
+ "for line in gzip.open(\"gutenberg-poetry-v001.ndjson.gz\"):\n",
585
+ " all_lines.append(json.loads(line.strip()))"
586
+ ]
587
+ },
588
+ {
589
+ "cell_type": "code",
590
+ "execution_count": 34,
591
+ "metadata": {},
592
+ "outputs": [
593
+ {
594
+ "name": "stdout",
595
+ "output_type": "stream",
596
+ "text": [
597
+ "[{'s': 'The Song of Hiawatha is based on the legends and stories of', 'gid': '19'}, {'s': 'many North American Indian tribes, but especially those of the', 'gid': '19'}, {'s': 'Ojibway Indians of northern Michigan, Wisconsin, and Minnesota.', 'gid': '19'}, {'s': 'They were collected by Henry Rowe Schoolcraft, the reknowned', 'gid': '19'}, {'s': 'Schoolcraft married Jane, O-bah-bahm-wawa-ge-zhe-go-qua (The', 'gid': '19'}, {'s': 'fur trader, and O-shau-gus-coday-way-qua (The Woman of the Green', 'gid': '19'}, {'s': 'Prairie), who was a daughter of Waub-o-jeeg (The White Fisher),', 'gid': '19'}, {'s': 'who was Chief of the Ojibway tribe at La Pointe, Wisconsin.', 'gid': '19'}, {'s': 'Jane and her mother are credited with having researched,', 'gid': '19'}, {'s': 'authenticated, and compiled much of the material Schoolcraft', 'gid': '19'}]\n"
598
+ ]
599
+ }
600
+ ],
601
+ "source": [
602
+ "import random\n",
603
+ "random.sample(all_lines, 8)\n",
604
+ "\n",
605
+ "print(all_lines[0:10])\n",
606
+ "\n"
607
+ ]
608
+ },
609
+ {
610
+ "cell_type": "code",
611
+ "execution_count": 33,
612
+ "metadata": {},
613
+ "outputs": [
614
+ {
615
+ "data": {
616
+ "text/plain": [
617
+ "{'Author': ['Henry Rider Haggard'],\n",
618
+ " 'Author Birth': [1856],\n",
619
+ " 'Author Death': [1925],\n",
620
+ " 'Author Given': ['Henry Rider'],\n",
621
+ " 'Author Surname': ['Haggard'],\n",
622
+ " 'Copyright Status': ['Not copyrighted in the United States.'],\n",
623
+ " 'Language': ['English'],\n",
624
+ " 'LoC Class': ['PR: Language and Literatures: English literature'],\n",
625
+ " 'Num': '2721',\n",
626
+ " 'Subject': ['Iceland -- Fiction'],\n",
627
+ " 'Title': ['Eric Brighteyes'],\n",
628
+ " 'charset': 'iso-8859-1',\n",
629
+ " 'gd-num-padded': '02721',\n",
630
+ " 'gd-path': '027/02721.txt',\n",
631
+ " 'href': '/2/7/2/2721/2721_8.zip'}"
632
+ ]
633
+ },
634
+ "execution_count": 33,
635
+ "metadata": {},
636
+ "output_type": "execute_result"
637
+ }
638
+ ],
639
+ "source": [
640
+ "from gutenbergdammit.ziputils import loadmetadata\n",
641
+ "metadata = loadmetadata(\"gutenberg-dammit-files-v002.zip\")\n",
642
+ "metadata[2620]"
643
+ ]
644
+ },
645
+ {
646
+ "cell_type": "code",
647
+ "execution_count": 37,
648
+ "metadata": {},
649
+ "outputs": [
650
+ {
651
+ "data": {
652
+ "text/plain": [
653
+ "['The Song of Hiawatha is based on the legends and stories of',\n",
654
+ " 'many North American Indian tribes, but especially those of the',\n",
655
+ " 'Ojibway Indians of northern Michigan, Wisconsin, and Minnesota.',\n",
656
+ " 'They were collected by Henry Rowe Schoolcraft, the reknowned',\n",
657
+ " 'Schoolcraft married Jane, O-bah-bahm-wawa-ge-zhe-go-qua (The',\n",
658
+ " 'fur trader, and O-shau-gus-coday-way-qua (The Woman of the Green',\n",
659
+ " 'Prairie), who was a daughter of Waub-o-jeeg (The White Fisher),',\n",
660
+ " 'who was Chief of the Ojibway tribe at La Pointe, Wisconsin.',\n",
661
+ " 'Jane and her mother are credited with having researched,',\n",
662
+ " 'authenticated, and compiled much of the material Schoolcraft']"
663
+ ]
664
+ },
665
+ "execution_count": 37,
666
+ "metadata": {},
667
+ "output_type": "execute_result"
668
+ }
669
+ ],
670
+ "source": [
671
+ "[line['s'] for line in all_lines[0:10]]"
672
+ ]
673
+ },
674
  {
675
  "cell_type": "code",
676
  "execution_count": null,
fine-tuning-for-casual-language-model.ipynb ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "# https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 43,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import transformers\n",
19
+ "from transformers import (\n",
20
+ " CONFIG_MAPPING,\n",
21
+ " MODEL_FOR_CAUSAL_LM_MAPPING,\n",
22
+ " AutoConfig,\n",
23
+ " AutoModelForCausalLM,\n",
24
+ " AutoTokenizer,\n",
25
+ " HfArgumentParser,\n",
26
+ " Trainer,\n",
27
+ " TrainingArguments,\n",
28
+ " default_data_collator,\n",
29
+ " is_torch_tpu_available,\n",
30
+ " set_seed,\n",
31
+ ")\n",
32
+ "\n",
33
+ "from itertools import chain\n",
34
+ "\n",
35
+ "from transformers.testing_utils import CaptureLogger\n",
36
+ "from transformers.trainer_utils import get_last_checkpoint\n",
37
+ "# from transformers.utils import check_min_version, send_example_telemetry\n",
38
+ "from transformers.utils.versions import require_version\n",
39
+ "\n",
40
+ "import datasets\n",
41
+ "from datasets import load_dataset"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 4,
47
+ "metadata": {},
48
+ "outputs": [
49
+ {
50
+ "ename": "ImportError",
51
+ "evalue": "This example requires a source install from HuggingFace Transformers (see `https://huggingface.co/transformers/installation.html#installing-from-source`), but the version found is 4.11.3.\nCheck out https://huggingface.co/transformers/examples.html for the examples corresponding to other versions of HuggingFace Transformers.",
52
+ "output_type": "error",
53
+ "traceback": [
54
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
55
+ "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
56
+ "Cell \u001b[0;32mIn [4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mcheck_min_version\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m4.23.0.dev0\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
57
+ "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/augmented_poetry/lib/python3.8/site-packages/transformers/utils/__init__.py:32\u001b[0m, in \u001b[0;36mcheck_min_version\u001b[0;34m(min_version)\u001b[0m\n\u001b[1;32m 30\u001b[0m error_message \u001b[39m=\u001b[39m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mThis example requires a minimum version of \u001b[39m\u001b[39m{\u001b[39;00mmin_version\u001b[39m}\u001b[39;00m\u001b[39m,\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 31\u001b[0m error_message \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m but the version found is \u001b[39m\u001b[39m{\u001b[39;00m__version__\u001b[39m}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[0;32m---> 32\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mImportError\u001b[39;00m(\n\u001b[1;32m 33\u001b[0m error_message\n\u001b[1;32m 34\u001b[0m \u001b[39m+\u001b[39m (\n\u001b[1;32m 35\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mCheck out https://huggingface.co/transformers/examples.html for the examples corresponding to other \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 36\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mversions of HuggingFace Transformers.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 37\u001b[0m )\n\u001b[1;32m 38\u001b[0m )\n",
58
+ "\u001b[0;31mImportError\u001b[0m: This example requires a source install from HuggingFace Transformers (see `https://huggingface.co/transformers/installation.html#installing-from-source`), but the version found is 4.11.3.\nCheck out https://huggingface.co/transformers/examples.html for the examples corresponding to other versions of HuggingFace Transformers."
59
+ ]
60
+ }
61
+ ],
62
+ "source": [
63
+ "# check_min_version(\"4.23.0.dev0\")"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 9,
69
+ "metadata": {},
70
+ "outputs": [],
71
+ "source": [
72
+ "require_version(\"datasets>=1.8.0\")"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": 5,
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "set_seed(37)"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "markdown",
86
+ "metadata": {},
87
+ "source": [
88
+ "##### Get all of the huggingface objects that we need: tokenzier, gpt2 model, poetry dataset."
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": 10,
94
+ "metadata": {},
95
+ "outputs": [
96
+ {
97
+ "name": "stderr",
98
+ "output_type": "stream",
99
+ "text": [
100
+ "Using custom data configuration merve--poetry-ca9a13ef5858cc3a\n"
101
+ ]
102
+ },
103
+ {
104
+ "name": "stdout",
105
+ "output_type": "stream",
106
+ "text": [
107
+ "Downloading and preparing dataset csv/merve--poetry to /Users/matth/.cache/huggingface/datasets/merve___csv/merve--poetry-ca9a13ef5858cc3a/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a...\n"
108
+ ]
109
+ },
110
+ {
111
+ "data": {
112
+ "application/vnd.jupyter.widget-view+json": {
113
+ "model_id": "ed56ee6b324647798b19ac7bf5accc40",
114
+ "version_major": 2,
115
+ "version_minor": 0
116
+ },
117
+ "text/plain": [
118
+ "Downloading data files: 0%| | 0/1 [00:00<?, ?it/s]"
119
+ ]
120
+ },
121
+ "metadata": {},
122
+ "output_type": "display_data"
123
+ },
124
+ {
125
+ "data": {
126
+ "application/vnd.jupyter.widget-view+json": {
127
+ "model_id": "32c10441ff20404cb153f6b27f16a829",
128
+ "version_major": 2,
129
+ "version_minor": 0
130
+ },
131
+ "text/plain": [
132
+ "Downloading data: 0%| | 0.00/606k [00:00<?, ?B/s]"
133
+ ]
134
+ },
135
+ "metadata": {},
136
+ "output_type": "display_data"
137
+ },
138
+ {
139
+ "data": {
140
+ "application/vnd.jupyter.widget-view+json": {
141
+ "model_id": "7ca47bc06937463e91d3948d7703ac64",
142
+ "version_major": 2,
143
+ "version_minor": 0
144
+ },
145
+ "text/plain": [
146
+ "Extracting data files: 0%| | 0/1 [00:00<?, ?it/s]"
147
+ ]
148
+ },
149
+ "metadata": {},
150
+ "output_type": "display_data"
151
+ },
152
+ {
153
+ "data": {
154
+ "application/vnd.jupyter.widget-view+json": {
155
+ "model_id": "1631dbdc53d04b14a8a7733883bbd1cc",
156
+ "version_major": 2,
157
+ "version_minor": 0
158
+ },
159
+ "text/plain": [
160
+ "0 tables [00:00, ? tables/s]"
161
+ ]
162
+ },
163
+ "metadata": {},
164
+ "output_type": "display_data"
165
+ },
166
+ {
167
+ "name": "stdout",
168
+ "output_type": "stream",
169
+ "text": [
170
+ "Dataset csv downloaded and prepared to /Users/matth/.cache/huggingface/datasets/merve___csv/merve--poetry-ca9a13ef5858cc3a/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a. Subsequent calls will reuse this data.\n"
171
+ ]
172
+ },
173
+ {
174
+ "data": {
175
+ "application/vnd.jupyter.widget-view+json": {
176
+ "model_id": "3c93229d66ad46d9a88da5f6a9528f2e",
177
+ "version_major": 2,
178
+ "version_minor": 0
179
+ },
180
+ "text/plain": [
181
+ " 0%| | 0/1 [00:00<?, ?it/s]"
182
+ ]
183
+ },
184
+ "metadata": {},
185
+ "output_type": "display_data"
186
+ }
187
+ ],
188
+ "source": [
189
+ "raw_datasets = load_dataset(\"merve/poetry\")"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": 12,
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "tokenizer = AutoTokenizer.from_pretrained('gpt2')"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": 13,
204
+ "metadata": {},
205
+ "outputs": [],
206
+ "source": [
207
+ "config = AutoConfig.from_pretrained('gpt2')"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": 16,
213
+ "metadata": {},
214
+ "outputs": [
215
+ {
216
+ "data": {
217
+ "text/plain": [
218
+ "Embedding(50257, 768)"
219
+ ]
220
+ },
221
+ "execution_count": 16,
222
+ "metadata": {},
223
+ "output_type": "execute_result"
224
+ }
225
+ ],
226
+ "source": [
227
+ "model = AutoModelForCausalLM.from_pretrained(\n",
228
+ " \"gpt2\",\n",
229
+ " config=config\n",
230
+ ")\n",
231
+ "model.resize_token_embeddings(len(tokenizer))"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": 24,
237
+ "metadata": {},
238
+ "outputs": [
239
+ {
240
+ "data": {
241
+ "text/plain": [
242
+ "Dataset({\n",
243
+ " features: ['author', 'content', 'poem name', 'age', 'type'],\n",
244
+ " num_rows: 573\n",
245
+ "})"
246
+ ]
247
+ },
248
+ "execution_count": 24,
249
+ "metadata": {},
250
+ "output_type": "execute_result"
251
+ }
252
+ ],
253
+ "source": [
254
+ "raw_datasets['train']"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": 26,
260
+ "metadata": {},
261
+ "outputs": [
262
+ {
263
+ "data": {
264
+ "text/plain": [
265
+ "'Mythology & Folklore'"
266
+ ]
267
+ },
268
+ "execution_count": 26,
269
+ "metadata": {},
270
+ "output_type": "execute_result"
271
+ }
272
+ ],
273
+ "source": [
274
+ "raw_datasets['train']['type'][0]"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": 28,
280
+ "metadata": {},
281
+ "outputs": [
282
+ {
283
+ "data": {
284
+ "text/plain": [
285
+ "DatasetDict({\n",
286
+ " train: Dataset({\n",
287
+ " features: ['author', 'content', 'poem name', 'age', 'type'],\n",
288
+ " num_rows: 573\n",
289
+ " })\n",
290
+ "})"
291
+ ]
292
+ },
293
+ "execution_count": 28,
294
+ "metadata": {},
295
+ "output_type": "execute_result"
296
+ }
297
+ ],
298
+ "source": [
299
+ "raw_datasets"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "code",
304
+ "execution_count": 29,
305
+ "metadata": {},
306
+ "outputs": [],
307
+ "source": [
308
+ "tok_logger = transformers.utils.logging.get_logger(\n",
309
+ " \"transformers.tokenization_utils_base\"\n",
310
+ ")"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": 30,
316
+ "metadata": {},
317
+ "outputs": [],
318
+ "source": [
319
+ "def tokenize_function(examples):\n",
320
+ " with CaptureLogger(tok_logger) as cl:\n",
321
+ " output = tokenizer(examples[text_column_name])\n",
322
+ " # clm input could be much much longer than block_size\n",
323
+ " if \"Token indices sequence length is longer than the\" in cl.out:\n",
324
+ " tok_logger.warning(\n",
325
+ " \"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits\"\n",
326
+ " \" before being passed to the model.\"\n",
327
+ " )\n",
328
+ " return output"
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "code",
333
+ "execution_count": 33,
334
+ "metadata": {},
335
+ "outputs": [],
336
+ "source": [
337
+ "column_names = raw_datasets[\"train\"].column_names\n",
338
+ "# text_column_name = \"text\" if \"text\" in column_names else column_names[0]\n",
339
+ "text_column_name = \"content\""
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": 34,
345
+ "metadata": {},
346
+ "outputs": [
347
+ {
348
+ "data": {
349
+ "application/vnd.jupyter.widget-view+json": {
350
+ "model_id": "82c09dbdfa1a47d79607a4c9729fb286",
351
+ "version_major": 2,
352
+ "version_minor": 0
353
+ },
354
+ "text/plain": [
355
+ "Running tokenizer on dataset: 0%| | 0/1 [00:00<?, ?ba/s]"
356
+ ]
357
+ },
358
+ "metadata": {},
359
+ "output_type": "display_data"
360
+ },
361
+ {
362
+ "name": "stderr",
363
+ "output_type": "stream",
364
+ "text": [
365
+ "Token indices sequence length is longer than the specified maximum sequence length for this model (7725 > 1024). Running this sequence through the model will result in indexing errors\n",
366
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model.\n"
367
+ ]
368
+ }
369
+ ],
370
+ "source": [
371
+ "tokenized_datasets = raw_datasets.map(\n",
372
+ " tokenize_function,\n",
373
+ " batched=True,\n",
374
+ " # num_proc=data_args.preprocessing_num_workers,\n",
375
+ " remove_columns=column_names,\n",
376
+ " # load_from_cache_file=not data_args.overwrite_cache,\n",
377
+ " desc=\"Running tokenizer on dataset\",\n",
378
+ ")"
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "execution_count": 39,
384
+ "metadata": {},
385
+ "outputs": [],
386
+ "source": [
387
+ "block_size = tokenizer.model_max_length"
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "code",
392
+ "execution_count": 41,
393
+ "metadata": {},
394
+ "outputs": [],
395
+ "source": [
396
+ "# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.\n",
397
+ "def group_texts(examples):\n",
398
+ " # Concatenate all texts.\n",
399
+ " concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}\n",
400
+ " total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
401
+ " # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n",
402
+ " # customize this part to your needs.\n",
403
+ " if total_length >= block_size:\n",
404
+ " total_length = (total_length // block_size) * block_size\n",
405
+ " # Split by chunks of max_len.\n",
406
+ " result = {\n",
407
+ " k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n",
408
+ " for k, t in concatenated_examples.items()\n",
409
+ " }\n",
410
+ " result[\"labels\"] = result[\"input_ids\"].copy()\n",
411
+ " return result"
412
+ ]
413
+ },
414
+ {
415
+ "cell_type": "code",
416
+ "execution_count": 44,
417
+ "metadata": {},
418
+ "outputs": [
419
+ {
420
+ "data": {
421
+ "application/vnd.jupyter.widget-view+json": {
422
+ "model_id": "ca2f64461e304df6aecb16e8cfcd42ac",
423
+ "version_major": 2,
424
+ "version_minor": 0
425
+ },
426
+ "text/plain": [
427
+ "Grouping texts in chunks of 1024: 0%| | 0/1 [00:00<?, ?ba/s]"
428
+ ]
429
+ },
430
+ "metadata": {},
431
+ "output_type": "display_data"
432
+ }
433
+ ],
434
+ "source": [
435
+ "lm_datasets = tokenized_datasets.map(\n",
436
+ " group_texts,\n",
437
+ " batched=True,\n",
438
+ " # num_proc=data_args.preprocessing_num_workers,\n",
439
+ " # load_from_cache_file=not data_args.overwrite_cache,\n",
440
+ " desc=f\"Grouping texts in chunks of {block_size}\",\n",
441
+ ")"
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "execution_count": 46,
447
+ "metadata": {},
448
+ "outputs": [],
449
+ "source": [
450
+ "train_dataset = lm_datasets[\"train\"]"
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "markdown",
455
+ "metadata": {},
456
+ "source": [
457
+ "#### Do the fine-tuning"
458
+ ]
459
+ },
460
+ {
461
+ "cell_type": "code",
462
+ "execution_count": 47,
463
+ "metadata": {},
464
+ "outputs": [],
465
+ "source": [
466
+ "# Initialize our Trainer\n",
467
+ "trainer = Trainer(\n",
468
+ " model=model,\n",
469
+ " # args=training_args,\n",
470
+ " train_dataset=train_dataset,\n",
471
+ " # eval_dataset=eval_dataset,\n",
472
+ " tokenizer=tokenizer,\n",
473
+ " # Data collator will default to DataCollatorWithPadding, so we change it.\n",
474
+ " data_collator=default_data_collator,\n",
475
+ " # compute_metrics=compute_metrics\n",
476
+ " # if training_args.do_eval and not is_torch_tpu_available()\n",
477
+ " # else None,\n",
478
+ " # preprocess_logits_for_metrics=preprocess_logits_for_metrics\n",
479
+ " # if training_args.do_eval and not is_torch_tpu_available()\n",
480
+ " # else None,\n",
481
+ ")"
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "code",
486
+ "execution_count": 48,
487
+ "metadata": {},
488
+ "outputs": [
489
+ {
490
+ "name": "stderr",
491
+ "output_type": "stream",
492
+ "text": [
493
+ "***** Running training *****\n",
494
+ " Num examples = 171\n",
495
+ " Num Epochs = 3\n",
496
+ " Instantaneous batch size per device = 8\n",
497
+ " Total train batch size (w. parallel, distributed & accumulation) = 8\n",
498
+ " Gradient Accumulation steps = 1\n",
499
+ " Total optimization steps = 66\n"
500
+ ]
501
+ },
502
+ {
503
+ "data": {
504
+ "application/vnd.jupyter.widget-view+json": {
505
+ "model_id": "59ebc6f251bd42e4bd3474b574614d1f",
506
+ "version_major": 2,
507
+ "version_minor": 0
508
+ },
509
+ "text/plain": [
510
+ " 0%| | 0/66 [00:00<?, ?it/s]"
511
+ ]
512
+ },
513
+ "metadata": {},
514
+ "output_type": "display_data"
515
+ },
516
+ {
517
+ "name": "stderr",
518
+ "output_type": "stream",
519
+ "text": [
520
+ "\n",
521
+ "\n",
522
+ "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
523
+ "\n",
524
+ "\n",
525
+ "Saving model checkpoint to tmp_trainer\n",
526
+ "Configuration saved in tmp_trainer/config.json\n"
527
+ ]
528
+ },
529
+ {
530
+ "name": "stdout",
531
+ "output_type": "stream",
532
+ "text": [
533
+ "{'train_runtime': 2967.2818, 'train_samples_per_second': 0.173, 'train_steps_per_second': 0.022, 'train_loss': 4.249474265358665, 'epoch': 3.0}\n"
534
+ ]
535
+ },
536
+ {
537
+ "name": "stderr",
538
+ "output_type": "stream",
539
+ "text": [
540
+ "Model weights saved in tmp_trainer/pytorch_model.bin\n",
541
+ "tokenizer config file saved in tmp_trainer/tokenizer_config.json\n",
542
+ "Special tokens file saved in tmp_trainer/special_tokens_map.json\n"
543
+ ]
544
+ },
545
+ {
546
+ "name": "stdout",
547
+ "output_type": "stream",
548
+ "text": [
549
+ "***** train metrics *****\n",
550
+ " epoch = 3.0\n",
551
+ " train_loss = 4.2495\n",
552
+ " train_runtime = 0:49:27.28\n",
553
+ " train_samples = 171\n",
554
+ " train_samples_per_second = 0.173\n",
555
+ " train_steps_per_second = 0.022\n"
556
+ ]
557
+ }
558
+ ],
559
+ "source": [
560
+ "# Training\n",
561
+ "checkpoint = None\n",
562
+ "train_result = trainer.train(resume_from_checkpoint=checkpoint)\n",
563
+ "trainer.save_model() # Saves the tokenizer too for easy upload\n",
564
+ "\n",
565
+ "metrics = train_result.metrics\n",
566
+ "\n",
567
+ "max_train_samples = (len(train_dataset))\n",
568
+ "metrics[\"train_samples\"] = min(max_train_samples, len(train_dataset))\n",
569
+ "\n",
570
+ "trainer.log_metrics(\"train\", metrics)\n",
571
+ "trainer.save_metrics(\"train\", metrics)\n",
572
+ "trainer.save_state()"
573
+ ]
574
+ }
575
+ ],
576
+ "metadata": {
577
+ "kernelspec": {
578
+ "display_name": "Python 3.10.6 ('augmented_poetry')",
579
+ "language": "python",
580
+ "name": "python3"
581
+ },
582
+ "language_info": {
583
+ "codemirror_mode": {
584
+ "name": "ipython",
585
+ "version": 3
586
+ },
587
+ "file_extension": ".py",
588
+ "mimetype": "text/x-python",
589
+ "name": "python",
590
+ "nbconvert_exporter": "python",
591
+ "pygments_lexer": "ipython3",
592
+ "version": "3.8.13"
593
+ },
594
+ "orig_nbformat": 4,
595
+ "vscode": {
596
+ "interpreter": {
597
+ "hash": "00664817f4a09ab74dd392ee5a8d12e3606381c26df296db9ea5c334bb5d1b65"
598
+ }
599
+ }
600
+ },
601
+ "nbformat": 4,
602
+ "nbformat_minor": 2
603
+ }