Sijun He commited on
Commit
48d79f7
1 Parent(s): 03e08e7

upload spaces

Browse files
.gitattributes CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ saved_model/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from char_tokenizer import CharTokenizer
2
+ import gradio as gr
3
+ from transformers import GPT2LMHeadModel
4
+
5
+ tokenizer = CharTokenizer.load("saved_model/tokenizer.json")
6
+ model = GPT2LMHeadModel.from_pretrained("saved_model")
7
+ def generation(prompt, length):
8
+ tokens = tokenizer(prompt=str(length) + prompt)
9
+ output_ids = model.generate(tokens['input_ids'],
10
+ do_sample=True,
11
+ top_p=0.95,
12
+ max_length=100)
13
+ decoded_verse = tokenizer.decode(output_ids)[len(prompt) + 1:]
14
+ return decoded_verse
15
+
16
+ input_prompt = gr.inputs.Textbox()
17
+ input_length = gr.inputs.Dropdown([5, 6, 7])
18
+ gr.Interface(fn=generation, inputs=[input_prompt, input_length], outputs="text").launch()
char_tokenizer.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, json
2
+
3
+ class CharTokenizer:
4
+ def __init__(self, corpus=None, vocab=None):
5
+ if vocab is not None:
6
+ self.vocab = vocab
7
+ elif corpus is not None:
8
+ self.vocab = self._build_vocab(corpus)
9
+ else:
10
+ raise Exception("Either corpus or vocab has to be supplied")
11
+ self.id2vocab = [char for char, index in sorted(self.vocab.items(), key=lambda item: item[1])]
12
+
13
+ def _tokenize(self, text):
14
+ return list(text)
15
+
16
+ def __call__(self, prompt, text=None, add_eos_token=False):
17
+ token_ids = [self.vocab.get(token, 0) for token in self._tokenize(prompt)]
18
+ if text is not None:
19
+ text_token_ids = [self.vocab.get(token, 0) for token in self._tokenize(text)]
20
+ token_ids = token_ids + [self.vocab["<bos>"]] + text_token_ids
21
+ if add_eos_token:
22
+ token_ids = token_ids + [self.vocab["<eos>"]]
23
+ input_ids_tensor = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0)
24
+ attention_masks = torch.ones_like(input_ids_tensor)
25
+ return {"input_ids": input_ids_tensor, "attention_mask": attention_masks}
26
+
27
+ def _build_vocab(self, corpus):
28
+ vocab = {"<pad>": 0}
29
+ for verse_lengths in range(3, 10):
30
+ vocab[str(verse_lengths)] = len(vocab)
31
+ for doc in corpus:
32
+ chars = self._tokenize(doc)
33
+ for char in chars:
34
+ if char not in vocab:
35
+ vocab[char] = len(vocab)
36
+ vocab["<bos>"] = len(vocab)
37
+ vocab["<eos>"] = len(vocab)
38
+ return vocab
39
+
40
+ def decode(self, token_ids):
41
+ chars = [self.id2vocab[token_id] for token_id in token_ids.flatten().tolist()]
42
+ filtered_chars = [char for char in chars if char not in ["<eos>", "<bos>", "<pad>"]]
43
+ return "".join(filtered_chars)
44
+
45
+ def save(self, filepath):
46
+ with open(filepath, "w") as f:
47
+ json.dump(self.vocab, f)
48
+
49
+ @classmethod
50
+ def load(cls, filepath):
51
+ with open(filepath) as f:
52
+ vocab = json.load(f)
53
+ return cls(vocab=vocab)
poet-gpt2.ipynb ADDED
@@ -0,0 +1,1502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "2c3bb18a",
7
+ "metadata": {
8
+ "execution": {
9
+ "iopub.execute_input": "2022-04-18T01:49:05.484451Z",
10
+ "iopub.status.busy": "2022-04-18T01:49:05.482966Z",
11
+ "iopub.status.idle": "2022-04-18T01:49:22.249321Z",
12
+ "shell.execute_reply": "2022-04-18T01:49:22.248692Z",
13
+ "shell.execute_reply.started": "2022-04-16T12:16:29.630467Z"
14
+ },
15
+ "papermill": {
16
+ "duration": 16.788107,
17
+ "end_time": "2022-04-18T01:49:22.249468",
18
+ "exception": false,
19
+ "start_time": "2022-04-18T01:49:05.461361",
20
+ "status": "completed"
21
+ },
22
+ "tags": []
23
+ },
24
+ "outputs": [
25
+ {
26
+ "name": "stdout",
27
+ "output_type": "stream",
28
+ "text": [
29
+ "Cloning into 'Poetry'...\n",
30
+ "remote: Enumerating objects: 135, done.\u001b[K\n",
31
+ "remote: Total 135 (delta 0), reused 0 (delta 0), pack-reused 135\u001b[K\n",
32
+ "Receiving objects: 100% (135/135), 123.55 MiB | 12.33 MiB/s, done.\n",
33
+ "Resolving deltas: 100% (77/77), done.\n",
34
+ "Updating files: 100% (39/39), done.\n"
35
+ ]
36
+ }
37
+ ],
38
+ "source": [
39
+ "#!wget https://raw.githubusercontent.com/youyuge34/Poems_generator_Keras/master/dataset/poetry.txt\n",
40
+ "!git clone https://github.com/Werneror/Poetry.git"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 2,
46
+ "id": "d76b15a8",
47
+ "metadata": {
48
+ "execution": {
49
+ "iopub.execute_input": "2022-04-18T01:49:22.322907Z",
50
+ "iopub.status.busy": "2022-04-18T01:49:22.322113Z",
51
+ "iopub.status.idle": "2022-04-18T01:49:28.965795Z",
52
+ "shell.execute_reply": "2022-04-18T01:49:28.965246Z",
53
+ "shell.execute_reply.started": "2022-04-16T12:16:41.322744Z"
54
+ },
55
+ "papermill": {
56
+ "duration": 6.678735,
57
+ "end_time": "2022-04-18T01:49:28.965944",
58
+ "exception": false,
59
+ "start_time": "2022-04-18T01:49:22.287209",
60
+ "status": "completed"
61
+ },
62
+ "tags": []
63
+ },
64
+ "outputs": [],
65
+ "source": [
66
+ "import os\n",
67
+ "import pandas as pd\n",
68
+ "from sklearn.model_selection import train_test_split\n",
69
+ "from transformers import GPT2Config, GPT2LMHeadModel\n",
70
+ "from transformers import TrainingArguments, Trainer"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 3,
76
+ "id": "de8c9caa",
77
+ "metadata": {
78
+ "execution": {
79
+ "iopub.execute_input": "2022-04-18T01:49:29.041708Z",
80
+ "iopub.status.busy": "2022-04-18T01:49:29.040006Z",
81
+ "iopub.status.idle": "2022-04-18T01:49:34.168753Z",
82
+ "shell.execute_reply": "2022-04-18T01:49:34.169341Z",
83
+ "shell.execute_reply.started": "2022-04-16T12:16:48.16115Z"
84
+ },
85
+ "papermill": {
86
+ "duration": 5.16885,
87
+ "end_time": "2022-04-18T01:49:34.169515",
88
+ "exception": false,
89
+ "start_time": "2022-04-18T01:49:29.000665",
90
+ "status": "completed"
91
+ },
92
+ "tags": []
93
+ },
94
+ "outputs": [
95
+ {
96
+ "data": {
97
+ "text/html": [
98
+ "<div>\n",
99
+ "<style scoped>\n",
100
+ " .dataframe tbody tr th:only-of-type {\n",
101
+ " vertical-align: middle;\n",
102
+ " }\n",
103
+ "\n",
104
+ " .dataframe tbody tr th {\n",
105
+ " vertical-align: top;\n",
106
+ " }\n",
107
+ "\n",
108
+ " .dataframe thead th {\n",
109
+ " text-align: right;\n",
110
+ " }\n",
111
+ "</style>\n",
112
+ "<table border=\"1\" class=\"dataframe\">\n",
113
+ " <thead>\n",
114
+ " <tr style=\"text-align: right;\">\n",
115
+ " <th></th>\n",
116
+ " <th>题目</th>\n",
117
+ " <th>朝代</th>\n",
118
+ " <th>作者</th>\n",
119
+ " <th>内容</th>\n",
120
+ " </tr>\n",
121
+ " </thead>\n",
122
+ " <tbody>\n",
123
+ " <tr>\n",
124
+ " <th>0</th>\n",
125
+ " <td>彭生行</td>\n",
126
+ " <td>明</td>\n",
127
+ " <td>何景明</td>\n",
128
+ " <td>岷峨山根江水坼,万里波涛混吴越。倾湖倒海不可量,仰看一线青天上。郁蓝秀色盘三巴,间产锦石兼丹...</td>\n",
129
+ " </tr>\n",
130
+ " <tr>\n",
131
+ " <th>1</th>\n",
132
+ " <td>黄河篇</td>\n",
133
+ " <td>明</td>\n",
134
+ " <td>何景明</td>\n",
135
+ " <td>黄河昆崙源,九曲与天通。银汉贯箕尾,左盘日月宫。奔流下龙门,喷薄沙海风。三山万里倚穷发,鳖极...</td>\n",
136
+ " </tr>\n",
137
+ " <tr>\n",
138
+ " <th>2</th>\n",
139
+ " <td>三清山人歌</td>\n",
140
+ " <td>明</td>\n",
141
+ " <td>何景明</td>\n",
142
+ " <td>山人佩剑冠远游,腰间鞶囊垂虎头,七星照耀金银钩。东行策杖指卢霍,逝将沧海寻丹丘。三清西南龙虎...</td>\n",
143
+ " </tr>\n",
144
+ " <tr>\n",
145
+ " <th>3</th>\n",
146
+ " <td>昔游篇</td>\n",
147
+ " <td>明</td>\n",
148
+ " <td>何景明</td>\n",
149
+ " <td>三星烂夜河汉流,觞行瑟作中堂幽。李君勿叹息,薛��且停讴。英英孟夫子,听我当筵歌昔游。昔游少年...</td>\n",
150
+ " </tr>\n",
151
+ " <tr>\n",
152
+ " <th>4</th>\n",
153
+ " <td>赠商三</td>\n",
154
+ " <td>明</td>\n",
155
+ " <td>何景明</td>\n",
156
+ " <td>去冬雪雨留蓟门,开筵谑浪倒金樽。今春灯火到长安,过门不肯回银鞍。燕山花隔平山柳,马上东风几回首。</td>\n",
157
+ " </tr>\n",
158
+ " </tbody>\n",
159
+ "</table>\n",
160
+ "</div>"
161
+ ],
162
+ "text/plain": [
163
+ " 题目 朝代 作者 内容\n",
164
+ "0 彭生行 明 何景明 岷峨山根江水坼,万里波涛混吴越。倾湖倒海不可量,仰看一线青天上。郁蓝秀色盘三巴,间产锦石兼丹...\n",
165
+ "1 黄河篇 明 何景明 黄河昆崙源,九曲与天通。银汉贯箕尾,左盘日月宫。奔流下龙门,喷薄沙海风。三山万里倚穷发,鳖极...\n",
166
+ "2 三清山人歌 明 何景明 山人佩剑冠远游,腰间鞶囊垂虎头,七星照耀金银钩。东行策杖指卢霍,逝将沧海寻丹丘。三清西南龙虎...\n",
167
+ "3 昔游篇 明 何景明 三星烂夜河汉流,觞行瑟作中堂幽。李君勿叹息,薛君且停讴。英英孟夫子,听我当筵歌昔游。昔游少年...\n",
168
+ "4 赠商三 明 何景明 去冬雪雨留蓟门,开筵谑浪倒金樽。今春灯火到长安,过门不肯回银鞍。燕山花隔平山柳,马上东风几回首。"
169
+ ]
170
+ },
171
+ "execution_count": 3,
172
+ "metadata": {},
173
+ "output_type": "execute_result"
174
+ }
175
+ ],
176
+ "source": [
177
+ "data = None\n",
178
+ "for (dirpath, dirnames, filenames) in os.walk(\"Poetry\"):\n",
179
+ " for filename in filenames:\n",
180
+ " if filename.endswith(\"csv\"):\n",
181
+ " cur_data = pd.read_csv(f\"Poetry/{filename}\")\n",
182
+ " if data is None:\n",
183
+ " data = cur_data\n",
184
+ " else:\n",
185
+ " data = pd.concat([data, cur_data])\n",
186
+ "data.head()"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": 4,
192
+ "id": "40c84fbf",
193
+ "metadata": {
194
+ "execution": {
195
+ "iopub.execute_input": "2022-04-18T01:49:34.242596Z",
196
+ "iopub.status.busy": "2022-04-18T01:49:34.241754Z",
197
+ "iopub.status.idle": "2022-04-18T01:49:34.244196Z",
198
+ "shell.execute_reply": "2022-04-18T01:49:34.243782Z",
199
+ "shell.execute_reply.started": "2022-04-16T12:16:53.639047Z"
200
+ },
201
+ "papermill": {
202
+ "duration": 0.041531,
203
+ "end_time": "2022-04-18T01:49:34.244315",
204
+ "exception": false,
205
+ "start_time": "2022-04-18T01:49:34.202784",
206
+ "status": "completed"
207
+ },
208
+ "tags": []
209
+ },
210
+ "outputs": [],
211
+ "source": [
212
+ "import re\n",
213
+ "\n",
214
+ "def verse_length(verses):\n",
215
+ " return len(verses[0])\n",
216
+ "\n",
217
+ "def verse_heads(verses):\n",
218
+ " verse_heads = [verse[0] for verse in verses]\n",
219
+ " return \"\".join(verse_heads)\n",
220
+ "\n",
221
+ "def split_poem(poem):\n",
222
+ " return [verse for verse in re.split(\",|。\", poem) if len(verse)]\n",
223
+ " \n",
224
+ "def is_correct_length(poem, max_length, min_length):\n",
225
+ " return len(poem) < max_length and len(poem) > min_length\n",
226
+ " \n",
227
+ "def is_equal_length(verses):\n",
228
+ " verse_lengths = [len(verse) for verse in verses]\n",
229
+ " for length in verse_lengths:\n",
230
+ " if length != verse_lengths[0]:\n",
231
+ " return False\n",
232
+ " return True "
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": 5,
238
+ "id": "4fd4df65",
239
+ "metadata": {
240
+ "execution": {
241
+ "iopub.execute_input": "2022-04-18T01:49:34.407430Z",
242
+ "iopub.status.busy": "2022-04-18T01:49:34.406391Z",
243
+ "iopub.status.idle": "2022-04-18T01:49:47.517219Z",
244
+ "shell.execute_reply": "2022-04-18T01:49:47.516725Z",
245
+ "shell.execute_reply.started": "2022-04-16T12:16:53.648455Z"
246
+ },
247
+ "papermill": {
248
+ "duration": 13.240486,
249
+ "end_time": "2022-04-18T01:49:47.517350",
250
+ "exception": false,
251
+ "start_time": "2022-04-18T01:49:34.276864",
252
+ "status": "completed"
253
+ },
254
+ "tags": []
255
+ },
256
+ "outputs": [
257
+ {
258
+ "name": "stderr",
259
+ "output_type": "stream",
260
+ "text": [
261
+ "/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:6: SettingWithCopyWarning: \n",
262
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
263
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
264
+ "\n",
265
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
266
+ " \n",
267
+ "/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n",
268
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
269
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
270
+ "\n",
271
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
272
+ " import sys\n"
273
+ ]
274
+ },
275
+ {
276
+ "name": "stdout",
277
+ "output_type": "stream",
278
+ "text": [
279
+ "Number of valid poems: 617674\n"
280
+ ]
281
+ }
282
+ ],
283
+ "source": [
284
+ "data = data[~data[\"内容\"].isna()]\n",
285
+ "data['verses'] = [split_poem(poem) for poem in data['内容']]\n",
286
+ "data['equal_verse_lengths'] = [is_equal_length(verses) for verses in data['verses']]\n",
287
+ "data['meet_length_requirements'] = [is_correct_length(poem, 100, 20) for poem in data['内容']]\n",
288
+ "valid_poems = data[data['equal_verse_lengths'] & data['meet_length_requirements']]\n",
289
+ "valid_poems['verse_lengths'] = [verse_length(verses) for verses in valid_poems['verses']]\n",
290
+ "valid_poems['verse_heads'] = [verse_heads(verses) for verses in valid_poems['verses']]\n",
291
+ "valid_poems = valid_poems[valid_poems['verse_lengths'] < 10]\n",
292
+ "print(f\"Number of valid poems: {len(valid_poems)}\")"
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "code",
297
+ "execution_count": 6,
298
+ "id": "f86c5f9c",
299
+ "metadata": {
300
+ "execution": {
301
+ "iopub.execute_input": "2022-04-18T01:49:47.589515Z",
302
+ "iopub.status.busy": "2022-04-18T01:49:47.588746Z",
303
+ "iopub.status.idle": "2022-04-18T01:49:47.601086Z",
304
+ "shell.execute_reply": "2022-04-18T01:49:47.600657Z",
305
+ "shell.execute_reply.started": "2022-04-16T12:17:06.029609Z"
306
+ },
307
+ "papermill": {
308
+ "duration": 0.049888,
309
+ "end_time": "2022-04-18T01:49:47.601199",
310
+ "exception": false,
311
+ "start_time": "2022-04-18T01:49:47.551311",
312
+ "status": "completed"
313
+ },
314
+ "tags": []
315
+ },
316
+ "outputs": [
317
+ {
318
+ "data": {
319
+ "text/html": [
320
+ "<div>\n",
321
+ "<style scoped>\n",
322
+ " .dataframe tbody tr th:only-of-type {\n",
323
+ " vertical-align: middle;\n",
324
+ " }\n",
325
+ "\n",
326
+ " .dataframe tbody tr th {\n",
327
+ " vertical-align: top;\n",
328
+ " }\n",
329
+ "\n",
330
+ " .dataframe thead th {\n",
331
+ " text-align: right;\n",
332
+ " }\n",
333
+ "</style>\n",
334
+ "<table border=\"1\" class=\"dataframe\">\n",
335
+ " <thead>\n",
336
+ " <tr style=\"text-align: right;\">\n",
337
+ " <th></th>\n",
338
+ " <th>题目</th>\n",
339
+ " <th>朝代</th>\n",
340
+ " <th>作者</th>\n",
341
+ " <th>内容</th>\n",
342
+ " <th>verses</th>\n",
343
+ " <th>equal_verse_lengths</th>\n",
344
+ " <th>meet_length_requirements</th>\n",
345
+ " <th>verse_lengths</th>\n",
346
+ " <th>verse_heads</th>\n",
347
+ " </tr>\n",
348
+ " </thead>\n",
349
+ " <tbody>\n",
350
+ " <tr>\n",
351
+ " <th>4</th>\n",
352
+ " <td>赠商三</td>\n",
353
+ " <td>明</td>\n",
354
+ " <td>何景明</td>\n",
355
+ " <td>去冬雪雨留蓟门,开筵谑浪倒金樽。今春灯火到长安,过门不肯回银鞍。燕山花隔平山柳,马上东风几回首。</td>\n",
356
+ " <td>[去冬雪雨留蓟门, 开筵谑浪倒金樽, 今春灯火到长安, 过门不肯回银鞍, 燕山花隔平山柳, ...</td>\n",
357
+ " <td>True</td>\n",
358
+ " <td>True</td>\n",
359
+ " <td>7</td>\n",
360
+ " <td>去开今过燕马</td>\n",
361
+ " </tr>\n",
362
+ " <tr>\n",
363
+ " <th>14</th>\n",
364
+ " <td>送叶生还闽中兼怀郑继之</td>\n",
365
+ " <td>明</td>\n",
366
+ " <td>何景明</td>\n",
367
+ " <td>叶生行吟燕中市,葛巾麻鞋岁将晚。两都为客今始归,五岳寻仙不辞远。江南画舸春柳低,海上茅堂白云...</td>\n",
368
+ " <td>[叶生行吟燕中市, 葛巾麻鞋岁将晚, 两都为客今始归, 五岳寻仙不辞远, 江南画舸春柳低, ...</td>\n",
369
+ " <td>True</td>\n",
370
+ " <td>True</td>\n",
371
+ " <td>7</td>\n",
372
+ " <td>叶葛两五江海谷为</td>\n",
373
+ " </tr>\n",
374
+ " <tr>\n",
375
+ " <th>15</th>\n",
376
+ " <td>送林利正同知之潮阳</td>\n",
377
+ " <td>明</td>\n",
378
+ " <td>何景明</td>\n",
379
+ " <td>忆在成均共携手,泉山门下相知久。万里恩情若父兄,十年道义惭师友。君才岂孤一第名,佩刀今作岭南...</td>\n",
380
+ " <td>[忆在成均共携手, 泉山门下相知久, 万里恩情若父兄, 十年道义惭师友, 君才岂孤一第名, ...</td>\n",
381
+ " <td>True</td>\n",
382
+ " <td>True</td>\n",
383
+ " <td>7</td>\n",
384
+ " <td>忆泉万十君佩挂伐燕相过道</td>\n",
385
+ " </tr>\n",
386
+ " <tr>\n",
387
+ " <th>16</th>\n",
388
+ " <td>金陵歌送李先生</td>\n",
389
+ " <td>明</td>\n",
390
+ " <td>何景明</td>\n",
391
+ " <td>李公为舅有吕甥,甥舅四海皆知名。吕君关西昨日去,公自金陵来复行。金陵江水无断绝,金陵之山高巀...</td>\n",
392
+ " <td>[李公为舅有吕甥, 甥舅四海皆知名, 吕君关西昨日去, 公自金陵来复行, 金陵江水无断绝, ...</td>\n",
393
+ " <td>True</td>\n",
394
+ " <td>True</td>\n",
395
+ " <td>7</td>\n",
396
+ " <td>李甥吕公金金龙星白清燕</td>\n",
397
+ " </tr>\n",
398
+ " <tr>\n",
399
+ " <th>21</th>\n",
400
+ " <td>延津歌送韩令</td>\n",
401
+ " <td>明</td>\n",
402
+ " <td>何景明</td>\n",
403
+ " <td>延津寇过馀少男,延津县令莫停骖。双凫直向黄河北,一雁先飞清卫南。黄河岸边不种麦,浊浪滔天多贾...</td>\n",
404
+ " <td>[延津寇过馀少男, 延津县令莫停骖, 双凫直向黄河北, 一雁先飞清卫南, 黄河岸边不种麦, ...</td>\n",
405
+ " <td>True</td>\n",
406
+ " <td>True</td>\n",
407
+ " <td>7</td>\n",
408
+ " <td>延延双一黄浊城县</td>\n",
409
+ " </tr>\n",
410
+ " </tbody>\n",
411
+ "</table>\n",
412
+ "</div>"
413
+ ],
414
+ "text/plain": [
415
+ " 题目 朝代 作者 内容 \\\n",
416
+ "4 赠商三 明 何景明 去冬雪雨留蓟门,开筵谑浪倒金樽。今春灯火到长安,过门不肯回银鞍。燕山花隔平山柳,马上东风几回首。 \n",
417
+ "14 送叶生还闽中兼怀郑继之 明 何景明 叶生行吟燕中市,葛巾麻鞋岁将晚。两都为客今始归,五岳寻仙不辞远。江南画舸春柳低,海上茅堂白云... \n",
418
+ "15 送林利正同知之潮阳 明 何景明 忆在成均共携手,泉山门下相知久。万里恩情若父兄,十年道义惭师友。君才岂孤一第名,佩刀今作岭南... \n",
419
+ "16 金陵歌送李先生 明 何景明 李公为舅有吕甥,甥舅四海皆知名。吕君关西昨日去,公自金陵来复行。金陵江水无断绝,金陵之山高巀... \n",
420
+ "21 延津歌送韩令 明 何景明 延津寇过馀少男,延津县令莫停骖。双凫直向黄河北,一雁先飞清卫南。黄河岸边不种麦,浊浪滔天多贾... \n",
421
+ "\n",
422
+ " verses equal_verse_lengths \\\n",
423
+ "4 [去冬雪雨留蓟门, 开筵谑浪倒金樽, 今春灯火到长安, 过门不肯回银鞍, 燕山花隔平山柳, ... True \n",
424
+ "14 [叶生行吟燕中市, 葛巾麻鞋岁将晚, 两都为客今始归, 五岳寻仙不辞远, 江南画舸春柳低, ... True \n",
425
+ "15 [忆在成均共携手, 泉山门下相知久, 万里恩情若父兄, 十年道义惭师友, 君才岂孤一第名, ... True \n",
426
+ "16 [李公为舅有吕甥, 甥舅四海皆知名, 吕君关西昨日去, 公自金陵来复行, 金陵江水无断绝, ... True \n",
427
+ "21 [延津寇过馀少男, 延津县令莫停骖, 双凫直向黄河北, 一雁先飞清卫南, 黄河岸边不种麦, ... True \n",
428
+ "\n",
429
+ " meet_length_requirements verse_lengths verse_heads \n",
430
+ "4 True 7 去开今过燕马 \n",
431
+ "14 True 7 叶葛两五江海谷为 \n",
432
+ "15 True 7 忆泉万十君佩挂伐燕相过道 \n",
433
+ "16 True 7 李甥吕公金金龙星白清燕 \n",
434
+ "21 True 7 延延双一黄浊城县 "
435
+ ]
436
+ },
437
+ "execution_count": 6,
438
+ "metadata": {},
439
+ "output_type": "execute_result"
440
+ }
441
+ ],
442
+ "source": [
443
+ "valid_poems.head()"
444
+ ]
445
+ },
446
+ {
447
+ "cell_type": "code",
448
+ "execution_count": 7,
449
+ "id": "33140481",
450
+ "metadata": {
451
+ "execution": {
452
+ "iopub.execute_input": "2022-04-18T01:49:47.695680Z",
453
+ "iopub.status.busy": "2022-04-18T01:49:47.694684Z",
454
+ "iopub.status.idle": "2022-04-18T01:49:47.696696Z",
455
+ "shell.execute_reply": "2022-04-18T01:49:47.697169Z",
456
+ "shell.execute_reply.started": "2022-04-16T12:23:28.401922Z"
457
+ },
458
+ "papermill": {
459
+ "duration": 0.06126,
460
+ "end_time": "2022-04-18T01:49:47.697307",
461
+ "exception": false,
462
+ "start_time": "2022-04-18T01:49:47.636047",
463
+ "status": "completed"
464
+ },
465
+ "tags": []
466
+ },
467
+ "outputs": [],
468
+ "source": [
469
+ "import torch, json\n",
470
+ "\n",
471
+ "class CharTokenizer:\n",
472
+ " def __init__(self, corpus=None, vocab=None):\n",
473
+ " if vocab is not None:\n",
474
+ " self.vocab = vocab\n",
475
+ " elif corpus is not None:\n",
476
+ " self.vocab = self._build_vocab(corpus)\n",
477
+ " else:\n",
478
+ " raise Exception(\"Either corpus or vocab has to be supplied\")\n",
479
+ " self.id2vocab = [char for char, index in sorted(self.vocab.items(), key=lambda item: item[1])]\n",
480
+ " \n",
481
+ " def _tokenize(self, text):\n",
482
+ " return list(text)\n",
483
+ " \n",
484
+ " def __call__(self, prompt, text=None, add_eos_token=False):\n",
485
+ " token_ids = [self.vocab.get(token, 0) for token in self._tokenize(prompt)]\n",
486
+ " if text is not None:\n",
487
+ " text_token_ids = [self.vocab.get(token, 0) for token in self._tokenize(text)]\n",
488
+ " token_ids = token_ids + [self.vocab[\"<bos>\"]] + text_token_ids\n",
489
+ " if add_eos_token:\n",
490
+ " token_ids = token_ids + [self.vocab[\"<eos>\"]]\n",
491
+ " input_ids_tensor = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0)\n",
492
+ " attention_masks = torch.ones_like(input_ids_tensor)\n",
493
+ " return {\"input_ids\": input_ids_tensor, \"attention_mask\": attention_masks}\n",
494
+ " \n",
495
+ " def _build_vocab(self, corpus):\n",
496
+ " vocab = {\"<pad>\": 0}\n",
497
+ " for verse_lengths in range(3, 10):\n",
498
+ " vocab[str(verse_lengths)] = len(vocab)\n",
499
+ " for doc in corpus:\n",
500
+ " chars = self._tokenize(doc)\n",
501
+ " for char in chars:\n",
502
+ " if char not in vocab:\n",
503
+ " vocab[char] = len(vocab)\n",
504
+ " vocab[\"<bos>\"] = len(vocab)\n",
505
+ " vocab[\"<eos>\"] = len(vocab)\n",
506
+ " return vocab\n",
507
+ " \n",
508
+ " def decode(self, token_ids):\n",
509
+ " chars = [self.id2vocab[token_id] for token_id in token_ids.flatten().tolist()]\n",
510
+ " filtered_chars = [char for char in chars if char not in [\"<eos>\", \"<bos>\", \"<pad>\"]]\n",
511
+ " return \"\".join(filtered_chars)\n",
512
+ " \n",
513
+ " def save(self, filepath):\n",
514
+ " with open(filepath, \"w\") as f:\n",
515
+ " json.dump(self.vocab, f)\n",
516
+ " \n",
517
+ " @classmethod\n",
518
+ " def load(cls, filepath):\n",
519
+ " with open(filepath) as f:\n",
520
+ " vocab = json.load(f)\n",
521
+ " return cls(vocab=vocab)"
522
+ ]
523
+ },
524
+ {
525
+ "cell_type": "code",
526
+ "execution_count": 8,
527
+ "id": "73f55174",
528
+ "metadata": {
529
+ "execution": {
530
+ "iopub.execute_input": "2022-04-18T01:49:47.806040Z",
531
+ "iopub.status.busy": "2022-04-18T01:49:47.795805Z",
532
+ "iopub.status.idle": "2022-04-18T01:49:51.506784Z",
533
+ "shell.execute_reply": "2022-04-18T01:49:51.506307Z",
534
+ "shell.execute_reply.started": "2022-04-16T12:23:28.57419Z"
535
+ },
536
+ "papermill": {
537
+ "duration": 3.770368,
538
+ "end_time": "2022-04-18T01:49:51.506972",
539
+ "exception": false,
540
+ "start_time": "2022-04-18T01:49:47.736604",
541
+ "status": "completed"
542
+ },
543
+ "tags": []
544
+ },
545
+ "outputs": [],
546
+ "source": [
547
+ "tokenizer = CharTokenizer(valid_poems['内容'])\n",
548
+ "tokenizer.save(\"/kaggle/working/tokenizer.json\")"
549
+ ]
550
+ },
551
+ {
552
+ "cell_type": "code",
553
+ "execution_count": 9,
554
+ "id": "2d0c4b52",
555
+ "metadata": {
556
+ "execution": {
557
+ "iopub.execute_input": "2022-04-18T01:49:51.587046Z",
558
+ "iopub.status.busy": "2022-04-18T01:49:51.578743Z",
559
+ "iopub.status.idle": "2022-04-18T01:50:13.120701Z",
560
+ "shell.execute_reply": "2022-04-18T01:50:13.121126Z",
561
+ "shell.execute_reply.started": "2022-04-16T12:35:45.273336Z"
562
+ },
563
+ "papermill": {
564
+ "duration": 21.579069,
565
+ "end_time": "2022-04-18T01:50:13.121274",
566
+ "exception": false,
567
+ "start_time": "2022-04-18T01:49:51.542205",
568
+ "status": "completed"
569
+ },
570
+ "tags": []
571
+ },
572
+ "outputs": [
573
+ {
574
+ "name": "stdout",
575
+ "output_type": "stream",
576
+ "text": [
577
+ "123\n"
578
+ ]
579
+ }
580
+ ],
581
+ "source": [
582
+ "tokenized_dataset = [tokenizer(prompt = str(length) + heads, text=poem, add_eos_token=True) for poem, length, heads in zip(valid_poems['内容'],\n",
583
+ " valid_poems['verse_lengths'],\n",
584
+ " valid_poems['verse_heads'])]\n",
585
+ "train_dataset, val_dataset = train_test_split(tokenized_dataset, test_size=0.02, random_state=1234)\n",
586
+ "max_lengths = max([tokenized[\"input_ids\"].size(1) for tokenized in tokenized_dataset])\n",
587
+ "print(max_lengths)"
588
+ ]
589
+ },
590
+ {
591
+ "cell_type": "code",
592
+ "execution_count": 10,
593
+ "id": "a4e831ab",
594
+ "metadata": {
595
+ "execution": {
596
+ "iopub.execute_input": "2022-04-18T01:50:13.232157Z",
597
+ "iopub.status.busy": "2022-04-18T01:50:13.231258Z",
598
+ "iopub.status.idle": "2022-04-18T01:50:13.233058Z",
599
+ "shell.execute_reply": "2022-04-18T01:50:13.233434Z",
600
+ "shell.execute_reply.started": "2022-04-16T12:24:19.850932Z"
601
+ },
602
+ "papermill": {
603
+ "duration": 0.075455,
604
+ "end_time": "2022-04-18T01:50:13.233582",
605
+ "exception": false,
606
+ "start_time": "2022-04-18T01:50:13.158127",
607
+ "status": "completed"
608
+ },
609
+ "tags": []
610
+ },
611
+ "outputs": [],
612
+ "source": [
613
+ "PAD_TOKEN_ID = 0\n",
614
+ "\n",
615
+ "def collate_fn(batch_inputs):\n",
616
+ " seq_lengths = [i[\"input_ids\"].size(1) for i in batch_inputs]\n",
617
+ " max_length = max(seq_lengths)\n",
618
+ " input_ids = torch.full((len(batch_inputs), max_length), PAD_TOKEN_ID, dtype=torch.long)\n",
619
+ " attention_mask = torch.full((len(batch_inputs), max_length), 0, dtype=torch.long)\n",
620
+ " for idx, inputs in enumerate(batch_inputs):\n",
621
+ " input_ids[idx, :seq_lengths[idx]] = inputs[\"input_ids\"]\n",
622
+ " attention_mask[idx, :seq_lengths[idx]] = 1\n",
623
+ " labels = input_ids.clone()\n",
624
+ " labels[labels == PAD_TOKEN_ID] = -100\n",
625
+ " return {\"input_ids\": input_ids, \"attention_mask\": attention_mask, \"labels\": labels}"
626
+ ]
627
+ },
628
+ {
629
+ "cell_type": "code",
630
+ "execution_count": 11,
631
+ "id": "193e7672",
632
+ "metadata": {
633
+ "execution": {
634
+ "iopub.execute_input": "2022-04-18T01:50:13.312349Z",
635
+ "iopub.status.busy": "2022-04-18T01:50:13.308720Z",
636
+ "iopub.status.idle": "2022-04-18T01:50:16.181794Z",
637
+ "shell.execute_reply": "2022-04-18T01:50:16.182874Z",
638
+ "shell.execute_reply.started": "2022-04-16T12:33:23.688559Z"
639
+ },
640
+ "papermill": {
641
+ "duration": 2.914467,
642
+ "end_time": "2022-04-18T01:50:16.183073",
643
+ "exception": false,
644
+ "start_time": "2022-04-18T01:50:13.268606",
645
+ "status": "completed"
646
+ },
647
+ "tags": []
648
+ },
649
+ "outputs": [
650
+ {
651
+ "name": "stdout",
652
+ "output_type": "stream",
653
+ "text": [
654
+ "Number of trainable parameters: 50873088\n"
655
+ ]
656
+ }
657
+ ],
658
+ "source": [
659
+ "config = GPT2Config(vocab_size = len(tokenizer.vocab),\n",
660
+ " n_positions = max_lengths,\n",
661
+ " n_embd = 768,\n",
662
+ " n_layer = 6,\n",
663
+ " n_head = 12,\n",
664
+ " eos_token_id=tokenizer.vocab[\"<eos>\"],\n",
665
+ " bos_token_id=tokenizer.vocab[\"<bos>\"])\n",
666
+ "model = GPT2LMHeadModel(config)\n",
667
+ "num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
668
+ "print(f\"Number of trainable parameters: {num_parameters}\")"
669
+ ]
670
+ },
671
+ {
672
+ "cell_type": "code",
673
+ "execution_count": 12,
674
+ "id": "484c0fc2",
675
+ "metadata": {
676
+ "execution": {
677
+ "iopub.execute_input": "2022-04-18T01:50:16.302344Z",
678
+ "iopub.status.busy": "2022-04-18T01:50:16.301561Z",
679
+ "iopub.status.idle": "2022-04-18T01:50:21.013819Z",
680
+ "shell.execute_reply": "2022-04-18T01:50:21.014253Z",
681
+ "shell.execute_reply.started": "2022-04-16T12:24:46.722086Z"
682
+ },
683
+ "papermill": {
684
+ "duration": 4.776549,
685
+ "end_time": "2022-04-18T01:50:21.014420",
686
+ "exception": false,
687
+ "start_time": "2022-04-18T01:50:16.237871",
688
+ "status": "completed"
689
+ },
690
+ "tags": []
691
+ },
692
+ "outputs": [
693
+ {
694
+ "name": "stderr",
695
+ "output_type": "stream",
696
+ "text": [
697
+ "Using amp half precision backend\n"
698
+ ]
699
+ }
700
+ ],
701
+ "source": [
702
+ "from transformers import EarlyStoppingCallback\n",
703
+ "training_args = TrainingArguments(\n",
704
+ " output_dir=\"results\",\n",
705
+ " eval_steps=2000,\n",
706
+ " save_steps=2000,\n",
707
+ " evaluation_strategy=\"steps\",\n",
708
+ " learning_rate=3e-4,\n",
709
+ " per_device_train_batch_size=32,\n",
710
+ " per_device_eval_batch_size=64,\n",
711
+ " save_total_limit=2,\n",
712
+ " num_train_epochs=8,\n",
713
+ " fp16=True,\n",
714
+ " report_to=\"none\",\n",
715
+ " dataloader_num_workers=2,\n",
716
+ " group_by_length=True,\n",
717
+ " metric_for_best_model = 'loss',\n",
718
+ " load_best_model_at_end=True\n",
719
+ ")\n",
720
+ "\n",
721
+ "trainer = Trainer(\n",
722
+ " model=model,\n",
723
+ " args=training_args,\n",
724
+ " train_dataset=train_dataset,\n",
725
+ " eval_dataset=val_dataset,\n",
726
+ " data_collator=collate_fn,\n",
727
+ " callbacks = [EarlyStoppingCallback(early_stopping_patience=1)]\n",
728
+ ")"
729
+ ]
730
+ },
731
+ {
732
+ "cell_type": "code",
733
+ "execution_count": 13,
734
+ "id": "fbc93ddf",
735
+ "metadata": {
736
+ "execution": {
737
+ "iopub.execute_input": "2022-04-18T01:50:21.089679Z",
738
+ "iopub.status.busy": "2022-04-18T01:50:21.089153Z",
739
+ "iopub.status.idle": "2022-04-18T05:43:12.456180Z",
740
+ "shell.execute_reply": "2022-04-18T05:43:12.455654Z",
741
+ "shell.execute_reply.started": "2022-04-16T12:25:06.616641Z"
742
+ },
743
+ "papermill": {
744
+ "duration": 13971.40658,
745
+ "end_time": "2022-04-18T05:43:12.456310",
746
+ "exception": false,
747
+ "start_time": "2022-04-18T01:50:21.049730",
748
+ "status": "completed"
749
+ },
750
+ "tags": []
751
+ },
752
+ "outputs": [
753
+ {
754
+ "name": "stderr",
755
+ "output_type": "stream",
756
+ "text": [
757
+ "/opt/conda/lib/python3.7/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use thePyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
758
+ " FutureWarning,\n",
759
+ "***** Running training *****\n",
760
+ " Num examples = 605320\n",
761
+ " Num Epochs = 8\n",
762
+ " Instantaneous batch size per device = 32\n",
763
+ " Total train batch size (w. parallel, distributed & accumulation) = 32\n",
764
+ " Gradient Accumulation steps = 1\n",
765
+ " Total optimization steps = 151336\n"
766
+ ]
767
+ },
768
+ {
769
+ "data": {
770
+ "text/html": [
771
+ "\n",
772
+ " <div>\n",
773
+ " \n",
774
+ " <progress value='58000' max='151336' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
775
+ " [ 58000/151336 3:52:48 < 6:14:39, 4.15 it/s, Epoch 3/8]\n",
776
+ " </div>\n",
777
+ " <table border=\"1\" class=\"dataframe\">\n",
778
+ " <thead>\n",
779
+ " <tr style=\"text-align: left;\">\n",
780
+ " <th>Step</th>\n",
781
+ " <th>Training Loss</th>\n",
782
+ " <th>Validation Loss</th>\n",
783
+ " </tr>\n",
784
+ " </thead>\n",
785
+ " <tbody>\n",
786
+ " <tr>\n",
787
+ " <td>2000</td>\n",
788
+ " <td>4.367700</td>\n",
789
+ " <td>4.235631</td>\n",
790
+ " </tr>\n",
791
+ " <tr>\n",
792
+ " <td>4000</td>\n",
793
+ " <td>3.953300</td>\n",
794
+ " <td>3.883913</td>\n",
795
+ " </tr>\n",
796
+ " <tr>\n",
797
+ " <td>6000</td>\n",
798
+ " <td>3.790700</td>\n",
799
+ " <td>3.730361</td>\n",
800
+ " </tr>\n",
801
+ " <tr>\n",
802
+ " <td>8000</td>\n",
803
+ " <td>3.699500</td>\n",
804
+ " <td>3.639758</td>\n",
805
+ " </tr>\n",
806
+ " <tr>\n",
807
+ " <td>10000</td>\n",
808
+ " <td>3.626500</td>\n",
809
+ " <td>3.581570</td>\n",
810
+ " </tr>\n",
811
+ " <tr>\n",
812
+ " <td>12000</td>\n",
813
+ " <td>3.575800</td>\n",
814
+ " <td>3.529477</td>\n",
815
+ " </tr>\n",
816
+ " <tr>\n",
817
+ " <td>14000</td>\n",
818
+ " <td>3.539500</td>\n",
819
+ " <td>3.490788</td>\n",
820
+ " </tr>\n",
821
+ " <tr>\n",
822
+ " <td>16000</td>\n",
823
+ " <td>3.506100</td>\n",
824
+ " <td>3.457211</td>\n",
825
+ " </tr>\n",
826
+ " <tr>\n",
827
+ " <td>18000</td>\n",
828
+ " <td>3.471100</td>\n",
829
+ " <td>3.427910</td>\n",
830
+ " </tr>\n",
831
+ " <tr>\n",
832
+ " <td>20000</td>\n",
833
+ " <td>3.411700</td>\n",
834
+ " <td>3.404946</td>\n",
835
+ " </tr>\n",
836
+ " <tr>\n",
837
+ " <td>22000</td>\n",
838
+ " <td>3.388500</td>\n",
839
+ " <td>3.384355</td>\n",
840
+ " </tr>\n",
841
+ " <tr>\n",
842
+ " <td>24000</td>\n",
843
+ " <td>3.384500</td>\n",
844
+ " <td>3.362393</td>\n",
845
+ " </tr>\n",
846
+ " <tr>\n",
847
+ " <td>26000</td>\n",
848
+ " <td>3.363900</td>\n",
849
+ " <td>3.345612</td>\n",
850
+ " </tr>\n",
851
+ " <tr>\n",
852
+ " <td>28000</td>\n",
853
+ " <td>3.350600</td>\n",
854
+ " <td>3.330873</td>\n",
855
+ " </tr>\n",
856
+ " <tr>\n",
857
+ " <td>30000</td>\n",
858
+ " <td>3.339300</td>\n",
859
+ " <td>3.316820</td>\n",
860
+ " </tr>\n",
861
+ " <tr>\n",
862
+ " <td>32000</td>\n",
863
+ " <td>3.320600</td>\n",
864
+ " <td>3.303108</td>\n",
865
+ " </tr>\n",
866
+ " <tr>\n",
867
+ " <td>34000</td>\n",
868
+ " <td>3.316600</td>\n",
869
+ " <td>3.286899</td>\n",
870
+ " </tr>\n",
871
+ " <tr>\n",
872
+ " <td>36000</td>\n",
873
+ " <td>3.312900</td>\n",
874
+ " <td>3.277738</td>\n",
875
+ " </tr>\n",
876
+ " <tr>\n",
877
+ " <td>38000</td>\n",
878
+ " <td>3.272500</td>\n",
879
+ " <td>3.271317</td>\n",
880
+ " </tr>\n",
881
+ " <tr>\n",
882
+ " <td>40000</td>\n",
883
+ " <td>3.228100</td>\n",
884
+ " <td>3.260200</td>\n",
885
+ " </tr>\n",
886
+ " <tr>\n",
887
+ " <td>42000</td>\n",
888
+ " <td>3.232000</td>\n",
889
+ " <td>3.252335</td>\n",
890
+ " </tr>\n",
891
+ " <tr>\n",
892
+ " <td>44000</td>\n",
893
+ " <td>3.220500</td>\n",
894
+ " <td>3.247865</td>\n",
895
+ " </tr>\n",
896
+ " <tr>\n",
897
+ " <td>46000</td>\n",
898
+ " <td>3.219700</td>\n",
899
+ " <td>3.236358</td>\n",
900
+ " </tr>\n",
901
+ " <tr>\n",
902
+ " <td>48000</td>\n",
903
+ " <td>3.218000</td>\n",
904
+ " <td>3.228396</td>\n",
905
+ " </tr>\n",
906
+ " <tr>\n",
907
+ " <td>50000</td>\n",
908
+ " <td>3.214900</td>\n",
909
+ " <td>3.219474</td>\n",
910
+ " </tr>\n",
911
+ " <tr>\n",
912
+ " <td>52000</td>\n",
913
+ " <td>3.207100</td>\n",
914
+ " <td>3.213028</td>\n",
915
+ " </tr>\n",
916
+ " <tr>\n",
917
+ " <td>54000</td>\n",
918
+ " <td>3.206800</td>\n",
919
+ " <td>3.206626</td>\n",
920
+ " </tr>\n",
921
+ " <tr>\n",
922
+ " <td>56000</td>\n",
923
+ " <td>3.196200</td>\n",
924
+ " <td>3.197654</td>\n",
925
+ " </tr>\n",
926
+ " <tr>\n",
927
+ " <td>58000</td>\n",
928
+ " <td>3.125000</td>\n",
929
+ " <td>3.197687</td>\n",
930
+ " </tr>\n",
931
+ " </tbody>\n",
932
+ "</table><p>"
933
+ ],
934
+ "text/plain": [
935
+ "<IPython.core.display.HTML object>"
936
+ ]
937
+ },
938
+ "metadata": {},
939
+ "output_type": "display_data"
940
+ },
941
+ {
942
+ "name": "stderr",
943
+ "output_type": "stream",
944
+ "text": [
945
+ "***** Running Evaluation *****\n",
946
+ " Num examples = 12354\n",
947
+ " Batch size = 64\n",
948
+ "Saving model checkpoint to results/checkpoint-2000\n",
949
+ "Configuration saved in results/checkpoint-2000/config.json\n",
950
+ "Model weights saved in results/checkpoint-2000/pytorch_model.bin\n",
951
+ "***** Running Evaluation *****\n",
952
+ " Num examples = 12354\n",
953
+ " Batch size = 64\n",
954
+ "Saving model checkpoint to results/checkpoint-4000\n",
955
+ "Configuration saved in results/checkpoint-4000/config.json\n",
956
+ "Model weights saved in results/checkpoint-4000/pytorch_model.bin\n",
957
+ "***** Running Evaluation *****\n",
958
+ " Num examples = 12354\n",
959
+ " Batch size = 64\n",
960
+ "Saving model checkpoint to results/checkpoint-6000\n",
961
+ "Configuration saved in results/checkpoint-6000/config.json\n",
962
+ "Model weights saved in results/checkpoint-6000/pytorch_model.bin\n",
963
+ "Deleting older checkpoint [results/checkpoint-2000] due to args.save_total_limit\n",
964
+ "***** Running Evaluation *****\n",
965
+ " Num examples = 12354\n",
966
+ " Batch size = 64\n",
967
+ "Saving model checkpoint to results/checkpoint-8000\n",
968
+ "Configuration saved in results/checkpoint-8000/config.json\n",
969
+ "Model weights saved in results/checkpoint-8000/pytorch_model.bin\n",
970
+ "Deleting older checkpoint [results/checkpoint-4000] due to args.save_total_limit\n",
971
+ "***** Running Evaluation *****\n",
972
+ " Num examples = 12354\n",
973
+ " Batch size = 64\n",
974
+ "Saving model checkpoint to results/checkpoint-10000\n",
975
+ "Configuration saved in results/checkpoint-10000/config.json\n",
976
+ "Model weights saved in results/checkpoint-10000/pytorch_model.bin\n",
977
+ "Deleting older checkpoint [results/checkpoint-6000] due to args.save_total_limit\n",
978
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
979
+ " args.max_grad_norm,\n",
980
+ "***** Running Evaluation *****\n",
981
+ " Num examples = 12354\n",
982
+ " Batch size = 64\n",
983
+ "Saving model checkpoint to results/checkpoint-12000\n",
984
+ "Configuration saved in results/checkpoint-12000/config.json\n",
985
+ "Model weights saved in results/checkpoint-12000/pytorch_model.bin\n",
986
+ "Deleting older checkpoint [results/checkpoint-8000] due to args.save_total_limit\n",
987
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
988
+ " args.max_grad_norm,\n",
989
+ "***** Running Evaluation *****\n",
990
+ " Num examples = 12354\n",
991
+ " Batch size = 64\n",
992
+ "Saving model checkpoint to results/checkpoint-14000\n",
993
+ "Configuration saved in results/checkpoint-14000/config.json\n",
994
+ "Model weights saved in results/checkpoint-14000/pytorch_model.bin\n",
995
+ "Deleting older checkpoint [results/checkpoint-10000] due to args.save_total_limit\n",
996
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
997
+ " args.max_grad_norm,\n",
998
+ "***** Running Evaluation *****\n",
999
+ " Num examples = 12354\n",
1000
+ " Batch size = 64\n",
1001
+ "Saving model checkpoint to results/checkpoint-16000\n",
1002
+ "Configuration saved in results/checkpoint-16000/config.json\n",
1003
+ "Model weights saved in results/checkpoint-16000/pytorch_model.bin\n",
1004
+ "Deleting older checkpoint [results/checkpoint-12000] due to args.save_total_limit\n",
1005
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1006
+ " args.max_grad_norm,\n",
1007
+ "***** Running Evaluation *****\n",
1008
+ " Num examples = 12354\n",
1009
+ " Batch size = 64\n",
1010
+ "Saving model checkpoint to results/checkpoint-18000\n",
1011
+ "Configuration saved in results/checkpoint-18000/config.json\n",
1012
+ "Model weights saved in results/checkpoint-18000/pytorch_model.bin\n",
1013
+ "Deleting older checkpoint [results/checkpoint-14000] due to args.save_total_limit\n",
1014
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1015
+ " args.max_grad_norm,\n",
1016
+ "***** Running Evaluation *****\n",
1017
+ " Num examples = 12354\n",
1018
+ " Batch size = 64\n",
1019
+ "Saving model checkpoint to results/checkpoint-20000\n",
1020
+ "Configuration saved in results/checkpoint-20000/config.json\n",
1021
+ "Model weights saved in results/checkpoint-20000/pytorch_model.bin\n",
1022
+ "Deleting older checkpoint [results/checkpoint-16000] due to args.save_total_limit\n",
1023
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1024
+ " args.max_grad_norm,\n",
1025
+ "***** Running Evaluation *****\n",
1026
+ " Num examples = 12354\n",
1027
+ " Batch size = 64\n",
1028
+ "Saving model checkpoint to results/checkpoint-22000\n",
1029
+ "Configuration saved in results/checkpoint-22000/config.json\n",
1030
+ "Model weights saved in results/checkpoint-22000/pytorch_model.bin\n",
1031
+ "Deleting older checkpoint [results/checkpoint-18000] due to args.save_total_limit\n",
1032
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1033
+ " args.max_grad_norm,\n",
1034
+ "***** Running Evaluation *****\n",
1035
+ " Num examples = 12354\n",
1036
+ " Batch size = 64\n",
1037
+ "Saving model checkpoint to results/checkpoint-24000\n",
1038
+ "Configuration saved in results/checkpoint-24000/config.json\n",
1039
+ "Model weights saved in results/checkpoint-24000/pytorch_model.bin\n",
1040
+ "Deleting older checkpoint [results/checkpoint-20000] due to args.save_total_limit\n",
1041
+ "***** Running Evaluation *****\n",
1042
+ " Num examples = 12354\n",
1043
+ " Batch size = 64\n",
1044
+ "Saving model checkpoint to results/checkpoint-26000\n",
1045
+ "Configuration saved in results/checkpoint-26000/config.json\n",
1046
+ "Model weights saved in results/checkpoint-26000/pytorch_model.bin\n",
1047
+ "Deleting older checkpoint [results/checkpoint-22000] due to args.save_total_limit\n",
1048
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1049
+ " args.max_grad_norm,\n",
1050
+ "***** Running Evaluation *****\n",
1051
+ " Num examples = 12354\n",
1052
+ " Batch size = 64\n",
1053
+ "Saving model checkpoint to results/checkpoint-28000\n",
1054
+ "Configuration saved in results/checkpoint-28000/config.json\n",
1055
+ "Model weights saved in results/checkpoint-28000/pytorch_model.bin\n",
1056
+ "Deleting older checkpoint [results/checkpoint-24000] due to args.save_total_limit\n",
1057
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1058
+ " args.max_grad_norm,\n",
1059
+ "***** Running Evaluation *****\n",
1060
+ " Num examples = 12354\n",
1061
+ " Batch size = 64\n",
1062
+ "Saving model checkpoint to results/checkpoint-30000\n",
1063
+ "Configuration saved in results/checkpoint-30000/config.json\n",
1064
+ "Model weights saved in results/checkpoint-30000/pytorch_model.bin\n",
1065
+ "Deleting older checkpoint [results/checkpoint-26000] due to args.save_total_limit\n",
1066
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1067
+ " args.max_grad_norm,\n",
1068
+ "***** Running Evaluation *****\n",
1069
+ " Num examples = 12354\n",
1070
+ " Batch size = 64\n",
1071
+ "Saving model checkpoint to results/checkpoint-32000\n",
1072
+ "Configuration saved in results/checkpoint-32000/config.json\n",
1073
+ "Model weights saved in results/checkpoint-32000/pytorch_model.bin\n",
1074
+ "Deleting older checkpoint [results/checkpoint-28000] due to args.save_total_limit\n",
1075
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1076
+ " args.max_grad_norm,\n",
1077
+ "***** Running Evaluation *****\n",
1078
+ " Num examples = 12354\n",
1079
+ " Batch size = 64\n",
1080
+ "Saving model checkpoint to results/checkpoint-34000\n",
1081
+ "Configuration saved in results/checkpoint-34000/config.json\n",
1082
+ "Model weights saved in results/checkpoint-34000/pytorch_model.bin\n",
1083
+ "Deleting older checkpoint [results/checkpoint-30000] due to args.save_total_limit\n",
1084
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1085
+ " args.max_grad_norm,\n",
1086
+ "***** Running Evaluation *****\n",
1087
+ " Num examples = 12354\n",
1088
+ " Batch size = 64\n",
1089
+ "Saving model checkpoint to results/checkpoint-36000\n",
1090
+ "Configuration saved in results/checkpoint-36000/config.json\n",
1091
+ "Model weights saved in results/checkpoint-36000/pytorch_model.bin\n",
1092
+ "Deleting older checkpoint [results/checkpoint-32000] due to args.save_total_limit\n",
1093
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1094
+ " args.max_grad_norm,\n",
1095
+ "***** Running Evaluation *****\n",
1096
+ " Num examples = 12354\n",
1097
+ " Batch size = 64\n",
1098
+ "Saving model checkpoint to results/checkpoint-38000\n",
1099
+ "Configuration saved in results/checkpoint-38000/config.json\n",
1100
+ "Model weights saved in results/checkpoint-38000/pytorch_model.bin\n",
1101
+ "Deleting older checkpoint [results/checkpoint-34000] due to args.save_total_limit\n",
1102
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1103
+ " args.max_grad_norm,\n",
1104
+ "***** Running Evaluation *****\n",
1105
+ " Num examples = 12354\n",
1106
+ " Batch size = 64\n",
1107
+ "Saving model checkpoint to results/checkpoint-40000\n",
1108
+ "Configuration saved in results/checkpoint-40000/config.json\n",
1109
+ "Model weights saved in results/checkpoint-40000/pytorch_model.bin\n",
1110
+ "Deleting older checkpoint [results/checkpoint-36000] due to args.save_total_limit\n",
1111
+ "***** Running Evaluation *****\n",
1112
+ " Num examples = 12354\n",
1113
+ " Batch size = 64\n",
1114
+ "Saving model checkpoint to results/checkpoint-42000\n",
1115
+ "Configuration saved in results/checkpoint-42000/config.json\n",
1116
+ "Model weights saved in results/checkpoint-42000/pytorch_model.bin\n",
1117
+ "Deleting older checkpoint [results/checkpoint-38000] due to args.save_total_limit\n",
1118
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1119
+ " args.max_grad_norm,\n",
1120
+ "***** Running Evaluation *****\n",
1121
+ " Num examples = 12354\n",
1122
+ " Batch size = 64\n",
1123
+ "Saving model checkpoint to results/checkpoint-44000\n",
1124
+ "Configuration saved in results/checkpoint-44000/config.json\n",
1125
+ "Model weights saved in results/checkpoint-44000/pytorch_model.bin\n",
1126
+ "Deleting older checkpoint [results/checkpoint-40000] due to args.save_total_limit\n",
1127
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1128
+ " args.max_grad_norm,\n",
1129
+ "***** Running Evaluation *****\n",
1130
+ " Num examples = 12354\n",
1131
+ " Batch size = 64\n",
1132
+ "Saving model checkpoint to results/checkpoint-46000\n",
1133
+ "Configuration saved in results/checkpoint-46000/config.json\n",
1134
+ "Model weights saved in results/checkpoint-46000/pytorch_model.bin\n",
1135
+ "Deleting older checkpoint [results/checkpoint-42000] due to args.save_total_limit\n",
1136
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1137
+ " args.max_grad_norm,\n",
1138
+ "***** Running Evaluation *****\n",
1139
+ " Num examples = 12354\n",
1140
+ " Batch size = 64\n",
1141
+ "Saving model checkpoint to results/checkpoint-48000\n",
1142
+ "Configuration saved in results/checkpoint-48000/config.json\n",
1143
+ "Model weights saved in results/checkpoint-48000/pytorch_model.bin\n",
1144
+ "Deleting older checkpoint [results/checkpoint-44000] due to args.save_total_limit\n",
1145
+ "***** Running Evaluation *****\n",
1146
+ " Num examples = 12354\n",
1147
+ " Batch size = 64\n",
1148
+ "Saving model checkpoint to results/checkpoint-50000\n",
1149
+ "Configuration saved in results/checkpoint-50000/config.json\n",
1150
+ "Model weights saved in results/checkpoint-50000/pytorch_model.bin\n",
1151
+ "Deleting older checkpoint [results/checkpoint-46000] due to args.save_total_limit\n",
1152
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1153
+ " args.max_grad_norm,\n",
1154
+ "***** Running Evaluation *****\n",
1155
+ " Num examples = 12354\n",
1156
+ " Batch size = 64\n",
1157
+ "Saving model checkpoint to results/checkpoint-52000\n",
1158
+ "Configuration saved in results/checkpoint-52000/config.json\n",
1159
+ "Model weights saved in results/checkpoint-52000/pytorch_model.bin\n",
1160
+ "Deleting older checkpoint [results/checkpoint-48000] due to args.save_total_limit\n",
1161
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1162
+ " args.max_grad_norm,\n",
1163
+ "***** Running Evaluation *****\n",
1164
+ " Num examples = 12354\n",
1165
+ " Batch size = 64\n",
1166
+ "Saving model checkpoint to results/checkpoint-54000\n",
1167
+ "Configuration saved in results/checkpoint-54000/config.json\n",
1168
+ "Model weights saved in results/checkpoint-54000/pytorch_model.bin\n",
1169
+ "Deleting older checkpoint [results/checkpoint-50000] due to args.save_total_limit\n",
1170
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1171
+ " args.max_grad_norm,\n",
1172
+ "***** Running Evaluation *****\n",
1173
+ " Num examples = 12354\n",
1174
+ " Batch size = 64\n",
1175
+ "Saving model checkpoint to results/checkpoint-56000\n",
1176
+ "Configuration saved in results/checkpoint-56000/config.json\n",
1177
+ "Model weights saved in results/checkpoint-56000/pytorch_model.bin\n",
1178
+ "Deleting older checkpoint [results/checkpoint-52000] due to args.save_total_limit\n",
1179
+ "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n",
1180
+ " args.max_grad_norm,\n",
1181
+ "***** Running Evaluation *****\n",
1182
+ " Num examples = 12354\n",
1183
+ " Batch size = 64\n",
1184
+ "Saving model checkpoint to results/checkpoint-58000\n",
1185
+ "Configuration saved in results/checkpoint-58000/config.json\n",
1186
+ "Model weights saved in results/checkpoint-58000/pytorch_model.bin\n",
1187
+ "Deleting older checkpoint [results/checkpoint-54000] due to args.save_total_limit\n",
1188
+ "\n",
1189
+ "\n",
1190
+ "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
1191
+ "\n",
1192
+ "\n",
1193
+ "Loading best model from results/checkpoint-56000 (score: 3.1976535320281982).\n"
1194
+ ]
1195
+ },
1196
+ {
1197
+ "data": {
1198
+ "text/plain": [
1199
+ "TrainOutput(global_step=58000, training_loss=3.448922660038389, metrics={'train_runtime': 13970.1599, 'train_samples_per_second': 346.636, 'train_steps_per_second': 10.833, 'total_flos': 5.124009885990912e+16, 'train_loss': 3.448922660038389, 'epoch': 3.07})"
1200
+ ]
1201
+ },
1202
+ "execution_count": 13,
1203
+ "metadata": {},
1204
+ "output_type": "execute_result"
1205
+ }
1206
+ ],
1207
+ "source": [
1208
+ "# n_embd = 768, n_layer = 12, n_head = 12, 58k steps, 93.4 M parameters, train loss 3.150600, val loss 3.163932\n",
1209
+ "# n_embd = 768, n_layer = 6, n_head = 12, steps, 50.9 M parameters, train loss , val loss \n",
1210
+ "# n_embd = 256, n_layer = 4, n_head = 8, steps, 5.94M parameters, train loss 3.374200, val loss 3.339147\n",
1211
+ "# n_embd = 128, n_layer = 2, n_head = 4, 54k steps, 1.78M parameters, train loss 3.819500, val loss 3.694196\n",
1212
+ "trainer.train()"
1213
+ ]
1214
+ },
1215
+ {
1216
+ "cell_type": "code",
1217
+ "execution_count": 14,
1218
+ "id": "127bea6d",
1219
+ "metadata": {
1220
+ "execution": {
1221
+ "iopub.execute_input": "2022-04-18T05:43:12.684274Z",
1222
+ "iopub.status.busy": "2022-04-18T05:43:12.683525Z",
1223
+ "iopub.status.idle": "2022-04-18T05:43:12.685531Z",
1224
+ "shell.execute_reply": "2022-04-18T05:43:12.685926Z",
1225
+ "shell.execute_reply.started": "2022-04-16T12:29:27.832584Z"
1226
+ },
1227
+ "papermill": {
1228
+ "duration": 0.122187,
1229
+ "end_time": "2022-04-18T05:43:12.686065",
1230
+ "exception": false,
1231
+ "start_time": "2022-04-18T05:43:12.563878",
1232
+ "status": "completed"
1233
+ },
1234
+ "tags": []
1235
+ },
1236
+ "outputs": [],
1237
+ "source": [
1238
+ "def generation(prompt, length):\n",
1239
+ " tokens = tokenizer(prompt=str(length) + prompt)\n",
1240
+ " output_ids = model.generate(tokens['input_ids'].to(\"cuda\"),\n",
1241
+ " do_sample=True, \n",
1242
+ " top_k=50,\n",
1243
+ " top_p=0.95,\n",
1244
+ " max_length=100)\n",
1245
+ " decoded_verse = tokenizer.decode(output_ids)[5:]\n",
1246
+ " return decoded_verse"
1247
+ ]
1248
+ },
1249
+ {
1250
+ "cell_type": "code",
1251
+ "execution_count": 15,
1252
+ "id": "e7f22169",
1253
+ "metadata": {
1254
+ "execution": {
1255
+ "iopub.execute_input": "2022-04-18T05:43:12.909172Z",
1256
+ "iopub.status.busy": "2022-04-18T05:43:12.908333Z",
1257
+ "iopub.status.idle": "2022-04-18T05:43:13.116636Z",
1258
+ "shell.execute_reply": "2022-04-18T05:43:13.117086Z",
1259
+ "shell.execute_reply.started": "2022-04-16T12:30:03.02288Z"
1260
+ },
1261
+ "papermill": {
1262
+ "duration": 0.325253,
1263
+ "end_time": "2022-04-18T05:43:13.117240",
1264
+ "exception": false,
1265
+ "start_time": "2022-04-18T05:43:12.791987",
1266
+ "status": "completed"
1267
+ },
1268
+ "tags": []
1269
+ },
1270
+ "outputs": [
1271
+ {
1272
+ "name": "stderr",
1273
+ "output_type": "stream",
1274
+ "text": [
1275
+ "Setting `pad_token_id` to `eos_token_id`:10741 for open-end generation.\n"
1276
+ ]
1277
+ },
1278
+ {
1279
+ "data": {
1280
+ "text/plain": [
1281
+ "'花明水在溪,好在波上得。月光忽在溪,圆明了不蚀。'"
1282
+ ]
1283
+ },
1284
+ "execution_count": 15,
1285
+ "metadata": {},
1286
+ "output_type": "execute_result"
1287
+ }
1288
+ ],
1289
+ "source": [
1290
+ "generation(\"花好月圆\", length=5)"
1291
+ ]
1292
+ },
1293
+ {
1294
+ "cell_type": "code",
1295
+ "execution_count": 16,
1296
+ "id": "536bd1dd",
1297
+ "metadata": {
1298
+ "execution": {
1299
+ "iopub.execute_input": "2022-04-18T05:43:13.336560Z",
1300
+ "iopub.status.busy": "2022-04-18T05:43:13.335672Z",
1301
+ "iopub.status.idle": "2022-04-18T05:43:13.521122Z",
1302
+ "shell.execute_reply": "2022-04-18T05:43:13.521536Z",
1303
+ "shell.execute_reply.started": "2022-04-16T12:29:42.949166Z"
1304
+ },
1305
+ "papermill": {
1306
+ "duration": 0.298044,
1307
+ "end_time": "2022-04-18T05:43:13.521677",
1308
+ "exception": false,
1309
+ "start_time": "2022-04-18T05:43:13.223633",
1310
+ "status": "completed"
1311
+ },
1312
+ "tags": []
1313
+ },
1314
+ "outputs": [
1315
+ {
1316
+ "name": "stderr",
1317
+ "output_type": "stream",
1318
+ "text": [
1319
+ "Setting `pad_token_id` to `eos_token_id`:10741 for open-end generation.\n"
1320
+ ]
1321
+ },
1322
+ {
1323
+ "data": {
1324
+ "text/plain": [
1325
+ "'下山来访小园中,楼阁清幽景物同。吃吃僧斋分数宿,饭松茶灶有馀功。'"
1326
+ ]
1327
+ },
1328
+ "execution_count": 16,
1329
+ "metadata": {},
1330
+ "output_type": "execute_result"
1331
+ }
1332
+ ],
1333
+ "source": [
1334
+ "generation(\"下楼吃饭\", length=7)"
1335
+ ]
1336
+ },
1337
+ {
1338
+ "cell_type": "code",
1339
+ "execution_count": 17,
1340
+ "id": "dd75f0be",
1341
+ "metadata": {
1342
+ "execution": {
1343
+ "iopub.execute_input": "2022-04-18T05:43:13.745410Z",
1344
+ "iopub.status.busy": "2022-04-18T05:43:13.744513Z",
1345
+ "iopub.status.idle": "2022-04-18T05:43:14.123442Z",
1346
+ "shell.execute_reply": "2022-04-18T05:43:14.123883Z",
1347
+ "shell.execute_reply.started": "2022-04-16T12:29:44.683058Z"
1348
+ },
1349
+ "papermill": {
1350
+ "duration": 0.490314,
1351
+ "end_time": "2022-04-18T05:43:14.124043",
1352
+ "exception": false,
1353
+ "start_time": "2022-04-18T05:43:13.633729",
1354
+ "status": "completed"
1355
+ },
1356
+ "tags": []
1357
+ },
1358
+ "outputs": [
1359
+ {
1360
+ "name": "stderr",
1361
+ "output_type": "stream",
1362
+ "text": [
1363
+ "Setting `pad_token_id` to `eos_token_id`:10741 for open-end generation.\n"
1364
+ ]
1365
+ },
1366
+ {
1367
+ "data": {
1368
+ "text/plain": [
1369
+ "'大深无坐今夕分明是别年,晚陪花下醉清眠。加餐我自能高咏,班列君应似谪仙。大地星河连太皞,深宵星斗下华躔。无言独向閒庭静,坐对西南又一天。'"
1370
+ ]
1371
+ },
1372
+ "execution_count": 17,
1373
+ "metadata": {},
1374
+ "output_type": "execute_result"
1375
+ }
1376
+ ],
1377
+ "source": [
1378
+ "generation(\"今晚加班\", length=7)"
1379
+ ]
1380
+ },
1381
+ {
1382
+ "cell_type": "code",
1383
+ "execution_count": 18,
1384
+ "id": "393331e4",
1385
+ "metadata": {
1386
+ "execution": {
1387
+ "iopub.execute_input": "2022-04-18T05:43:14.346788Z",
1388
+ "iopub.status.busy": "2022-04-18T05:43:14.345916Z",
1389
+ "iopub.status.idle": "2022-04-18T05:43:14.539457Z",
1390
+ "shell.execute_reply": "2022-04-18T05:43:14.539890Z",
1391
+ "shell.execute_reply.started": "2022-04-16T12:29:56.371973Z"
1392
+ },
1393
+ "papermill": {
1394
+ "duration": 0.307929,
1395
+ "end_time": "2022-04-18T05:43:14.540041",
1396
+ "exception": false,
1397
+ "start_time": "2022-04-18T05:43:14.232112",
1398
+ "status": "completed"
1399
+ },
1400
+ "tags": []
1401
+ },
1402
+ "outputs": [
1403
+ {
1404
+ "name": "stderr",
1405
+ "output_type": "stream",
1406
+ "text": [
1407
+ "Setting `pad_token_id` to `eos_token_id`:10741 for open-end generation.\n"
1408
+ ]
1409
+ },
1410
+ {
1411
+ "data": {
1412
+ "text/plain": [
1413
+ "'加餐未暇望天颜,班列群仙戏綵幡。内史赐花频赐宴,卷帘先为看朝元。'"
1414
+ ]
1415
+ },
1416
+ "execution_count": 18,
1417
+ "metadata": {},
1418
+ "output_type": "execute_result"
1419
+ }
1420
+ ],
1421
+ "source": [
1422
+ "generation(\"加班内卷\", length=7)"
1423
+ ]
1424
+ },
1425
+ {
1426
+ "cell_type": "code",
1427
+ "execution_count": 19,
1428
+ "id": "ea886add",
1429
+ "metadata": {
1430
+ "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
1431
+ "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
1432
+ "execution": {
1433
+ "iopub.execute_input": "2022-04-18T05:43:14.760813Z",
1434
+ "iopub.status.busy": "2022-04-18T05:43:14.759955Z",
1435
+ "iopub.status.idle": "2022-04-18T05:43:14.761716Z",
1436
+ "shell.execute_reply": "2022-04-18T05:43:14.762174Z"
1437
+ },
1438
+ "papermill": {
1439
+ "duration": 0.113971,
1440
+ "end_time": "2022-04-18T05:43:14.762305",
1441
+ "exception": false,
1442
+ "start_time": "2022-04-18T05:43:14.648334",
1443
+ "status": "completed"
1444
+ },
1445
+ "tags": []
1446
+ },
1447
+ "outputs": [],
1448
+ "source": [
1449
+ "# # This Python 3 environment comes with many helpful analytics libraries installed\n",
1450
+ "# # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n",
1451
+ "# # For example, here's several helpful packages to load\n",
1452
+ "\n",
1453
+ "# import numpy as np # linear algebra\n",
1454
+ "# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
1455
+ "\n",
1456
+ "# # Input data files are available in the read-only \"../input/\" directory\n",
1457
+ "# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n",
1458
+ "\n",
1459
+ "# import os\n",
1460
+ "# for dirname, _, filenames in os.walk('/kaggle/input'):\n",
1461
+ "# for filename in filenames:\n",
1462
+ "# print(os.path.join(dirname, filename))\n",
1463
+ "\n",
1464
+ "# # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n",
1465
+ "# # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session"
1466
+ ]
1467
+ }
1468
+ ],
1469
+ "metadata": {
1470
+ "kernelspec": {
1471
+ "display_name": "Python 3 (ipykernel)",
1472
+ "language": "python",
1473
+ "name": "python3"
1474
+ },
1475
+ "language_info": {
1476
+ "codemirror_mode": {
1477
+ "name": "ipython",
1478
+ "version": 3
1479
+ },
1480
+ "file_extension": ".py",
1481
+ "mimetype": "text/x-python",
1482
+ "name": "python",
1483
+ "nbconvert_exporter": "python",
1484
+ "pygments_lexer": "ipython3",
1485
+ "version": "3.9.10"
1486
+ },
1487
+ "papermill": {
1488
+ "default_parameters": {},
1489
+ "duration": 14060.414143,
1490
+ "end_time": "2022-04-18T05:43:17.806051",
1491
+ "environment_variables": {},
1492
+ "exception": null,
1493
+ "input_path": "__notebook__.ipynb",
1494
+ "output_path": "__notebook__.ipynb",
1495
+ "parameters": {},
1496
+ "start_time": "2022-04-18T01:48:57.391908",
1497
+ "version": "2.3.3"
1498
+ }
1499
+ },
1500
+ "nbformat": 4,
1501
+ "nbformat_minor": 5
1502
+ }
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
saved_model/.DS_Store ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60aad9e84184acf739b198f050b82fbf5ae133be5b284e3b9d99c823c916b132
3
+ size 6148
saved_model/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6ef7db0cc72a8278be6c82773090292ee4be0957b0a75161241b817b010439e
3
+ size 748
saved_model/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad28f6bf811a2cde91efae753a8769a83461ef0e7c0ee740b584a671a9519f4a
3
+ size 203614109
saved_model/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f54f62ade9dfcf56bd08cc74bfa9ec22ef590bc9508a381e42d21dc95a89c2f
3
+ size 171436