ACMC
commited on
Commit
•
7e73556
0
Parent(s):
initial commit
Browse files- .gitattributes +35 -0
- .gitignore +2 -0
- README.md +12 -0
- app.py +204 -0
- requirements.txt +13 -0
- utils.py +342 -0
- validation.py +174 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*.jsonl
|
2 |
+
__pycache__/
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Whatsapp Chats Finetuning Formatter
|
3 |
+
emoji: 👀
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.20.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
from uuid import uuid4
|
3 |
+
import gradio as gr
|
4 |
+
import datasets
|
5 |
+
import json
|
6 |
+
import io
|
7 |
+
from utils import (
|
8 |
+
process_chat_file,
|
9 |
+
transform_conversations_dataset_into_training_examples,
|
10 |
+
)
|
11 |
+
from validation import (
|
12 |
+
check_format_errors,
|
13 |
+
check_token_counts,
|
14 |
+
estimate_cost,
|
15 |
+
get_distributions,
|
16 |
+
)
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
|
19 |
+
|
20 |
+
def convert_to_dataset(files, do_spelling_correction, progress):
|
21 |
+
modified_dataset = None
|
22 |
+
for file in progress.tqdm(files, desc="Processing files"):
|
23 |
+
if modified_dataset is None:
|
24 |
+
# First file
|
25 |
+
modified_dataset = process_chat_file(file, do_spelling_correction=do_spelling_correction)
|
26 |
+
else:
|
27 |
+
# Concatenate the datasets
|
28 |
+
this_file_dataset = process_chat_file(file, do_spelling_correction=do_spelling_correction)
|
29 |
+
modified_dataset = datasets.concatenate_datasets(
|
30 |
+
[modified_dataset, this_file_dataset]
|
31 |
+
)
|
32 |
+
return modified_dataset
|
33 |
+
|
34 |
+
|
35 |
+
def file_upload_callback(files, system_prompt, do_spelling_correction, validation_split, progress=gr.Progress()):
|
36 |
+
print(f"Processing {files}")
|
37 |
+
full_system_prompt = f"""You are a chatbot. Your goal is to simulate realistic, natural chat conversations as if you were me.
|
38 |
+
# Task
|
39 |
+
A participant can send multiple messages in a row, delimited by '\"', in the following schema:
|
40 |
+
{{string}}[]. Your answer always needs to be JSON compliant. Always start your answer with [\"
|
41 |
+
# Information about me
|
42 |
+
You should use the following information about me to answer:
|
43 |
+
{system_prompt}
|
44 |
+
# Example
|
45 |
+
[{{\"role\":\"user\",\"content\":\"[\"Hello!\",\"How are you?\"]\"}},{{\"role\":\"assistant\",\"content\":\"[\"Hi!\",\"I'm doing great.\",\"What about you?\"]\"}},{{\"role\":\"user\",\"content\":\"[\"I'm doing well.\",\"Have you been travelling?\"]\"}}]
|
46 |
+
Response:
|
47 |
+
[{{\"role\":\"assistant\",\"content\":\"[\"Yes, I've been to many places.\",\"I love travelling.\"]\"}}]"""
|
48 |
+
|
49 |
+
# Avoid using the full system prompt for now, as it is too long and increases the cost of the training
|
50 |
+
full_system_prompt = system_prompt
|
51 |
+
dataset = convert_to_dataset(files=files, progress=progress, do_spelling_correction=do_spelling_correction)
|
52 |
+
training_examples_ds = transform_conversations_dataset_into_training_examples(
|
53 |
+
conversations_ds=dataset, system_prompt=full_system_prompt
|
54 |
+
)
|
55 |
+
|
56 |
+
# Split into training and validation datasets (80% and 20%)
|
57 |
+
training_examples_ds = training_examples_ds.train_test_split(test_size=validation_split, seed=42)
|
58 |
+
training_examples_ds, validation_examples_ds = training_examples_ds["train"], training_examples_ds["test"]
|
59 |
+
|
60 |
+
format_errors = check_format_errors(training_examples_ds)
|
61 |
+
distributions = get_distributions(training_examples_ds)
|
62 |
+
cost_stats = estimate_cost(training_examples_ds)
|
63 |
+
|
64 |
+
stats = {
|
65 |
+
"Format Errors": format_errors,
|
66 |
+
"Number of examples missing system message": distributions["n_missing_system"],
|
67 |
+
"Number of examples missing user message": distributions["n_missing_user"],
|
68 |
+
"Cost Statistics": cost_stats,
|
69 |
+
}
|
70 |
+
|
71 |
+
fig_num_messages_distribution_plot = plt.figure()
|
72 |
+
num_messages_distribution_plot = plt.hist(distributions["n_messages"], bins=20)
|
73 |
+
|
74 |
+
fig_num_total_tokens_per_example_plot = plt.figure()
|
75 |
+
num_total_tokens_per_example_plot = plt.hist(distributions["convo_lens"], bins=20)
|
76 |
+
|
77 |
+
fig_num_assistant_tokens_per_example_plot = plt.figure()
|
78 |
+
num_assistant_tokens_per_example_plot = plt.hist(
|
79 |
+
distributions["assistant_message_lens"],
|
80 |
+
bins=20
|
81 |
+
)
|
82 |
+
|
83 |
+
# The DownloadFile component requires a path to the file, it can't accept a buffer to keep the file in memory.
|
84 |
+
# Therefore, we need to save the buffer to a file and then pass the path to the DownloadFile component.
|
85 |
+
# However, if different users are using the app at the same time, we need to make sure that the file is unique AND that no user can access the file of another user.
|
86 |
+
# We can use a UUID generator to create a unique file name.
|
87 |
+
uuid = str(uuid4())
|
88 |
+
file_path = f"training_examples_{uuid}.jsonl"
|
89 |
+
training_examples_ds.to_json(path_or_buf=file_path, force_ascii=False)
|
90 |
+
|
91 |
+
file_path_validation = f"validation_examples_{uuid}.jsonl"
|
92 |
+
validation_examples_ds.to_json(path_or_buf=file_path_validation, force_ascii=False)
|
93 |
+
|
94 |
+
return (
|
95 |
+
file_path,
|
96 |
+
gr.update(visible=True),
|
97 |
+
file_path_validation,
|
98 |
+
gr.update(visible=True),
|
99 |
+
stats,
|
100 |
+
fig_num_messages_distribution_plot,
|
101 |
+
fig_num_total_tokens_per_example_plot,
|
102 |
+
fig_num_assistant_tokens_per_example_plot
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
+
def remove_file_and_hide_button(file_path):
|
107 |
+
import os
|
108 |
+
|
109 |
+
try:
|
110 |
+
os.remove(file_path)
|
111 |
+
except Exception as e:
|
112 |
+
print(f"Error removing file {file_path}: {e}")
|
113 |
+
|
114 |
+
return gr.update(visible=False)
|
115 |
+
|
116 |
+
|
117 |
+
theme = gr.themes.Default(primary_hue="cyan", secondary_hue="fuchsia")
|
118 |
+
|
119 |
+
with gr.Blocks(theme=theme) as demo:
|
120 |
+
gr.Markdown(
|
121 |
+
"""
|
122 |
+
# WhatsApp Chat to Dataset Converter
|
123 |
+
Upload your WhatsApp chat files and convert them into a Dataset.
|
124 |
+
"""
|
125 |
+
)
|
126 |
+
gr.Markdown(
|
127 |
+
"""
|
128 |
+
## Instructions
|
129 |
+
1. Click on the "Upload WhatsApp Chat Files" button.
|
130 |
+
2. Select the WhatsApp chat files you want to convert.
|
131 |
+
3. Write a prompt about you to give context to the training examples.
|
132 |
+
4. Click on the "Submit" button.
|
133 |
+
5. Wait for the process to finish.
|
134 |
+
6. Download the generated training examples as a JSONL file.
|
135 |
+
7. Use the training examples to train your own model.
|
136 |
+
"""
|
137 |
+
)
|
138 |
+
|
139 |
+
input_files = gr.File(
|
140 |
+
label="Upload WhatsApp Chat Files",
|
141 |
+
type="filepath",
|
142 |
+
file_count="multiple",
|
143 |
+
file_types=["txt"],
|
144 |
+
)
|
145 |
+
|
146 |
+
system_prompt = gr.Textbox(
|
147 |
+
label="System Prompt",
|
148 |
+
placeholder="Background information about you.",
|
149 |
+
lines=5,
|
150 |
+
info="Enter the system prompt to be used for the training examples generation. This is the background information about you that will be used to generate the training examples.",
|
151 |
+
value="""Aldan is an AI researcher who loves to play around with AI systems, travelling and learning new things.""",
|
152 |
+
)
|
153 |
+
|
154 |
+
do_spelling_correction = gr.Checkbox(
|
155 |
+
label="Do Spelling Correction (English)",
|
156 |
+
info="Check this box if you want to perform spelling correction on the chat messages before generating the training examples.",
|
157 |
+
)
|
158 |
+
|
159 |
+
# Allow the user to choose the validation split size
|
160 |
+
validation_split = gr.Slider(
|
161 |
+
minimum=0.0,
|
162 |
+
maximum=0.5,
|
163 |
+
value=0.2,
|
164 |
+
interactive=True,
|
165 |
+
label="Validation Split",
|
166 |
+
info="Choose the percentage of the dataset to be used for validation. For example, if you choose 0.2, 20% of the dataset will be used for validation and 80% for training.",
|
167 |
+
)
|
168 |
+
|
169 |
+
submit = gr.Button(value="Submit", variant="primary")
|
170 |
+
|
171 |
+
output_file = gr.DownloadButton(label="Download Generated Training Examples", visible=False, variant="primary")
|
172 |
+
output_file_validation = gr.DownloadButton(label="Download Generated Validation Examples", visible=False, variant="secondary")
|
173 |
+
# output_example = gr.JSON(label="Example Training Example")
|
174 |
+
|
175 |
+
with gr.Group():
|
176 |
+
# Statistics about the dataset
|
177 |
+
gr.Markdown("## Statistics")
|
178 |
+
written_stats = gr.JSON()
|
179 |
+
num_messages_distribution_plot = gr.Plot(label="Number of Messages Distribution")
|
180 |
+
num_total_tokens_per_example_plot = gr.Plot(label="Total Number of Tokens per Example")
|
181 |
+
num_assistant_tokens_per_example_plot = gr.Plot(
|
182 |
+
label="Number of Assistant Tokens per Example"
|
183 |
+
)
|
184 |
+
|
185 |
+
submit.click(
|
186 |
+
file_upload_callback,
|
187 |
+
inputs=[input_files, system_prompt, do_spelling_correction, validation_split],
|
188 |
+
outputs=[
|
189 |
+
output_file,
|
190 |
+
output_file,
|
191 |
+
output_file_validation,
|
192 |
+
output_file_validation,
|
193 |
+
written_stats,
|
194 |
+
num_messages_distribution_plot,
|
195 |
+
num_total_tokens_per_example_plot,
|
196 |
+
num_assistant_tokens_per_example_plot,
|
197 |
+
]
|
198 |
+
)
|
199 |
+
|
200 |
+
output_file.click(remove_file_and_hide_button, inputs=[output_file], outputs=[output_file])
|
201 |
+
output_file_validation.click(remove_file_and_hide_button, inputs=[output_file_validation], outputs=[output_file_validation])
|
202 |
+
|
203 |
+
if __name__ == "__main__":
|
204 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
contextualSpellCheck==0.4.4
|
2 |
+
datasets==2.18.0
|
3 |
+
es-core-news-sm @ https://github.com/explosion/spacy-models/releases/download/es_core_news_sm-3.7.0/es_core_news_sm-3.7.0-py3-none-any.whl#sha256=61e6e5530941f5880166855f09f60d7e6ba79ec1e8e45f96244bdb1eb169eb1d
|
4 |
+
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889
|
5 |
+
gradio==4.20.1
|
6 |
+
matplotlib==3.8.3
|
7 |
+
numpy==1.26.4
|
8 |
+
pandas==2.2.1
|
9 |
+
spacy==3.7.4
|
10 |
+
tiktoken==0.6.0
|
11 |
+
torch==2.2.1
|
12 |
+
transformers==4.38.2
|
13 |
+
pyspellchecker==0.8.1
|
utils.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datasets
|
2 |
+
import datetime
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
|
6 |
+
import re
|
7 |
+
|
8 |
+
exp = re.compile(
|
9 |
+
r"(?P<month>\d+)/(?P<day>\d+)/(?P<year>\d+), (?P<hour>\d+):(?P<minute>\d+) - (?P<contact_name>.+): (?P<message>.+)"
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
def process_line(example):
|
14 |
+
# The lines have this format: dd/mm/yy, hh:mm - <person>: <msg>
|
15 |
+
try:
|
16 |
+
groups = exp.match(example["text"]).groupdict()
|
17 |
+
timestamp = datetime.datetime(
|
18 |
+
int(groups["year"]),
|
19 |
+
int(groups["month"]),
|
20 |
+
int(groups["day"]),
|
21 |
+
int(groups["hour"]),
|
22 |
+
int(groups["minute"]),
|
23 |
+
).timestamp()
|
24 |
+
return {
|
25 |
+
"message": groups["message"],
|
26 |
+
"contact_name": groups["contact_name"],
|
27 |
+
"timestamp": timestamp,
|
28 |
+
}
|
29 |
+
except Exception as e:
|
30 |
+
print(e)
|
31 |
+
print(example["text"])
|
32 |
+
raise e
|
33 |
+
|
34 |
+
|
35 |
+
# %%
|
36 |
+
# Now, create message groups ('conversations')
|
37 |
+
# The idea is to group messages that are close in time
|
38 |
+
# We'll use a 240 minute threshold
|
39 |
+
MINUTES_THRESHOLD = 240
|
40 |
+
|
41 |
+
|
42 |
+
def group_messages(messages_iterable):
|
43 |
+
groups = []
|
44 |
+
current_group = [next(messages_iterable)]
|
45 |
+
for message in messages_iterable:
|
46 |
+
assert len(current_group) > 0 # We should never have an empty group
|
47 |
+
if (
|
48 |
+
message["timestamp"] - current_group[-1]["timestamp"]
|
49 |
+
< MINUTES_THRESHOLD * 60
|
50 |
+
):
|
51 |
+
current_group.append(message)
|
52 |
+
else:
|
53 |
+
groups.append(current_group)
|
54 |
+
current_group = [message]
|
55 |
+
groups.append(current_group)
|
56 |
+
return groups
|
57 |
+
|
58 |
+
|
59 |
+
def printable_conversation(conversation):
|
60 |
+
return "\n".join(
|
61 |
+
[f"{message['contact_name']}: {message['message']}" for message in conversation]
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
# %%
|
66 |
+
# Use spacy to spell check the messages
|
67 |
+
import spacy
|
68 |
+
import contextualSpellCheck
|
69 |
+
from spellchecker import SpellChecker
|
70 |
+
spell = SpellChecker()
|
71 |
+
#nlp = spacy.load("es_core_news_sm")
|
72 |
+
nlp = spacy.load("en_core_web_sm")
|
73 |
+
|
74 |
+
|
75 |
+
def spell_check_conversation(conversation):
|
76 |
+
for i, message in enumerate(conversation["conversations"]):
|
77 |
+
# Use SpaCy to get the words
|
78 |
+
words = spell.split_words(message["message"])
|
79 |
+
print(f"Words: {words}")
|
80 |
+
corrected_message = []
|
81 |
+
for word in words:
|
82 |
+
correction = spell.correction(word)
|
83 |
+
if (correction != None) and (correction != word):
|
84 |
+
print(f"Spell check: {word} -> {correction}")
|
85 |
+
corrected_message.append(correction)
|
86 |
+
else:
|
87 |
+
corrected_message.append(word)
|
88 |
+
|
89 |
+
print(f"Corrected message: {corrected_message}")
|
90 |
+
joined_message = " ".join(corrected_message)
|
91 |
+
conversation["conversations"][i]["message"] = joined_message
|
92 |
+
|
93 |
+
return conversation
|
94 |
+
|
95 |
+
|
96 |
+
def spell_check_conversation_spacy(conversation):
|
97 |
+
|
98 |
+
nlp.add_pipe(
|
99 |
+
"contextual spellchecker",
|
100 |
+
config={
|
101 |
+
"model_name": "bert-base-multilingual-uncased",
|
102 |
+
"max_edit_dist": 2,
|
103 |
+
},
|
104 |
+
)
|
105 |
+
docs = list(nlp.pipe([msg["message"] for msg in conversation["conversations"]]))
|
106 |
+
for i, doc in enumerate(docs):
|
107 |
+
if doc._.performed_spellCheck:
|
108 |
+
print(f"Spell checked: {doc.text} -> {doc._.outcome_spellCheck}")
|
109 |
+
conversation["conversations"][i]["message"] = doc._.outcome_spellCheck
|
110 |
+
|
111 |
+
return conversation
|
112 |
+
|
113 |
+
|
114 |
+
def remove_whatapp_annotations(conversation):
|
115 |
+
"""
|
116 |
+
Removes the following annotations from the messages:
|
117 |
+
- <This message was edited>
|
118 |
+
"""
|
119 |
+
for message in conversation["conversations"]:
|
120 |
+
message["message"] = re.sub(
|
121 |
+
r"<This message was edited>", "", message["message"]
|
122 |
+
)
|
123 |
+
return conversation
|
124 |
+
|
125 |
+
|
126 |
+
# %%
|
127 |
+
"""
|
128 |
+
Sometimes, people write concurrently in the same conversation. We'll try to detect that and reorder the messages.
|
129 |
+
For example, if we have a conversation like this:
|
130 |
+
A: Hi
|
131 |
+
A: How are you?
|
132 |
+
B: Hi
|
133 |
+
B: I'm fine, thanks
|
134 |
+
A: I'm fine too
|
135 |
+
We'll reorder it to:
|
136 |
+
A: Hi
|
137 |
+
B: Hi
|
138 |
+
A: How are you?
|
139 |
+
B: I'm fine, thanks
|
140 |
+
A: I'm fine too
|
141 |
+
|
142 |
+
To do it, we'll use MobileBERT with the next sentence prediction head. We'll use the first message as the first sentence, and the second message as the second sentence. If the model predicts that the second sentence is more likely to be the next sentence, we'll swap the messages.
|
143 |
+
"""
|
144 |
+
|
145 |
+
from transformers import AutoTokenizer, AutoModelForNextSentencePrediction
|
146 |
+
import torch
|
147 |
+
|
148 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
149 |
+
model = AutoModelForNextSentencePrediction.from_pretrained("bert-base-uncased")
|
150 |
+
if torch.cuda.is_available():
|
151 |
+
model.cuda()
|
152 |
+
|
153 |
+
|
154 |
+
def swap_messages_if_needed(message1, message2):
|
155 |
+
# If the messages have the same contact, we don't swap them
|
156 |
+
if message1["contact_name"] == message2["contact_name"]:
|
157 |
+
return message1, message2
|
158 |
+
# The timestamp must have a difference of less than 2 minutes. First, convert to datetime
|
159 |
+
datetime1 = datetime.datetime.fromtimestamp(message1["timestamp"])
|
160 |
+
datetime2 = datetime.datetime.fromtimestamp(message2["timestamp"])
|
161 |
+
if (datetime2 - datetime1).total_seconds() > 2 * 60:
|
162 |
+
return message1, message2
|
163 |
+
# If one of the messages has less than 3 words, we don't swap them
|
164 |
+
if len(message1["message"].split()) < 3 or len(message2["message"].split()) < 3:
|
165 |
+
return message1, message2
|
166 |
+
# We'll use the first message as the first sentence, and the second message as the second sentence
|
167 |
+
inputs = tokenizer(message1["message"], message2["message"], return_tensors="pt")
|
168 |
+
reverse_inputs = tokenizer(
|
169 |
+
message2["message"], message1["message"], return_tensors="pt"
|
170 |
+
)
|
171 |
+
# Join them in a single batch
|
172 |
+
joined_inputs = torch.cat([inputs["input_ids"], reverse_inputs["input_ids"]], dim=0)
|
173 |
+
if torch.cuda.is_available():
|
174 |
+
joined_inputs = joined_inputs.cuda()
|
175 |
+
with torch.no_grad():
|
176 |
+
outputs = model(input_ids=joined_inputs)
|
177 |
+
# The output is a tuple with the logits for each class (next sentence or not)
|
178 |
+
# We'll take the first one (next sentence)
|
179 |
+
logits = outputs[0]
|
180 |
+
# Apply softmax
|
181 |
+
logits = torch.softmax(logits, dim=1)
|
182 |
+
# We have two probabilities: the probability of 1 -> 2, and the probability of 2 -> 1
|
183 |
+
# We'll take the difference
|
184 |
+
swap = logits[0, 0] - logits[1, 0] < -0.2
|
185 |
+
if swap:
|
186 |
+
# Swap the messages
|
187 |
+
print(f"YES Swapping messages: {message1['message']} <-> {message2['message']}")
|
188 |
+
return message2, message1
|
189 |
+
else:
|
190 |
+
# print(f"NOT swapping messages: {message1['message']} <-> {message2['message']}")
|
191 |
+
return message1, message2
|
192 |
+
|
193 |
+
|
194 |
+
def swap_messages_if_needed_in_conversation(conversation):
|
195 |
+
# We'll use the first message as the first sentence, and the second message as the second sentence
|
196 |
+
if len(conversation) <= 2:
|
197 |
+
return conversation
|
198 |
+
new_conversation = [
|
199 |
+
conversation[0],
|
200 |
+
conversation[1],
|
201 |
+
] # We'll always keep the first message in the same position
|
202 |
+
for i in range(2, len(conversation)):
|
203 |
+
message1 = new_conversation[-1]
|
204 |
+
message2 = conversation[i]
|
205 |
+
message1, message2 = swap_messages_if_needed(message1, message2)
|
206 |
+
new_conversation[-1] = message1
|
207 |
+
new_conversation.append(message2)
|
208 |
+
|
209 |
+
# print(f"\nOriginal conversation:\n{printable_conversation(conversation)}")
|
210 |
+
# print(f"\nNew conversation:\n{printable_conversation(new_conversation)}")
|
211 |
+
return new_conversation
|
212 |
+
|
213 |
+
|
214 |
+
test_conversation = [
|
215 |
+
{"message": "Hola!", "contact_name": "A", "timestamp": 1},
|
216 |
+
{
|
217 |
+
"message": "Está todo bien, gracias por preguntar!",
|
218 |
+
"contact_name": "B",
|
219 |
+
"timestamp": 2,
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"message": "Hola, qué tal estás? Espero que vaya todo bien por España.",
|
223 |
+
"contact_name": "A",
|
224 |
+
"timestamp": 3,
|
225 |
+
},
|
226 |
+
]
|
227 |
+
# print(swap_messages_if_needed_in_conversation(test_conversation))
|
228 |
+
|
229 |
+
# %%
|
230 |
+
# Now, we'll train an mT5 model to generate the next message in a conversation
|
231 |
+
import os
|
232 |
+
|
233 |
+
|
234 |
+
# For the contact_name, rewrite everything that is not 'Aldi' to 'Other'
|
235 |
+
def rewrite_contact_name(conversation):
|
236 |
+
for message in conversation["conversations"]:
|
237 |
+
if message["contact_name"] != "Aldi":
|
238 |
+
message["contact_name"] = "Other"
|
239 |
+
return conversation
|
240 |
+
|
241 |
+
|
242 |
+
# %%
|
243 |
+
def process_chat_file(file, do_spelling_correction, do_reordering=False):
|
244 |
+
"""
|
245 |
+
Process a chat file and return a dataset with the conversations.
|
246 |
+
"""
|
247 |
+
ds = (
|
248 |
+
datasets.load_dataset("text", data_files=[file])["train"]
|
249 |
+
.filter(
|
250 |
+
# Has to begin by date, time, contact name, and contain at least a ':' symbol
|
251 |
+
lambda x: re.match(
|
252 |
+
r"^\d{1,2}/\d{1,2}/\d{1,2},\s\d{2}:\d{2}\s-\s.+:", x["text"]
|
253 |
+
)
|
254 |
+
)
|
255 |
+
.map(process_line, remove_columns=["text"])
|
256 |
+
)
|
257 |
+
|
258 |
+
# Filter out messages that just say '<Media omitted>'
|
259 |
+
ds = ds.filter(lambda x: x["message"] != "<Media omitted>")
|
260 |
+
|
261 |
+
groups = group_messages(iter(ds))
|
262 |
+
# Generate the dataset
|
263 |
+
conversations_ds = datasets.Dataset.from_dict({"conversations": groups})
|
264 |
+
|
265 |
+
# Filter out conversations with less than 10 messages
|
266 |
+
conversations_ds = conversations_ds.filter(lambda x: len(x["conversations"]) >= 10)
|
267 |
+
|
268 |
+
conversations_ds_without_whatsapp_annotations = conversations_ds.map(
|
269 |
+
remove_whatapp_annotations,
|
270 |
+
num_proc=os.cpu_count() - 1,
|
271 |
+
)
|
272 |
+
|
273 |
+
if do_spelling_correction:
|
274 |
+
spell_checked_conversations_ds = (
|
275 |
+
conversations_ds_without_whatsapp_annotations.map(spell_check_conversation)
|
276 |
+
)
|
277 |
+
else:
|
278 |
+
spell_checked_conversations_ds = conversations_ds_without_whatsapp_annotations
|
279 |
+
|
280 |
+
if do_reordering:
|
281 |
+
reordered_conversations_ds = spell_checked_conversations_ds.map(
|
282 |
+
swap_messages_if_needed_in_conversation
|
283 |
+
)
|
284 |
+
else:
|
285 |
+
reordered_conversations_ds = spell_checked_conversations_ds
|
286 |
+
|
287 |
+
changed_contact_name_ds = reordered_conversations_ds.map(
|
288 |
+
rewrite_contact_name
|
289 |
+
) # , num_proc=os.cpu_count() - 1)
|
290 |
+
|
291 |
+
# Filter out conversations with only one contact
|
292 |
+
changed_contact_name_ds = changed_contact_name_ds.filter(
|
293 |
+
lambda x: len(set([msg["contact_name"] for msg in x["conversations"]])) > 1
|
294 |
+
)
|
295 |
+
|
296 |
+
return changed_contact_name_ds
|
297 |
+
|
298 |
+
|
299 |
+
def transform_conversations_dataset_into_training_examples(
|
300 |
+
conversations_ds, system_prompt
|
301 |
+
):
|
302 |
+
"""
|
303 |
+
Takes in a dataset with conversations and returns a dataset with training examples.
|
304 |
+
|
305 |
+
The input dataset contains a single column (conversations), with each row being a list of messages with this format:
|
306 |
+
```
|
307 |
+
[{'contact_name': 'Aldi', 'message': <message>, 'timestamp': <time>}, {'contact_name': 'Other', 'message': <message>, 'timestamp': <time>}, ... ]
|
308 |
+
```
|
309 |
+
|
310 |
+
Each row will be converted to fit the format of the training examples.
|
311 |
+
|
312 |
+
The training examples have the following format:
|
313 |
+
```
|
314 |
+
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
|
315 |
+
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "William Shakespeare"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]}
|
316 |
+
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "384,400 kilometers"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]}
|
317 |
+
```
|
318 |
+
"""
|
319 |
+
|
320 |
+
def process_one_example(example):
|
321 |
+
messages = [{"role": "system", "content": [system_prompt]}]
|
322 |
+
for msg in example["conversations"]:
|
323 |
+
converted_role = "assistant" if msg["contact_name"] == "Aldi" else "user"
|
324 |
+
if converted_role == messages[-1]["role"]:
|
325 |
+
messages[-1]["content"] += [msg["message"]]
|
326 |
+
else:
|
327 |
+
messages.append({"role": converted_role, "content": [msg["message"]]})
|
328 |
+
return {
|
329 |
+
"messages": [
|
330 |
+
{
|
331 |
+
"role": m["role"],
|
332 |
+
"content": json.dumps(m["content"], ensure_ascii=False),
|
333 |
+
}
|
334 |
+
for m in messages
|
335 |
+
]
|
336 |
+
}
|
337 |
+
|
338 |
+
return conversations_ds.map(
|
339 |
+
process_one_example,
|
340 |
+
remove_columns=["conversations"],
|
341 |
+
num_proc=os.cpu_count() - 1,
|
342 |
+
)
|
validation.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from collections import defaultdict
|
3 |
+
import tiktoken
|
4 |
+
|
5 |
+
|
6 |
+
def check_format_errors(train_dataset):
|
7 |
+
"""
|
8 |
+
Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
|
9 |
+
"""
|
10 |
+
# Format error checks
|
11 |
+
format_errors = defaultdict(int)
|
12 |
+
|
13 |
+
for ex in train_dataset:
|
14 |
+
if not isinstance(ex, dict):
|
15 |
+
format_errors["data_type"] += 1
|
16 |
+
continue
|
17 |
+
|
18 |
+
messages = ex.get("messages", None)
|
19 |
+
if not messages:
|
20 |
+
format_errors["missing_messages_list"] += 1
|
21 |
+
continue
|
22 |
+
|
23 |
+
for message in messages:
|
24 |
+
if "role" not in message or "content" not in message:
|
25 |
+
format_errors["message_missing_key"] += 1
|
26 |
+
|
27 |
+
if any(k not in ("role", "content", "name", "function_call", "weight") for k in message):
|
28 |
+
format_errors["message_unrecognized_key"] += 1
|
29 |
+
|
30 |
+
if message.get("role", None) not in ("system", "user", "assistant", "function"):
|
31 |
+
format_errors["unrecognized_role"] += 1
|
32 |
+
|
33 |
+
content = message.get("content", None)
|
34 |
+
function_call = message.get("function_call", None)
|
35 |
+
|
36 |
+
if (not content and not function_call) or not isinstance(content, str):
|
37 |
+
format_errors["missing_content"] += 1
|
38 |
+
|
39 |
+
if not any(message.get("role", None) == "assistant" for message in messages):
|
40 |
+
format_errors["example_missing_assistant_message"] += 1
|
41 |
+
|
42 |
+
if format_errors:
|
43 |
+
print("Found errors:")
|
44 |
+
for k, v in format_errors.items():
|
45 |
+
print(f"{k}: {v}")
|
46 |
+
else:
|
47 |
+
print("No errors found")
|
48 |
+
|
49 |
+
return format_errors if format_errors else {}
|
50 |
+
|
51 |
+
def get_distributions(train_dataset):
|
52 |
+
"""
|
53 |
+
Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
|
54 |
+
|
55 |
+
Gets the distributions of the number of messages per example, the total number of tokens per example, and the number of assistant tokens per example.
|
56 |
+
"""
|
57 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
58 |
+
|
59 |
+
# not exact!
|
60 |
+
# simplified from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
61 |
+
def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
|
62 |
+
num_tokens = 0
|
63 |
+
for message in messages:
|
64 |
+
num_tokens += tokens_per_message
|
65 |
+
for key, value in message.items():
|
66 |
+
num_tokens += len(encoding.encode(value))
|
67 |
+
if key == "name":
|
68 |
+
num_tokens += tokens_per_name
|
69 |
+
num_tokens += 3
|
70 |
+
return num_tokens
|
71 |
+
|
72 |
+
def num_assistant_tokens_from_messages(messages):
|
73 |
+
num_tokens = 0
|
74 |
+
for message in messages:
|
75 |
+
if message["role"] == "assistant":
|
76 |
+
num_tokens += len(encoding.encode(message["content"]))
|
77 |
+
return num_tokens
|
78 |
+
|
79 |
+
|
80 |
+
n_missing_system = 0
|
81 |
+
n_missing_user = 0
|
82 |
+
n_messages = []
|
83 |
+
convo_lens = []
|
84 |
+
assistant_message_lens = []
|
85 |
+
|
86 |
+
for ex in train_dataset:
|
87 |
+
messages = ex["messages"]
|
88 |
+
if not any(message["role"] == "system" for message in messages):
|
89 |
+
n_missing_system += 1
|
90 |
+
if not any(message["role"] == "user" for message in messages):
|
91 |
+
n_missing_user += 1
|
92 |
+
n_messages.append(len(messages))
|
93 |
+
convo_lens.append(num_tokens_from_messages(messages))
|
94 |
+
assistant_message_lens.append(num_assistant_tokens_from_messages(messages))
|
95 |
+
|
96 |
+
return {
|
97 |
+
"n_missing_system": n_missing_system,
|
98 |
+
"n_missing_user": n_missing_user,
|
99 |
+
"n_messages": n_messages,
|
100 |
+
"convo_lens": convo_lens,
|
101 |
+
"assistant_message_lens": assistant_message_lens
|
102 |
+
}
|
103 |
+
|
104 |
+
|
105 |
+
def check_token_counts(train_dataset):
|
106 |
+
"""
|
107 |
+
Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
|
108 |
+
"""
|
109 |
+
def print_distribution(values, name):
|
110 |
+
print(f"\n#### Distribution of {name}:")
|
111 |
+
print(f"min / max: {min(values)}, {max(values)}")
|
112 |
+
print(f"mean / median: {np.mean(values)}, {np.median(values)}")
|
113 |
+
print(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
# Warnings and tokens counts
|
118 |
+
distributions = get_distributions(train_dataset)
|
119 |
+
n_missing_system = distributions["n_missing_system"]
|
120 |
+
n_missing_user = distributions["n_missing_user"]
|
121 |
+
n_messages = distributions["n_messages"]
|
122 |
+
convo_lens = distributions["convo_lens"]
|
123 |
+
assistant_message_lens = distributions["assistant_message_lens"]
|
124 |
+
|
125 |
+
print("Num examples missing system message:", n_missing_system)
|
126 |
+
print("Num examples missing user message:", n_missing_user)
|
127 |
+
print_distribution(n_messages, "num_messages_per_example")
|
128 |
+
print_distribution(convo_lens, "num_total_tokens_per_example")
|
129 |
+
print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
|
130 |
+
n_too_long = sum(l > 4096 for l in convo_lens)
|
131 |
+
print(
|
132 |
+
f"\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning"
|
133 |
+
)
|
134 |
+
|
135 |
+
return
|
136 |
+
|
137 |
+
|
138 |
+
def estimate_cost(train_dataset):
|
139 |
+
"""
|
140 |
+
Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
|
141 |
+
"""
|
142 |
+
distributions = get_distributions(train_dataset)
|
143 |
+
n_missing_system = distributions["n_missing_system"]
|
144 |
+
n_missing_user = distributions["n_missing_user"]
|
145 |
+
n_messages = distributions["n_messages"]
|
146 |
+
convo_lens = distributions["convo_lens"]
|
147 |
+
assistant_message_lens = distributions["assistant_message_lens"]
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
# Pricing and default n_epochs estimate
|
152 |
+
MAX_TOKENS_PER_EXAMPLE = 4096
|
153 |
+
|
154 |
+
TARGET_EPOCHS = 3
|
155 |
+
MIN_TARGET_EXAMPLES = 100
|
156 |
+
MAX_TARGET_EXAMPLES = 25000
|
157 |
+
MIN_DEFAULT_EPOCHS = 1
|
158 |
+
MAX_DEFAULT_EPOCHS = 25
|
159 |
+
|
160 |
+
n_epochs = TARGET_EPOCHS
|
161 |
+
n_train_examples = len(train_dataset)
|
162 |
+
if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
|
163 |
+
n_epochs = min(MAX_DEFAULT_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
|
164 |
+
elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
|
165 |
+
n_epochs = max(MIN_DEFAULT_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)
|
166 |
+
|
167 |
+
n_billing_tokens_in_dataset = sum(
|
168 |
+
min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens
|
169 |
+
)
|
170 |
+
|
171 |
+
return {
|
172 |
+
"Estimated number of tokens in dataset": n_billing_tokens_in_dataset,
|
173 |
+
f"Estimated number of tokens that will be billed (assuming {n_epochs} training epochs)": n_epochs * n_billing_tokens_in_dataset
|
174 |
+
}
|