ACMC commited on
Commit
a8dfddd
1 Parent(s): f1fc3d0
Files changed (2) hide show
  1. app.py +94 -21
  2. utils.py +136 -60
app.py CHANGED
@@ -8,15 +8,26 @@ import datasets
8
  import gradio as gr
9
  import matplotlib.pyplot as plt
10
 
11
- from utils import (process_chat_file,
12
- transform_conversations_dataset_into_training_examples)
 
 
13
  from validation import check_format_errors, estimate_cost, get_distributions
14
 
15
  logger = logging.getLogger(__name__)
16
  logger.setLevel(logging.INFO)
17
 
18
 
19
- def convert_to_dataset(files, do_spelling_correction, progress, whatsapp_name, datetime_dayfirst, message_line_format):
 
 
 
 
 
 
 
 
 
20
  modified_dataset = None
21
  for file in progress.tqdm(files, desc="Processing files"):
22
  try:
@@ -28,6 +39,8 @@ def convert_to_dataset(files, do_spelling_correction, progress, whatsapp_name, d
28
  whatsapp_name=whatsapp_name,
29
  datetime_dayfirst=datetime_dayfirst,
30
  message_line_format=message_line_format,
 
 
31
  )
32
  else:
33
  # Concatenate the datasets
@@ -37,6 +50,8 @@ def convert_to_dataset(files, do_spelling_correction, progress, whatsapp_name, d
37
  whatsapp_name=whatsapp_name,
38
  datetime_dayfirst=datetime_dayfirst,
39
  message_line_format=message_line_format,
 
 
40
  )
41
  modified_dataset = datasets.concatenate_datasets(
42
  [modified_dataset, this_file_dataset]
@@ -57,6 +72,10 @@ def file_upload_callback(
57
  whatsapp_name,
58
  datetime_dayfirst,
59
  message_line_format,
 
 
 
 
60
  progress=gr.Progress(),
61
  ):
62
  logger.info(f"Processing {files}")
@@ -73,7 +92,7 @@ The {model_role} and the {user_role} can send multiple messages in a row, as a J
73
  # Check if the user has not chosen any files
74
  if not files or len(files) == 0:
75
  raise gr.Error("Please upload at least one file.")
76
-
77
  # Check if the user has not entered their whatsapp name
78
  if not whatsapp_name or len(whatsapp_name) == 0:
79
  raise gr.Error("Please enter your WhatsApp name.")
@@ -87,26 +106,43 @@ The {model_role} and the {user_role} can send multiple messages in a row, as a J
87
  whatsapp_name=whatsapp_name,
88
  datetime_dayfirst=datetime_dayfirst,
89
  message_line_format=message_line_format,
 
 
 
 
 
90
  )
91
- logger.info(f"Number of conversations of dataset before being transformed: {len(dataset)}")
92
 
93
- training_examples_ds = transform_conversations_dataset_into_training_examples(
94
  conversations_ds=dataset,
95
  system_prompt=full_system_prompt,
96
  user_role=user_role,
97
  model_role=model_role,
98
  whatsapp_name=whatsapp_name,
 
 
 
 
 
 
 
 
99
  )
100
- logger.info(f"Number of training examples: {len(training_examples_ds)}")
101
 
102
  # Split into training and validation datasets (80% and 20%)
103
- training_examples_ds = training_examples_ds.train_test_split(
104
- test_size=validation_split, seed=42
105
- )
106
- training_examples_ds, validation_examples_ds = (
107
- training_examples_ds["train"],
108
- training_examples_ds["test"],
109
- )
 
 
 
 
 
 
110
  training_examples_ds = training_examples_ds # .select(
111
  # range(min(250, len(training_examples_ds)))
112
  # )
@@ -125,9 +161,12 @@ The {model_role} and the {user_role} can send multiple messages in a row, as a J
125
  )
126
 
127
  stats = {
128
- "Format Errors": format_errors,
 
 
129
  "Number of examples missing system message": distributions["n_missing_system"],
130
  "Number of examples missing user message": distributions["n_missing_user"],
 
131
  "Cost Statistics": cost_stats,
132
  }
133
 
@@ -156,9 +195,9 @@ The {model_role} and the {user_role} can send multiple messages in a row, as a J
156
  # If there's less than 50 training examples, show a warning message
157
  if len(training_examples_ds) < 50:
158
  gr.Warning(
159
- "Warning: There are less than 50 training examples. The model may not perform well with such a small dataset. Consider adding more chat files to increase the number of training examples."
160
  )
161
-
162
  system_prompt_to_use = full_system_prompt
163
 
164
  return (
@@ -245,14 +284,38 @@ with gr.Blocks(theme=theme) as demo:
245
 
246
  model_role = gr.Textbox(
247
  label="Role for Model",
248
- info="This is a technical parameter. Usual values are 'model' or 'assistant'.",
249
  value="model",
250
  )
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  message_line_format = gr.Textbox(
253
  label="Message Line Format",
254
  info="Format of each message line in the chat file, as a regular expression. The default value should work for most cases.",
255
- value=r"\[?(?P<msg_datetime>\S+,\s\S+?(?:\s[APap][Mm])?)\]? (?:- )?(?P<contact_name>.+): (?P<message>.+)",
256
  )
257
 
258
  datetime_dayfirst = gr.Checkbox(
@@ -287,7 +350,13 @@ with gr.Blocks(theme=theme) as demo:
287
  variant="secondary",
288
  )
289
 
290
- system_prompt_to_use = gr.Textbox(label="System Prompt that you can use", visible=False, interactive=False, show_copy_button=True, info="When using the model, if you're asked for a system prompt, you can use this text.")
 
 
 
 
 
 
291
  # output_example = gr.JSON(label="Example Training Example")
292
 
293
  with gr.Group():
@@ -316,6 +385,10 @@ with gr.Blocks(theme=theme) as demo:
316
  whatsapp_name,
317
  datetime_dayfirst,
318
  message_line_format,
 
 
 
 
319
  ],
320
  outputs=[
321
  output_file,
@@ -327,7 +400,7 @@ with gr.Blocks(theme=theme) as demo:
327
  num_total_tokens_per_example_plot,
328
  num_assistant_tokens_per_example_plot,
329
  system_prompt_to_use,
330
- system_prompt_to_use
331
  ],
332
  )
333
 
 
8
  import gradio as gr
9
  import matplotlib.pyplot as plt
10
 
11
+ from utils import (
12
+ process_chat_file,
13
+ transform_conversations_dataset_into_training_examples,
14
+ )
15
  from validation import check_format_errors, estimate_cost, get_distributions
16
 
17
  logger = logging.getLogger(__name__)
18
  logger.setLevel(logging.INFO)
19
 
20
 
21
+ def convert_to_dataset(
22
+ files,
23
+ do_spelling_correction,
24
+ progress,
25
+ whatsapp_name,
26
+ datetime_dayfirst,
27
+ message_line_format,
28
+ minutes_threshold,
29
+ min_messages_per_conversation,
30
+ ):
31
  modified_dataset = None
32
  for file in progress.tqdm(files, desc="Processing files"):
33
  try:
 
39
  whatsapp_name=whatsapp_name,
40
  datetime_dayfirst=datetime_dayfirst,
41
  message_line_format=message_line_format,
42
+ minutes_threshold=minutes_threshold,
43
+ min_messages_per_conversation=min_messages_per_conversation,
44
  )
45
  else:
46
  # Concatenate the datasets
 
50
  whatsapp_name=whatsapp_name,
51
  datetime_dayfirst=datetime_dayfirst,
52
  message_line_format=message_line_format,
53
+ minutes_threshold=minutes_threshold,
54
+ min_messages_per_conversation=min_messages_per_conversation,
55
  )
56
  modified_dataset = datasets.concatenate_datasets(
57
  [modified_dataset, this_file_dataset]
 
72
  whatsapp_name,
73
  datetime_dayfirst,
74
  message_line_format,
75
+ minutes_threshold,
76
+ min_messages_per_conversation,
77
+ max_characters_per_message,
78
+ split_conversation_threshold,
79
  progress=gr.Progress(),
80
  ):
81
  logger.info(f"Processing {files}")
 
92
  # Check if the user has not chosen any files
93
  if not files or len(files) == 0:
94
  raise gr.Error("Please upload at least one file.")
95
+
96
  # Check if the user has not entered their whatsapp name
97
  if not whatsapp_name or len(whatsapp_name) == 0:
98
  raise gr.Error("Please enter your WhatsApp name.")
 
106
  whatsapp_name=whatsapp_name,
107
  datetime_dayfirst=datetime_dayfirst,
108
  message_line_format=message_line_format,
109
+ minutes_threshold=minutes_threshold,
110
+ min_messages_per_conversation=min_messages_per_conversation,
111
+ )
112
+ logger.info(
113
+ f"Number of conversations of dataset before being transformed: {len(dataset)}"
114
  )
 
115
 
116
+ full_examples_ds = transform_conversations_dataset_into_training_examples(
117
  conversations_ds=dataset,
118
  system_prompt=full_system_prompt,
119
  user_role=user_role,
120
  model_role=model_role,
121
  whatsapp_name=whatsapp_name,
122
+ minutes_threshold=minutes_threshold,
123
+ min_messages_per_conversation=min_messages_per_conversation,
124
+ split_conversation_threshold=split_conversation_threshold,
125
+ max_characters_per_message=max_characters_per_message,
126
+ )
127
+ total_number_of_generated_examples = len(full_examples_ds)
128
+ logger.info(
129
+ f"Total number of generated examples: {total_number_of_generated_examples}"
130
  )
 
131
 
132
  # Split into training and validation datasets (80% and 20%)
133
+ try:
134
+ split_examples_ds = full_examples_ds.train_test_split(
135
+ test_size=validation_split, seed=42
136
+ )
137
+ training_examples_ds, validation_examples_ds = (
138
+ split_examples_ds["train"],
139
+ split_examples_ds["test"],
140
+ )
141
+ except ValueError as e:
142
+ # This happens when there's not enough data to split into training and validation datasets
143
+ # In this case, we'll just use the whole dataset for training, nothing for validation
144
+ training_examples_ds = full_examples_ds
145
+ validation_examples_ds = datasets.Dataset.from_dict({})
146
  training_examples_ds = training_examples_ds # .select(
147
  # range(min(250, len(training_examples_ds)))
148
  # )
 
161
  )
162
 
163
  stats = {
164
+ "Total number of training examples": total_number_of_generated_examples,
165
+ "Number of training examples": len(training_examples_ds),
166
+ "Number of validation examples": len(validation_examples_ds),
167
  "Number of examples missing system message": distributions["n_missing_system"],
168
  "Number of examples missing user message": distributions["n_missing_user"],
169
+ "Format Errors": format_errors,
170
  "Cost Statistics": cost_stats,
171
  }
172
 
 
195
  # If there's less than 50 training examples, show a warning message
196
  if len(training_examples_ds) < 50:
197
  gr.Warning(
198
+ "There are less than 50 training examples. The model may not perform well with such a small dataset. Consider adding more chat files to increase the number of training examples."
199
  )
200
+
201
  system_prompt_to_use = full_system_prompt
202
 
203
  return (
 
284
 
285
  model_role = gr.Textbox(
286
  label="Role for Model",
287
+ info="This is a technical parameter. Usual values are 'model' (e.g. Vertex AI) or 'assistant' (e.g. OpenAI).",
288
  value="model",
289
  )
290
 
291
+ minutes_threshold = gr.Number(
292
+ label="Minutes Threshold",
293
+ info="Threshold in minutes to consider that a new message is a new conversation. The default value should work for most cases.",
294
+ value=180,
295
+ )
296
+
297
+ min_messages_per_conversation = gr.Number(
298
+ label="Minimum Messages per Conversation",
299
+ info="Minimum number of messages per conversation to consider it as a valid conversation. The default value should work for most cases.",
300
+ value=5,
301
+ )
302
+
303
+ max_characters_per_message = gr.Number(
304
+ label="Max Characters per Message",
305
+ info="One token is around 3 characters. The default value should work for most cases. For example, on Vertex AI, the maximum number of tokens per example is [32,000](https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-prepare#sample-datasets), so keeping the default value will ensure that the examples are well within the limit.",
306
+ value=10000,
307
+ )
308
+
309
+ split_conversation_threshold = gr.Number(
310
+ label="Split Conversation Threshold",
311
+ info="Number of messages in a conversation to split it into multiple ones. The default value should work for most cases.",
312
+ value=40,
313
+ )
314
+
315
  message_line_format = gr.Textbox(
316
  label="Message Line Format",
317
  info="Format of each message line in the chat file, as a regular expression. The default value should work for most cases.",
318
+ value=r"\[?(?P<msg_datetime>\S+,\s\S+?(?:\s[APap][Mm])?)\]? (?:- )?(?P<contact_name>.+?): (?P<message>.+)",
319
  )
320
 
321
  datetime_dayfirst = gr.Checkbox(
 
350
  variant="secondary",
351
  )
352
 
353
+ system_prompt_to_use = gr.Textbox(
354
+ label="System Prompt that you can use",
355
+ visible=False,
356
+ interactive=False,
357
+ show_copy_button=True,
358
+ info="When using the model, if you're asked for a system prompt, you can use this text.",
359
+ )
360
  # output_example = gr.JSON(label="Example Training Example")
361
 
362
  with gr.Group():
 
385
  whatsapp_name,
386
  datetime_dayfirst,
387
  message_line_format,
388
+ minutes_threshold,
389
+ min_messages_per_conversation,
390
+ max_characters_per_message,
391
+ split_conversation_threshold,
392
  ],
393
  outputs=[
394
  output_file,
 
400
  num_total_tokens_per_example_plot,
401
  num_assistant_tokens_per_example_plot,
402
  system_prompt_to_use,
403
+ system_prompt_to_use,
404
  ],
405
  )
406
 
utils.py CHANGED
@@ -9,22 +9,25 @@ import dateutil.parser
9
  logger = logging.getLogger(__name__)
10
  logger.setLevel(logging.INFO)
11
 
12
- # %%
13
- # Now, create message groups ('conversations')
14
- # The idea is to group messages that are close in time
15
- # We'll use a 180 minute threshold
16
- MINUTES_THRESHOLD = 180
17
- MIN_MESSAGES_THRESHOLD = 5
18
-
19
 
20
- def group_messages(messages_iterable):
 
 
 
 
21
  groups = []
22
- current_group = [next(messages_iterable)]
 
 
 
 
 
 
23
  for message in messages_iterable:
24
  assert len(current_group) > 0 # We should never have an empty group
25
  if (
26
  message["timestamp"] - current_group[-1]["timestamp"]
27
- < MINUTES_THRESHOLD * 60
28
  ):
29
  current_group.append(message)
30
  else:
@@ -214,7 +217,16 @@ import os
214
 
215
 
216
  # %%
217
- def process_chat_file(file, do_spelling_correction, whatsapp_name, datetime_dayfirst, message_line_format, do_reordering=False):
 
 
 
 
 
 
 
 
 
218
  """
219
  Process a chat file and return a dataset with the conversations.
220
  """
@@ -224,50 +236,83 @@ def process_chat_file(file, do_spelling_correction, whatsapp_name, datetime_dayf
224
  message_line_format
225
  )
226
 
227
- def process_line(example):
228
  # The lines have this format: dd/mm/yy, hh:mm - <person>: <msg>
229
- try:
230
- groups = exp.match(example["text"]).groupdict()
231
- timestamp = dateutil.parser.parse(groups['msg_datetime'], dayfirst=datetime_dayfirst).timestamp()
232
- return {
233
- "message": groups["message"],
234
- "contact_name": groups["contact_name"],
235
- "timestamp": timestamp,
236
- }
237
- except Exception as e:
238
- logger.exception(example["text"])
239
- raise e
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  try:
242
  ds = datasets.load_dataset("text", data_files=[file])["train"]
243
  except Exception as e:
244
  logger.exception(f"Error while loading file {file}")
245
  raise Exception(f"Error while loading file {file}") from e
 
 
 
 
 
 
 
 
 
 
 
 
246
  try:
247
- ds = ds.filter(
248
- # Has to begin by date, time, contact name, and contain at least a ':' symbol
249
- lambda x: re.match(
250
- r"^\d{1,2}/\d{1,2}/\d{1,2},\s\d{2}:\d{2}\s-\s.+:", x["text"]
251
- )
252
- )
253
- except Exception as e:
254
- logger.exception(f"Error filtering the lines in file {file} so they match the expected format")
255
- raise Exception(f"Error filtering the lines in file {file} so they match the expected format") from e
256
- try:
257
- ds = ds.map(process_line, remove_columns=["text"])
258
  except Exception as e:
259
- logger.exception(f"Error mapping the lines in file {file} to the expected format")
260
- raise Exception(f"Error mapping the lines in file {file} to the expected format") from e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  try:
263
  # Filter out messages that just say '<Media omitted>'
264
  ds = ds.filter(lambda x: x["message"] != "<Media omitted>")
265
  except Exception as e:
266
- logger.exception(f"Error filtering out messages that say '<Media omitted>' in file {file}")
267
- raise Exception(f"Error filtering out messages that say '<Media omitted>' in file {file}") from e
 
 
 
 
268
 
269
  try:
270
- groups = group_messages(iter(ds))
271
  # Generate the dataset
272
  conversations_ds = datasets.Dataset.from_dict({"conversations": groups})
273
  except Exception as e:
@@ -277,11 +322,15 @@ def process_chat_file(file, do_spelling_correction, whatsapp_name, datetime_dayf
277
  try:
278
  # Filter out conversations with less than 5 messages
279
  conversations_ds = conversations_ds.filter(
280
- lambda x: len(x["conversations"]) >= MIN_MESSAGES_THRESHOLD
281
  )
282
  except Exception as e:
283
- logger.exception(f"Error filtering out conversations with less than {MIN_MESSAGES_THRESHOLD} messages in file {file}")
284
- raise Exception(f"Error filtering out conversations with less than {MIN_MESSAGES_THRESHOLD} messages in file {file}") from e
 
 
 
 
285
 
286
  try:
287
  conversations_ds_without_whatsapp_annotations = conversations_ds.map(
@@ -295,11 +344,15 @@ def process_chat_file(file, do_spelling_correction, whatsapp_name, datetime_dayf
295
  if do_spelling_correction:
296
  try:
297
  spell_checked_conversations_ds = (
298
- conversations_ds_without_whatsapp_annotations.map(spell_check_conversation)
 
 
299
  )
300
  except Exception as e:
301
  logger.exception(f"Error spell checking the conversations in file {file}")
302
- raise Exception(f"Error spell checking the conversations in file {file}") from e
 
 
303
  else:
304
  spell_checked_conversations_ds = conversations_ds_without_whatsapp_annotations
305
 
@@ -327,7 +380,9 @@ def process_chat_file(file, do_spelling_correction, whatsapp_name, datetime_dayf
327
  ) # , num_proc=os.cpu_count() - 1)
328
  except Exception as e:
329
  logger.exception(f"Error changing your other contact's names in file {file}")
330
- raise Exception(f"Error changing your other contact's names in file {file}") from e
 
 
331
 
332
  try:
333
  # Filter out conversations with only one contact
@@ -335,18 +390,26 @@ def process_chat_file(file, do_spelling_correction, whatsapp_name, datetime_dayf
335
  lambda x: len(set([msg["contact_name"] for msg in x["conversations"]])) > 1
336
  )
337
  except Exception as e:
338
- logger.exception(f"Error filtering out conversations with only one contact in file {file}")
339
- raise Exception(f"Error filtering out conversations with only one contact in file {file}") from e
 
 
 
 
340
 
341
  return changed_contact_name_ds
342
 
343
 
344
- SPLIT_CONVERSATION_THRESHOLD = 40
345
- MAX_CHARACTERS_PER_MESSAGE = 10000 # Max is 8,192 tokens (https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-about#sample-datasets)
346
-
347
-
348
  def transform_conversations_dataset_into_training_examples(
349
- conversations_ds, system_prompt, user_role, model_role, whatsapp_name
 
 
 
 
 
 
 
 
350
  ):
351
  """
352
  Takes in a dataset with conversations and returns a dataset with training examples.
@@ -376,7 +439,7 @@ def transform_conversations_dataset_into_training_examples(
376
  model_role if msg["contact_name"] == whatsapp_name else user_role
377
  )
378
  if (
379
- counter > SPLIT_CONVERSATION_THRESHOLD
380
  and converted_role == user_role
381
  ):
382
  processed_examples.append(
@@ -401,7 +464,7 @@ def transform_conversations_dataset_into_training_examples(
401
  {"role": converted_role, "content": [msg["message"]]}
402
  )
403
  counter += 1
404
- if len(messages) >= MIN_MESSAGES_THRESHOLD:
405
  processed_examples.append(
406
  {
407
  "messages": [
@@ -415,8 +478,13 @@ def transform_conversations_dataset_into_training_examples(
415
  )
416
  else:
417
  logger.warning(
418
- f"Discarding conversation because the length is not at least {MIN_MESSAGES_THRESHOLD}: {messages}"
419
  )
 
 
 
 
 
420
  # Before returning, flatten the list of dictionaries into a dictionary of lists
421
  flattened_examples = {}
422
  for key in processed_examples[0].keys():
@@ -431,17 +499,25 @@ def transform_conversations_dataset_into_training_examples(
431
  batched=True,
432
  )
433
  except Exception as e:
434
- logger.exception("Error transforming the conversations dataset into training examples")
435
- raise Exception("Error transforming the conversations dataset into training examples") from e
 
 
 
 
436
 
437
  try:
438
  examples_filtered_by_length = processed_examples.filter(
439
  lambda x: all(
440
- [len(m["content"]) < MAX_CHARACTERS_PER_MESSAGE for m in x["messages"]]
441
  )
442
  )
443
  except Exception as e:
444
- logger.exception("Error filtering out examples with messages longer than the maximum allowed")
445
- raise Exception("Error filtering out examples with messages longer than the maximum allowed") from e
 
 
 
 
446
 
447
  return examples_filtered_by_length
 
9
  logger = logging.getLogger(__name__)
10
  logger.setLevel(logging.INFO)
11
 
 
 
 
 
 
 
 
12
 
13
+ # %%
14
+ def group_messages(messages_iterable, minutes_threshold):
15
+ """
16
+ Groups messages in a conversation. If the difference between two consecutive messages is less than `minutes_threshold` minutes, they are grouped together.
17
+ """
18
  groups = []
19
+ current_group = []
20
+ try:
21
+ first_message = next(messages_iterable)
22
+ current_group.append(first_message)
23
+ except StopIteration:
24
+ logger.exception("No messages in the conversation")
25
+ return []
26
  for message in messages_iterable:
27
  assert len(current_group) > 0 # We should never have an empty group
28
  if (
29
  message["timestamp"] - current_group[-1]["timestamp"]
30
+ < minutes_threshold * 60
31
  ):
32
  current_group.append(message)
33
  else:
 
217
 
218
 
219
  # %%
220
+ def process_chat_file(
221
+ file,
222
+ do_spelling_correction,
223
+ whatsapp_name,
224
+ datetime_dayfirst,
225
+ message_line_format,
226
+ minutes_threshold,
227
+ min_messages_per_conversation,
228
+ do_reordering=False,
229
+ ):
230
  """
231
  Process a chat file and return a dataset with the conversations.
232
  """
 
236
  message_line_format
237
  )
238
 
239
+ def process_line(examples):
240
  # The lines have this format: dd/mm/yy, hh:mm - <person>: <msg>
241
+ messages = []
242
+ contact_names = []
243
+ timestamps = []
244
+ for line_text in examples["text"]:
245
+ try:
246
+ groups = exp.match(line_text).groupdict()
247
+ # First, get the elements. If something fails here, it will raise an exception before actually adding the element to the list, so we'll be sure that the three lists contain the same # of elements.
248
+ timestamp = dateutil.parser.parse(
249
+ groups["msg_datetime"], dayfirst=datetime_dayfirst
250
+ ).timestamp()
251
+ message = groups["message"]
252
+ contact_name = groups["contact_name"]
253
+ messages.append(message)
254
+ contact_names.append(contact_name)
255
+ timestamps.append(timestamp)
256
+ except Exception as e:
257
+ logger.exception(f"Error while processing line {line_text}")
258
+ return {
259
+ "message": messages,
260
+ "contact_name": contact_names,
261
+ "timestamp": timestamps,
262
+ }
263
 
264
  try:
265
  ds = datasets.load_dataset("text", data_files=[file])["train"]
266
  except Exception as e:
267
  logger.exception(f"Error while loading file {file}")
268
  raise Exception(f"Error while loading file {file}") from e
269
+
270
+ # try:
271
+ # ds = ds.filter(
272
+ # # Has to begin by date, time, contact name, and contain at least a ':' symbol
273
+ # lambda x: re.match(
274
+ # r"^\d{1,2}/\d{1,2}/\d{1,4},\s\d{2}:\d{2}\s-\s.+:", x["text"]
275
+ # )
276
+ # )
277
+ # except Exception as e:
278
+ # logger.exception(f"Error filtering the lines in file {file} so they match the expected format")
279
+ # raise Exception(f"Error filtering the lines in file {file} so they match the expected format") from e
280
+
281
  try:
282
+ ds = ds.map(process_line, remove_columns=["text"], batched=True, batch_size=10)
 
 
 
 
 
 
 
 
 
 
283
  except Exception as e:
284
+ logger.exception(
285
+ f"Error mapping the lines in file {file} to the expected format"
286
+ )
287
+ raise Exception(
288
+ f"Error mapping the lines in file {file} to the expected format"
289
+ ) from e
290
+
291
+ # Check that the WhatsApp name is in at least one of the messages. If it's not, raise an exception
292
+ set_of_contact_names = ds.unique("contact_name")
293
+ if whatsapp_name not in set_of_contact_names:
294
+ raise Exception(
295
+ f"Your WhatsApp name ({whatsapp_name}) is not in the messages of at least one uploaded file. Please check that you wrote your name correctly. These were the participants found: {set_of_contact_names}"
296
+ )
297
+ # # Also check that the number of contact names is == 2 (i.e. we don't have group chats)
298
+ # if len(set_of_contact_names) > 2:
299
+ # raise Exception(
300
+ # f"There were more than 2 participants in at least one uploaded file. Please check that you're not using group chats. These were the participants found: {set_of_contact_names}"
301
+ # )
302
 
303
  try:
304
  # Filter out messages that just say '<Media omitted>'
305
  ds = ds.filter(lambda x: x["message"] != "<Media omitted>")
306
  except Exception as e:
307
+ logger.exception(
308
+ f"Error filtering out messages that say '<Media omitted>' in file {file}"
309
+ )
310
+ raise Exception(
311
+ f"Error filtering out messages that say '<Media omitted>' in file {file}"
312
+ ) from e
313
 
314
  try:
315
+ groups = group_messages(iter(ds), minutes_threshold=minutes_threshold)
316
  # Generate the dataset
317
  conversations_ds = datasets.Dataset.from_dict({"conversations": groups})
318
  except Exception as e:
 
322
  try:
323
  # Filter out conversations with less than 5 messages
324
  conversations_ds = conversations_ds.filter(
325
+ lambda x: len(x["conversations"]) >= min_messages_per_conversation
326
  )
327
  except Exception as e:
328
+ logger.exception(
329
+ f"Error filtering out conversations with less than {min_messages_per_conversation} messages in file {file}"
330
+ )
331
+ raise Exception(
332
+ f"Error filtering out conversations with less than {min_messages_per_conversation} messages in file {file}"
333
+ ) from e
334
 
335
  try:
336
  conversations_ds_without_whatsapp_annotations = conversations_ds.map(
 
344
  if do_spelling_correction:
345
  try:
346
  spell_checked_conversations_ds = (
347
+ conversations_ds_without_whatsapp_annotations.map(
348
+ spell_check_conversation
349
+ )
350
  )
351
  except Exception as e:
352
  logger.exception(f"Error spell checking the conversations in file {file}")
353
+ raise Exception(
354
+ f"Error spell checking the conversations in file {file}"
355
+ ) from e
356
  else:
357
  spell_checked_conversations_ds = conversations_ds_without_whatsapp_annotations
358
 
 
380
  ) # , num_proc=os.cpu_count() - 1)
381
  except Exception as e:
382
  logger.exception(f"Error changing your other contact's names in file {file}")
383
+ raise Exception(
384
+ f"Error changing your other contact's names in file {file}"
385
+ ) from e
386
 
387
  try:
388
  # Filter out conversations with only one contact
 
390
  lambda x: len(set([msg["contact_name"] for msg in x["conversations"]])) > 1
391
  )
392
  except Exception as e:
393
+ logger.exception(
394
+ f"Error filtering out conversations with only one contact in file {file}"
395
+ )
396
+ raise Exception(
397
+ f"Error filtering out conversations with only one contact in file {file}"
398
+ ) from e
399
 
400
  return changed_contact_name_ds
401
 
402
 
 
 
 
 
403
  def transform_conversations_dataset_into_training_examples(
404
+ conversations_ds,
405
+ system_prompt,
406
+ user_role,
407
+ model_role,
408
+ whatsapp_name,
409
+ minutes_threshold,
410
+ min_messages_per_conversation,
411
+ split_conversation_threshold,
412
+ max_characters_per_message,
413
  ):
414
  """
415
  Takes in a dataset with conversations and returns a dataset with training examples.
 
439
  model_role if msg["contact_name"] == whatsapp_name else user_role
440
  )
441
  if (
442
+ counter > split_conversation_threshold
443
  and converted_role == user_role
444
  ):
445
  processed_examples.append(
 
464
  {"role": converted_role, "content": [msg["message"]]}
465
  )
466
  counter += 1
467
+ if len(messages) >= min_messages_per_conversation:
468
  processed_examples.append(
469
  {
470
  "messages": [
 
478
  )
479
  else:
480
  logger.warning(
481
+ f"Discarding conversation because the length is not at least {min_messages_per_conversation}: {messages}"
482
  )
483
+ if len(processed_examples) == 0:
484
+ logger.warning(
485
+ f"Discarding all conversations because none of them have at least {min_messages_per_conversation} messages"
486
+ )
487
+ return {}
488
  # Before returning, flatten the list of dictionaries into a dictionary of lists
489
  flattened_examples = {}
490
  for key in processed_examples[0].keys():
 
499
  batched=True,
500
  )
501
  except Exception as e:
502
+ logger.exception(
503
+ "Error transforming the conversations dataset into training examples"
504
+ )
505
+ raise Exception(
506
+ "Error transforming the conversations dataset into training examples"
507
+ ) from e
508
 
509
  try:
510
  examples_filtered_by_length = processed_examples.filter(
511
  lambda x: all(
512
+ [len(m["content"]) < max_characters_per_message for m in x["messages"]]
513
  )
514
  )
515
  except Exception as e:
516
+ logger.exception(
517
+ "Error filtering out examples with messages longer than the maximum allowed"
518
+ )
519
+ raise Exception(
520
+ "Error filtering out examples with messages longer than the maximum allowed"
521
+ ) from e
522
 
523
  return examples_filtered_by_length