ACMC
commited on
Commit
•
a8dfddd
1
Parent(s):
f1fc3d0
Bugfix
Browse files
app.py
CHANGED
@@ -8,15 +8,26 @@ import datasets
|
|
8 |
import gradio as gr
|
9 |
import matplotlib.pyplot as plt
|
10 |
|
11 |
-
from utils import (
|
12 |
-
|
|
|
|
|
13 |
from validation import check_format_errors, estimate_cost, get_distributions
|
14 |
|
15 |
logger = logging.getLogger(__name__)
|
16 |
logger.setLevel(logging.INFO)
|
17 |
|
18 |
|
19 |
-
def convert_to_dataset(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
modified_dataset = None
|
21 |
for file in progress.tqdm(files, desc="Processing files"):
|
22 |
try:
|
@@ -28,6 +39,8 @@ def convert_to_dataset(files, do_spelling_correction, progress, whatsapp_name, d
|
|
28 |
whatsapp_name=whatsapp_name,
|
29 |
datetime_dayfirst=datetime_dayfirst,
|
30 |
message_line_format=message_line_format,
|
|
|
|
|
31 |
)
|
32 |
else:
|
33 |
# Concatenate the datasets
|
@@ -37,6 +50,8 @@ def convert_to_dataset(files, do_spelling_correction, progress, whatsapp_name, d
|
|
37 |
whatsapp_name=whatsapp_name,
|
38 |
datetime_dayfirst=datetime_dayfirst,
|
39 |
message_line_format=message_line_format,
|
|
|
|
|
40 |
)
|
41 |
modified_dataset = datasets.concatenate_datasets(
|
42 |
[modified_dataset, this_file_dataset]
|
@@ -57,6 +72,10 @@ def file_upload_callback(
|
|
57 |
whatsapp_name,
|
58 |
datetime_dayfirst,
|
59 |
message_line_format,
|
|
|
|
|
|
|
|
|
60 |
progress=gr.Progress(),
|
61 |
):
|
62 |
logger.info(f"Processing {files}")
|
@@ -73,7 +92,7 @@ The {model_role} and the {user_role} can send multiple messages in a row, as a J
|
|
73 |
# Check if the user has not chosen any files
|
74 |
if not files or len(files) == 0:
|
75 |
raise gr.Error("Please upload at least one file.")
|
76 |
-
|
77 |
# Check if the user has not entered their whatsapp name
|
78 |
if not whatsapp_name or len(whatsapp_name) == 0:
|
79 |
raise gr.Error("Please enter your WhatsApp name.")
|
@@ -87,26 +106,43 @@ The {model_role} and the {user_role} can send multiple messages in a row, as a J
|
|
87 |
whatsapp_name=whatsapp_name,
|
88 |
datetime_dayfirst=datetime_dayfirst,
|
89 |
message_line_format=message_line_format,
|
|
|
|
|
|
|
|
|
|
|
90 |
)
|
91 |
-
logger.info(f"Number of conversations of dataset before being transformed: {len(dataset)}")
|
92 |
|
93 |
-
|
94 |
conversations_ds=dataset,
|
95 |
system_prompt=full_system_prompt,
|
96 |
user_role=user_role,
|
97 |
model_role=model_role,
|
98 |
whatsapp_name=whatsapp_name,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
)
|
100 |
-
logger.info(f"Number of training examples: {len(training_examples_ds)}")
|
101 |
|
102 |
# Split into training and validation datasets (80% and 20%)
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
training_examples_ds
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
training_examples_ds = training_examples_ds # .select(
|
111 |
# range(min(250, len(training_examples_ds)))
|
112 |
# )
|
@@ -125,9 +161,12 @@ The {model_role} and the {user_role} can send multiple messages in a row, as a J
|
|
125 |
)
|
126 |
|
127 |
stats = {
|
128 |
-
"
|
|
|
|
|
129 |
"Number of examples missing system message": distributions["n_missing_system"],
|
130 |
"Number of examples missing user message": distributions["n_missing_user"],
|
|
|
131 |
"Cost Statistics": cost_stats,
|
132 |
}
|
133 |
|
@@ -156,9 +195,9 @@ The {model_role} and the {user_role} can send multiple messages in a row, as a J
|
|
156 |
# If there's less than 50 training examples, show a warning message
|
157 |
if len(training_examples_ds) < 50:
|
158 |
gr.Warning(
|
159 |
-
"
|
160 |
)
|
161 |
-
|
162 |
system_prompt_to_use = full_system_prompt
|
163 |
|
164 |
return (
|
@@ -245,14 +284,38 @@ with gr.Blocks(theme=theme) as demo:
|
|
245 |
|
246 |
model_role = gr.Textbox(
|
247 |
label="Role for Model",
|
248 |
-
info="This is a technical parameter. Usual values are 'model' or 'assistant'.",
|
249 |
value="model",
|
250 |
)
|
251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
message_line_format = gr.Textbox(
|
253 |
label="Message Line Format",
|
254 |
info="Format of each message line in the chat file, as a regular expression. The default value should work for most cases.",
|
255 |
-
value=r"\[?(?P<msg_datetime>\S+,\s\S+?(?:\s[APap][Mm])?)\]? (?:- )?(?P<contact_name
|
256 |
)
|
257 |
|
258 |
datetime_dayfirst = gr.Checkbox(
|
@@ -287,7 +350,13 @@ with gr.Blocks(theme=theme) as demo:
|
|
287 |
variant="secondary",
|
288 |
)
|
289 |
|
290 |
-
system_prompt_to_use = gr.Textbox(
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
# output_example = gr.JSON(label="Example Training Example")
|
292 |
|
293 |
with gr.Group():
|
@@ -316,6 +385,10 @@ with gr.Blocks(theme=theme) as demo:
|
|
316 |
whatsapp_name,
|
317 |
datetime_dayfirst,
|
318 |
message_line_format,
|
|
|
|
|
|
|
|
|
319 |
],
|
320 |
outputs=[
|
321 |
output_file,
|
@@ -327,7 +400,7 @@ with gr.Blocks(theme=theme) as demo:
|
|
327 |
num_total_tokens_per_example_plot,
|
328 |
num_assistant_tokens_per_example_plot,
|
329 |
system_prompt_to_use,
|
330 |
-
system_prompt_to_use
|
331 |
],
|
332 |
)
|
333 |
|
|
|
8 |
import gradio as gr
|
9 |
import matplotlib.pyplot as plt
|
10 |
|
11 |
+
from utils import (
|
12 |
+
process_chat_file,
|
13 |
+
transform_conversations_dataset_into_training_examples,
|
14 |
+
)
|
15 |
from validation import check_format_errors, estimate_cost, get_distributions
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
18 |
logger.setLevel(logging.INFO)
|
19 |
|
20 |
|
21 |
+
def convert_to_dataset(
|
22 |
+
files,
|
23 |
+
do_spelling_correction,
|
24 |
+
progress,
|
25 |
+
whatsapp_name,
|
26 |
+
datetime_dayfirst,
|
27 |
+
message_line_format,
|
28 |
+
minutes_threshold,
|
29 |
+
min_messages_per_conversation,
|
30 |
+
):
|
31 |
modified_dataset = None
|
32 |
for file in progress.tqdm(files, desc="Processing files"):
|
33 |
try:
|
|
|
39 |
whatsapp_name=whatsapp_name,
|
40 |
datetime_dayfirst=datetime_dayfirst,
|
41 |
message_line_format=message_line_format,
|
42 |
+
minutes_threshold=minutes_threshold,
|
43 |
+
min_messages_per_conversation=min_messages_per_conversation,
|
44 |
)
|
45 |
else:
|
46 |
# Concatenate the datasets
|
|
|
50 |
whatsapp_name=whatsapp_name,
|
51 |
datetime_dayfirst=datetime_dayfirst,
|
52 |
message_line_format=message_line_format,
|
53 |
+
minutes_threshold=minutes_threshold,
|
54 |
+
min_messages_per_conversation=min_messages_per_conversation,
|
55 |
)
|
56 |
modified_dataset = datasets.concatenate_datasets(
|
57 |
[modified_dataset, this_file_dataset]
|
|
|
72 |
whatsapp_name,
|
73 |
datetime_dayfirst,
|
74 |
message_line_format,
|
75 |
+
minutes_threshold,
|
76 |
+
min_messages_per_conversation,
|
77 |
+
max_characters_per_message,
|
78 |
+
split_conversation_threshold,
|
79 |
progress=gr.Progress(),
|
80 |
):
|
81 |
logger.info(f"Processing {files}")
|
|
|
92 |
# Check if the user has not chosen any files
|
93 |
if not files or len(files) == 0:
|
94 |
raise gr.Error("Please upload at least one file.")
|
95 |
+
|
96 |
# Check if the user has not entered their whatsapp name
|
97 |
if not whatsapp_name or len(whatsapp_name) == 0:
|
98 |
raise gr.Error("Please enter your WhatsApp name.")
|
|
|
106 |
whatsapp_name=whatsapp_name,
|
107 |
datetime_dayfirst=datetime_dayfirst,
|
108 |
message_line_format=message_line_format,
|
109 |
+
minutes_threshold=minutes_threshold,
|
110 |
+
min_messages_per_conversation=min_messages_per_conversation,
|
111 |
+
)
|
112 |
+
logger.info(
|
113 |
+
f"Number of conversations of dataset before being transformed: {len(dataset)}"
|
114 |
)
|
|
|
115 |
|
116 |
+
full_examples_ds = transform_conversations_dataset_into_training_examples(
|
117 |
conversations_ds=dataset,
|
118 |
system_prompt=full_system_prompt,
|
119 |
user_role=user_role,
|
120 |
model_role=model_role,
|
121 |
whatsapp_name=whatsapp_name,
|
122 |
+
minutes_threshold=minutes_threshold,
|
123 |
+
min_messages_per_conversation=min_messages_per_conversation,
|
124 |
+
split_conversation_threshold=split_conversation_threshold,
|
125 |
+
max_characters_per_message=max_characters_per_message,
|
126 |
+
)
|
127 |
+
total_number_of_generated_examples = len(full_examples_ds)
|
128 |
+
logger.info(
|
129 |
+
f"Total number of generated examples: {total_number_of_generated_examples}"
|
130 |
)
|
|
|
131 |
|
132 |
# Split into training and validation datasets (80% and 20%)
|
133 |
+
try:
|
134 |
+
split_examples_ds = full_examples_ds.train_test_split(
|
135 |
+
test_size=validation_split, seed=42
|
136 |
+
)
|
137 |
+
training_examples_ds, validation_examples_ds = (
|
138 |
+
split_examples_ds["train"],
|
139 |
+
split_examples_ds["test"],
|
140 |
+
)
|
141 |
+
except ValueError as e:
|
142 |
+
# This happens when there's not enough data to split into training and validation datasets
|
143 |
+
# In this case, we'll just use the whole dataset for training, nothing for validation
|
144 |
+
training_examples_ds = full_examples_ds
|
145 |
+
validation_examples_ds = datasets.Dataset.from_dict({})
|
146 |
training_examples_ds = training_examples_ds # .select(
|
147 |
# range(min(250, len(training_examples_ds)))
|
148 |
# )
|
|
|
161 |
)
|
162 |
|
163 |
stats = {
|
164 |
+
"Total number of training examples": total_number_of_generated_examples,
|
165 |
+
"Number of training examples": len(training_examples_ds),
|
166 |
+
"Number of validation examples": len(validation_examples_ds),
|
167 |
"Number of examples missing system message": distributions["n_missing_system"],
|
168 |
"Number of examples missing user message": distributions["n_missing_user"],
|
169 |
+
"Format Errors": format_errors,
|
170 |
"Cost Statistics": cost_stats,
|
171 |
}
|
172 |
|
|
|
195 |
# If there's less than 50 training examples, show a warning message
|
196 |
if len(training_examples_ds) < 50:
|
197 |
gr.Warning(
|
198 |
+
"There are less than 50 training examples. The model may not perform well with such a small dataset. Consider adding more chat files to increase the number of training examples."
|
199 |
)
|
200 |
+
|
201 |
system_prompt_to_use = full_system_prompt
|
202 |
|
203 |
return (
|
|
|
284 |
|
285 |
model_role = gr.Textbox(
|
286 |
label="Role for Model",
|
287 |
+
info="This is a technical parameter. Usual values are 'model' (e.g. Vertex AI) or 'assistant' (e.g. OpenAI).",
|
288 |
value="model",
|
289 |
)
|
290 |
|
291 |
+
minutes_threshold = gr.Number(
|
292 |
+
label="Minutes Threshold",
|
293 |
+
info="Threshold in minutes to consider that a new message is a new conversation. The default value should work for most cases.",
|
294 |
+
value=180,
|
295 |
+
)
|
296 |
+
|
297 |
+
min_messages_per_conversation = gr.Number(
|
298 |
+
label="Minimum Messages per Conversation",
|
299 |
+
info="Minimum number of messages per conversation to consider it as a valid conversation. The default value should work for most cases.",
|
300 |
+
value=5,
|
301 |
+
)
|
302 |
+
|
303 |
+
max_characters_per_message = gr.Number(
|
304 |
+
label="Max Characters per Message",
|
305 |
+
info="One token is around 3 characters. The default value should work for most cases. For example, on Vertex AI, the maximum number of tokens per example is [32,000](https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-prepare#sample-datasets), so keeping the default value will ensure that the examples are well within the limit.",
|
306 |
+
value=10000,
|
307 |
+
)
|
308 |
+
|
309 |
+
split_conversation_threshold = gr.Number(
|
310 |
+
label="Split Conversation Threshold",
|
311 |
+
info="Number of messages in a conversation to split it into multiple ones. The default value should work for most cases.",
|
312 |
+
value=40,
|
313 |
+
)
|
314 |
+
|
315 |
message_line_format = gr.Textbox(
|
316 |
label="Message Line Format",
|
317 |
info="Format of each message line in the chat file, as a regular expression. The default value should work for most cases.",
|
318 |
+
value=r"\[?(?P<msg_datetime>\S+,\s\S+?(?:\s[APap][Mm])?)\]? (?:- )?(?P<contact_name>.+?): (?P<message>.+)",
|
319 |
)
|
320 |
|
321 |
datetime_dayfirst = gr.Checkbox(
|
|
|
350 |
variant="secondary",
|
351 |
)
|
352 |
|
353 |
+
system_prompt_to_use = gr.Textbox(
|
354 |
+
label="System Prompt that you can use",
|
355 |
+
visible=False,
|
356 |
+
interactive=False,
|
357 |
+
show_copy_button=True,
|
358 |
+
info="When using the model, if you're asked for a system prompt, you can use this text.",
|
359 |
+
)
|
360 |
# output_example = gr.JSON(label="Example Training Example")
|
361 |
|
362 |
with gr.Group():
|
|
|
385 |
whatsapp_name,
|
386 |
datetime_dayfirst,
|
387 |
message_line_format,
|
388 |
+
minutes_threshold,
|
389 |
+
min_messages_per_conversation,
|
390 |
+
max_characters_per_message,
|
391 |
+
split_conversation_threshold,
|
392 |
],
|
393 |
outputs=[
|
394 |
output_file,
|
|
|
400 |
num_total_tokens_per_example_plot,
|
401 |
num_assistant_tokens_per_example_plot,
|
402 |
system_prompt_to_use,
|
403 |
+
system_prompt_to_use,
|
404 |
],
|
405 |
)
|
406 |
|
utils.py
CHANGED
@@ -9,22 +9,25 @@ import dateutil.parser
|
|
9 |
logger = logging.getLogger(__name__)
|
10 |
logger.setLevel(logging.INFO)
|
11 |
|
12 |
-
# %%
|
13 |
-
# Now, create message groups ('conversations')
|
14 |
-
# The idea is to group messages that are close in time
|
15 |
-
# We'll use a 180 minute threshold
|
16 |
-
MINUTES_THRESHOLD = 180
|
17 |
-
MIN_MESSAGES_THRESHOLD = 5
|
18 |
-
|
19 |
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
groups = []
|
22 |
-
current_group = [
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
for message in messages_iterable:
|
24 |
assert len(current_group) > 0 # We should never have an empty group
|
25 |
if (
|
26 |
message["timestamp"] - current_group[-1]["timestamp"]
|
27 |
-
<
|
28 |
):
|
29 |
current_group.append(message)
|
30 |
else:
|
@@ -214,7 +217,16 @@ import os
|
|
214 |
|
215 |
|
216 |
# %%
|
217 |
-
def process_chat_file(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
"""
|
219 |
Process a chat file and return a dataset with the conversations.
|
220 |
"""
|
@@ -224,50 +236,83 @@ def process_chat_file(file, do_spelling_correction, whatsapp_name, datetime_dayf
|
|
224 |
message_line_format
|
225 |
)
|
226 |
|
227 |
-
def process_line(
|
228 |
# The lines have this format: dd/mm/yy, hh:mm - <person>: <msg>
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
try:
|
242 |
ds = datasets.load_dataset("text", data_files=[file])["train"]
|
243 |
except Exception as e:
|
244 |
logger.exception(f"Error while loading file {file}")
|
245 |
raise Exception(f"Error while loading file {file}") from e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
try:
|
247 |
-
ds = ds.
|
248 |
-
# Has to begin by date, time, contact name, and contain at least a ':' symbol
|
249 |
-
lambda x: re.match(
|
250 |
-
r"^\d{1,2}/\d{1,2}/\d{1,2},\s\d{2}:\d{2}\s-\s.+:", x["text"]
|
251 |
-
)
|
252 |
-
)
|
253 |
-
except Exception as e:
|
254 |
-
logger.exception(f"Error filtering the lines in file {file} so they match the expected format")
|
255 |
-
raise Exception(f"Error filtering the lines in file {file} so they match the expected format") from e
|
256 |
-
try:
|
257 |
-
ds = ds.map(process_line, remove_columns=["text"])
|
258 |
except Exception as e:
|
259 |
-
logger.exception(
|
260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
|
262 |
try:
|
263 |
# Filter out messages that just say '<Media omitted>'
|
264 |
ds = ds.filter(lambda x: x["message"] != "<Media omitted>")
|
265 |
except Exception as e:
|
266 |
-
logger.exception(
|
267 |
-
|
|
|
|
|
|
|
|
|
268 |
|
269 |
try:
|
270 |
-
groups = group_messages(iter(ds))
|
271 |
# Generate the dataset
|
272 |
conversations_ds = datasets.Dataset.from_dict({"conversations": groups})
|
273 |
except Exception as e:
|
@@ -277,11 +322,15 @@ def process_chat_file(file, do_spelling_correction, whatsapp_name, datetime_dayf
|
|
277 |
try:
|
278 |
# Filter out conversations with less than 5 messages
|
279 |
conversations_ds = conversations_ds.filter(
|
280 |
-
lambda x: len(x["conversations"]) >=
|
281 |
)
|
282 |
except Exception as e:
|
283 |
-
logger.exception(
|
284 |
-
|
|
|
|
|
|
|
|
|
285 |
|
286 |
try:
|
287 |
conversations_ds_without_whatsapp_annotations = conversations_ds.map(
|
@@ -295,11 +344,15 @@ def process_chat_file(file, do_spelling_correction, whatsapp_name, datetime_dayf
|
|
295 |
if do_spelling_correction:
|
296 |
try:
|
297 |
spell_checked_conversations_ds = (
|
298 |
-
conversations_ds_without_whatsapp_annotations.map(
|
|
|
|
|
299 |
)
|
300 |
except Exception as e:
|
301 |
logger.exception(f"Error spell checking the conversations in file {file}")
|
302 |
-
raise Exception(
|
|
|
|
|
303 |
else:
|
304 |
spell_checked_conversations_ds = conversations_ds_without_whatsapp_annotations
|
305 |
|
@@ -327,7 +380,9 @@ def process_chat_file(file, do_spelling_correction, whatsapp_name, datetime_dayf
|
|
327 |
) # , num_proc=os.cpu_count() - 1)
|
328 |
except Exception as e:
|
329 |
logger.exception(f"Error changing your other contact's names in file {file}")
|
330 |
-
raise Exception(
|
|
|
|
|
331 |
|
332 |
try:
|
333 |
# Filter out conversations with only one contact
|
@@ -335,18 +390,26 @@ def process_chat_file(file, do_spelling_correction, whatsapp_name, datetime_dayf
|
|
335 |
lambda x: len(set([msg["contact_name"] for msg in x["conversations"]])) > 1
|
336 |
)
|
337 |
except Exception as e:
|
338 |
-
logger.exception(
|
339 |
-
|
|
|
|
|
|
|
|
|
340 |
|
341 |
return changed_contact_name_ds
|
342 |
|
343 |
|
344 |
-
SPLIT_CONVERSATION_THRESHOLD = 40
|
345 |
-
MAX_CHARACTERS_PER_MESSAGE = 10000 # Max is 8,192 tokens (https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-about#sample-datasets)
|
346 |
-
|
347 |
-
|
348 |
def transform_conversations_dataset_into_training_examples(
|
349 |
-
conversations_ds,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
):
|
351 |
"""
|
352 |
Takes in a dataset with conversations and returns a dataset with training examples.
|
@@ -376,7 +439,7 @@ def transform_conversations_dataset_into_training_examples(
|
|
376 |
model_role if msg["contact_name"] == whatsapp_name else user_role
|
377 |
)
|
378 |
if (
|
379 |
-
counter >
|
380 |
and converted_role == user_role
|
381 |
):
|
382 |
processed_examples.append(
|
@@ -401,7 +464,7 @@ def transform_conversations_dataset_into_training_examples(
|
|
401 |
{"role": converted_role, "content": [msg["message"]]}
|
402 |
)
|
403 |
counter += 1
|
404 |
-
if len(messages) >=
|
405 |
processed_examples.append(
|
406 |
{
|
407 |
"messages": [
|
@@ -415,8 +478,13 @@ def transform_conversations_dataset_into_training_examples(
|
|
415 |
)
|
416 |
else:
|
417 |
logger.warning(
|
418 |
-
f"Discarding conversation because the length is not at least {
|
419 |
)
|
|
|
|
|
|
|
|
|
|
|
420 |
# Before returning, flatten the list of dictionaries into a dictionary of lists
|
421 |
flattened_examples = {}
|
422 |
for key in processed_examples[0].keys():
|
@@ -431,17 +499,25 @@ def transform_conversations_dataset_into_training_examples(
|
|
431 |
batched=True,
|
432 |
)
|
433 |
except Exception as e:
|
434 |
-
logger.exception(
|
435 |
-
|
|
|
|
|
|
|
|
|
436 |
|
437 |
try:
|
438 |
examples_filtered_by_length = processed_examples.filter(
|
439 |
lambda x: all(
|
440 |
-
[len(m["content"]) <
|
441 |
)
|
442 |
)
|
443 |
except Exception as e:
|
444 |
-
logger.exception(
|
445 |
-
|
|
|
|
|
|
|
|
|
446 |
|
447 |
return examples_filtered_by_length
|
|
|
9 |
logger = logging.getLogger(__name__)
|
10 |
logger.setLevel(logging.INFO)
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
# %%
|
14 |
+
def group_messages(messages_iterable, minutes_threshold):
|
15 |
+
"""
|
16 |
+
Groups messages in a conversation. If the difference between two consecutive messages is less than `minutes_threshold` minutes, they are grouped together.
|
17 |
+
"""
|
18 |
groups = []
|
19 |
+
current_group = []
|
20 |
+
try:
|
21 |
+
first_message = next(messages_iterable)
|
22 |
+
current_group.append(first_message)
|
23 |
+
except StopIteration:
|
24 |
+
logger.exception("No messages in the conversation")
|
25 |
+
return []
|
26 |
for message in messages_iterable:
|
27 |
assert len(current_group) > 0 # We should never have an empty group
|
28 |
if (
|
29 |
message["timestamp"] - current_group[-1]["timestamp"]
|
30 |
+
< minutes_threshold * 60
|
31 |
):
|
32 |
current_group.append(message)
|
33 |
else:
|
|
|
217 |
|
218 |
|
219 |
# %%
|
220 |
+
def process_chat_file(
|
221 |
+
file,
|
222 |
+
do_spelling_correction,
|
223 |
+
whatsapp_name,
|
224 |
+
datetime_dayfirst,
|
225 |
+
message_line_format,
|
226 |
+
minutes_threshold,
|
227 |
+
min_messages_per_conversation,
|
228 |
+
do_reordering=False,
|
229 |
+
):
|
230 |
"""
|
231 |
Process a chat file and return a dataset with the conversations.
|
232 |
"""
|
|
|
236 |
message_line_format
|
237 |
)
|
238 |
|
239 |
+
def process_line(examples):
|
240 |
# The lines have this format: dd/mm/yy, hh:mm - <person>: <msg>
|
241 |
+
messages = []
|
242 |
+
contact_names = []
|
243 |
+
timestamps = []
|
244 |
+
for line_text in examples["text"]:
|
245 |
+
try:
|
246 |
+
groups = exp.match(line_text).groupdict()
|
247 |
+
# First, get the elements. If something fails here, it will raise an exception before actually adding the element to the list, so we'll be sure that the three lists contain the same # of elements.
|
248 |
+
timestamp = dateutil.parser.parse(
|
249 |
+
groups["msg_datetime"], dayfirst=datetime_dayfirst
|
250 |
+
).timestamp()
|
251 |
+
message = groups["message"]
|
252 |
+
contact_name = groups["contact_name"]
|
253 |
+
messages.append(message)
|
254 |
+
contact_names.append(contact_name)
|
255 |
+
timestamps.append(timestamp)
|
256 |
+
except Exception as e:
|
257 |
+
logger.exception(f"Error while processing line {line_text}")
|
258 |
+
return {
|
259 |
+
"message": messages,
|
260 |
+
"contact_name": contact_names,
|
261 |
+
"timestamp": timestamps,
|
262 |
+
}
|
263 |
|
264 |
try:
|
265 |
ds = datasets.load_dataset("text", data_files=[file])["train"]
|
266 |
except Exception as e:
|
267 |
logger.exception(f"Error while loading file {file}")
|
268 |
raise Exception(f"Error while loading file {file}") from e
|
269 |
+
|
270 |
+
# try:
|
271 |
+
# ds = ds.filter(
|
272 |
+
# # Has to begin by date, time, contact name, and contain at least a ':' symbol
|
273 |
+
# lambda x: re.match(
|
274 |
+
# r"^\d{1,2}/\d{1,2}/\d{1,4},\s\d{2}:\d{2}\s-\s.+:", x["text"]
|
275 |
+
# )
|
276 |
+
# )
|
277 |
+
# except Exception as e:
|
278 |
+
# logger.exception(f"Error filtering the lines in file {file} so they match the expected format")
|
279 |
+
# raise Exception(f"Error filtering the lines in file {file} so they match the expected format") from e
|
280 |
+
|
281 |
try:
|
282 |
+
ds = ds.map(process_line, remove_columns=["text"], batched=True, batch_size=10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
except Exception as e:
|
284 |
+
logger.exception(
|
285 |
+
f"Error mapping the lines in file {file} to the expected format"
|
286 |
+
)
|
287 |
+
raise Exception(
|
288 |
+
f"Error mapping the lines in file {file} to the expected format"
|
289 |
+
) from e
|
290 |
+
|
291 |
+
# Check that the WhatsApp name is in at least one of the messages. If it's not, raise an exception
|
292 |
+
set_of_contact_names = ds.unique("contact_name")
|
293 |
+
if whatsapp_name not in set_of_contact_names:
|
294 |
+
raise Exception(
|
295 |
+
f"Your WhatsApp name ({whatsapp_name}) is not in the messages of at least one uploaded file. Please check that you wrote your name correctly. These were the participants found: {set_of_contact_names}"
|
296 |
+
)
|
297 |
+
# # Also check that the number of contact names is == 2 (i.e. we don't have group chats)
|
298 |
+
# if len(set_of_contact_names) > 2:
|
299 |
+
# raise Exception(
|
300 |
+
# f"There were more than 2 participants in at least one uploaded file. Please check that you're not using group chats. These were the participants found: {set_of_contact_names}"
|
301 |
+
# )
|
302 |
|
303 |
try:
|
304 |
# Filter out messages that just say '<Media omitted>'
|
305 |
ds = ds.filter(lambda x: x["message"] != "<Media omitted>")
|
306 |
except Exception as e:
|
307 |
+
logger.exception(
|
308 |
+
f"Error filtering out messages that say '<Media omitted>' in file {file}"
|
309 |
+
)
|
310 |
+
raise Exception(
|
311 |
+
f"Error filtering out messages that say '<Media omitted>' in file {file}"
|
312 |
+
) from e
|
313 |
|
314 |
try:
|
315 |
+
groups = group_messages(iter(ds), minutes_threshold=minutes_threshold)
|
316 |
# Generate the dataset
|
317 |
conversations_ds = datasets.Dataset.from_dict({"conversations": groups})
|
318 |
except Exception as e:
|
|
|
322 |
try:
|
323 |
# Filter out conversations with less than 5 messages
|
324 |
conversations_ds = conversations_ds.filter(
|
325 |
+
lambda x: len(x["conversations"]) >= min_messages_per_conversation
|
326 |
)
|
327 |
except Exception as e:
|
328 |
+
logger.exception(
|
329 |
+
f"Error filtering out conversations with less than {min_messages_per_conversation} messages in file {file}"
|
330 |
+
)
|
331 |
+
raise Exception(
|
332 |
+
f"Error filtering out conversations with less than {min_messages_per_conversation} messages in file {file}"
|
333 |
+
) from e
|
334 |
|
335 |
try:
|
336 |
conversations_ds_without_whatsapp_annotations = conversations_ds.map(
|
|
|
344 |
if do_spelling_correction:
|
345 |
try:
|
346 |
spell_checked_conversations_ds = (
|
347 |
+
conversations_ds_without_whatsapp_annotations.map(
|
348 |
+
spell_check_conversation
|
349 |
+
)
|
350 |
)
|
351 |
except Exception as e:
|
352 |
logger.exception(f"Error spell checking the conversations in file {file}")
|
353 |
+
raise Exception(
|
354 |
+
f"Error spell checking the conversations in file {file}"
|
355 |
+
) from e
|
356 |
else:
|
357 |
spell_checked_conversations_ds = conversations_ds_without_whatsapp_annotations
|
358 |
|
|
|
380 |
) # , num_proc=os.cpu_count() - 1)
|
381 |
except Exception as e:
|
382 |
logger.exception(f"Error changing your other contact's names in file {file}")
|
383 |
+
raise Exception(
|
384 |
+
f"Error changing your other contact's names in file {file}"
|
385 |
+
) from e
|
386 |
|
387 |
try:
|
388 |
# Filter out conversations with only one contact
|
|
|
390 |
lambda x: len(set([msg["contact_name"] for msg in x["conversations"]])) > 1
|
391 |
)
|
392 |
except Exception as e:
|
393 |
+
logger.exception(
|
394 |
+
f"Error filtering out conversations with only one contact in file {file}"
|
395 |
+
)
|
396 |
+
raise Exception(
|
397 |
+
f"Error filtering out conversations with only one contact in file {file}"
|
398 |
+
) from e
|
399 |
|
400 |
return changed_contact_name_ds
|
401 |
|
402 |
|
|
|
|
|
|
|
|
|
403 |
def transform_conversations_dataset_into_training_examples(
|
404 |
+
conversations_ds,
|
405 |
+
system_prompt,
|
406 |
+
user_role,
|
407 |
+
model_role,
|
408 |
+
whatsapp_name,
|
409 |
+
minutes_threshold,
|
410 |
+
min_messages_per_conversation,
|
411 |
+
split_conversation_threshold,
|
412 |
+
max_characters_per_message,
|
413 |
):
|
414 |
"""
|
415 |
Takes in a dataset with conversations and returns a dataset with training examples.
|
|
|
439 |
model_role if msg["contact_name"] == whatsapp_name else user_role
|
440 |
)
|
441 |
if (
|
442 |
+
counter > split_conversation_threshold
|
443 |
and converted_role == user_role
|
444 |
):
|
445 |
processed_examples.append(
|
|
|
464 |
{"role": converted_role, "content": [msg["message"]]}
|
465 |
)
|
466 |
counter += 1
|
467 |
+
if len(messages) >= min_messages_per_conversation:
|
468 |
processed_examples.append(
|
469 |
{
|
470 |
"messages": [
|
|
|
478 |
)
|
479 |
else:
|
480 |
logger.warning(
|
481 |
+
f"Discarding conversation because the length is not at least {min_messages_per_conversation}: {messages}"
|
482 |
)
|
483 |
+
if len(processed_examples) == 0:
|
484 |
+
logger.warning(
|
485 |
+
f"Discarding all conversations because none of them have at least {min_messages_per_conversation} messages"
|
486 |
+
)
|
487 |
+
return {}
|
488 |
# Before returning, flatten the list of dictionaries into a dictionary of lists
|
489 |
flattened_examples = {}
|
490 |
for key in processed_examples[0].keys():
|
|
|
499 |
batched=True,
|
500 |
)
|
501 |
except Exception as e:
|
502 |
+
logger.exception(
|
503 |
+
"Error transforming the conversations dataset into training examples"
|
504 |
+
)
|
505 |
+
raise Exception(
|
506 |
+
"Error transforming the conversations dataset into training examples"
|
507 |
+
) from e
|
508 |
|
509 |
try:
|
510 |
examples_filtered_by_length = processed_examples.filter(
|
511 |
lambda x: all(
|
512 |
+
[len(m["content"]) < max_characters_per_message for m in x["messages"]]
|
513 |
)
|
514 |
)
|
515 |
except Exception as e:
|
516 |
+
logger.exception(
|
517 |
+
"Error filtering out examples with messages longer than the maximum allowed"
|
518 |
+
)
|
519 |
+
raise Exception(
|
520 |
+
"Error filtering out examples with messages longer than the maximum allowed"
|
521 |
+
) from e
|
522 |
|
523 |
return examples_filtered_by_length
|