|
|
|
import io |
|
import json |
|
import logging |
|
from uuid import uuid4 |
|
|
|
import datasets |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
|
|
from utils import ( |
|
process_chat_file, |
|
transform_conversations_dataset_into_training_examples, |
|
) |
|
from validation import check_format_errors, estimate_cost, get_distributions |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
def convert_to_dataset( |
|
files, |
|
do_spelling_correction, |
|
progress, |
|
whatsapp_name, |
|
datetime_dayfirst, |
|
message_line_format, |
|
minutes_threshold, |
|
min_messages_per_conversation, |
|
): |
|
modified_dataset = None |
|
for file in progress.tqdm(files, desc="Processing files"): |
|
try: |
|
if modified_dataset is None: |
|
|
|
modified_dataset = process_chat_file( |
|
file, |
|
do_spelling_correction=do_spelling_correction, |
|
whatsapp_name=whatsapp_name, |
|
datetime_dayfirst=datetime_dayfirst, |
|
message_line_format=message_line_format, |
|
minutes_threshold=minutes_threshold, |
|
min_messages_per_conversation=min_messages_per_conversation, |
|
) |
|
else: |
|
|
|
this_file_dataset = process_chat_file( |
|
file, |
|
do_spelling_correction=do_spelling_correction, |
|
whatsapp_name=whatsapp_name, |
|
datetime_dayfirst=datetime_dayfirst, |
|
message_line_format=message_line_format, |
|
minutes_threshold=minutes_threshold, |
|
min_messages_per_conversation=min_messages_per_conversation, |
|
) |
|
modified_dataset = datasets.concatenate_datasets( |
|
[modified_dataset, this_file_dataset] |
|
) |
|
except Exception as e: |
|
logger.error(f"Error processing file {file}: {e}") |
|
raise gr.Error(f"Error processing file {file}: {e}") |
|
return modified_dataset |
|
|
|
|
|
def file_upload_callback( |
|
files, |
|
system_prompt, |
|
do_spelling_correction, |
|
validation_split, |
|
user_role, |
|
model_role, |
|
whatsapp_name, |
|
datetime_dayfirst, |
|
message_line_format, |
|
minutes_threshold, |
|
min_messages_per_conversation, |
|
max_characters_per_message, |
|
split_conversation_threshold, |
|
progress=gr.Progress(), |
|
): |
|
logger.info(f"Processing {files}") |
|
full_system_prompt = f"""# Task |
|
You are a chatbot. Your goal is to simulate realistic, natural chat conversations as if you were me. |
|
The {model_role} and the {user_role} can send multiple messages in a row, as a JSON list of strings. Your answer always needs to be JSON compliant. The strings are delimited by double quotes ("). The strings are separated by a comma (,). The list is delimited by square brackets ([, ]). Always start your answer with [", and close it with "]. Do not write anything else in your answer after "]. |
|
# Information about me |
|
{system_prompt}""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
if not files or len(files) == 0: |
|
raise gr.Error("Please upload at least one file.") |
|
|
|
|
|
if not whatsapp_name or len(whatsapp_name) == 0: |
|
raise gr.Error("Please enter your WhatsApp name.") |
|
|
|
|
|
|
|
dataset = convert_to_dataset( |
|
files=files, |
|
progress=progress, |
|
do_spelling_correction=do_spelling_correction, |
|
whatsapp_name=whatsapp_name, |
|
datetime_dayfirst=datetime_dayfirst, |
|
message_line_format=message_line_format, |
|
minutes_threshold=minutes_threshold, |
|
min_messages_per_conversation=min_messages_per_conversation, |
|
) |
|
logger.info( |
|
f"Number of conversations of dataset before being transformed: {len(dataset)}" |
|
) |
|
|
|
full_examples_ds = transform_conversations_dataset_into_training_examples( |
|
conversations_ds=dataset, |
|
system_prompt=full_system_prompt, |
|
user_role=user_role, |
|
model_role=model_role, |
|
whatsapp_name=whatsapp_name, |
|
minutes_threshold=minutes_threshold, |
|
min_messages_per_conversation=min_messages_per_conversation, |
|
split_conversation_threshold=split_conversation_threshold, |
|
max_characters_per_message=max_characters_per_message, |
|
) |
|
total_number_of_generated_examples = len(full_examples_ds) |
|
logger.info( |
|
f"Total number of generated examples: {total_number_of_generated_examples}" |
|
) |
|
|
|
|
|
try: |
|
split_examples_ds = full_examples_ds.train_test_split( |
|
test_size=validation_split, seed=42 |
|
) |
|
training_examples_ds, validation_examples_ds = ( |
|
split_examples_ds["train"], |
|
split_examples_ds["test"], |
|
) |
|
except ValueError as e: |
|
|
|
|
|
training_examples_ds = full_examples_ds |
|
validation_examples_ds = datasets.Dataset.from_dict({}) |
|
training_examples_ds = training_examples_ds |
|
|
|
|
|
validation_examples_ds = validation_examples_ds.select( |
|
range(min(200, len(validation_examples_ds))) |
|
) |
|
|
|
format_errors = check_format_errors( |
|
training_examples_ds, user_role=user_role, model_role=model_role |
|
) |
|
distributions = get_distributions( |
|
training_examples_ds, user_role=user_role, model_role=model_role |
|
) |
|
cost_stats = estimate_cost( |
|
training_examples_ds, user_role=user_role, model_role=model_role |
|
) |
|
|
|
stats = { |
|
"Total number of training examples": total_number_of_generated_examples, |
|
"Number of training examples": len(training_examples_ds), |
|
"Number of validation examples": len(validation_examples_ds), |
|
"Number of examples missing system message": distributions["n_missing_system"], |
|
"Number of examples missing user message": distributions["n_missing_user"], |
|
"Format Errors": format_errors, |
|
"Cost Statistics": cost_stats, |
|
} |
|
|
|
fig_num_messages_distribution_plot = plt.figure() |
|
num_messages_distribution_plot = plt.hist(distributions["n_messages"], bins=20) |
|
|
|
fig_num_total_tokens_per_example_plot = plt.figure() |
|
num_total_tokens_per_example_plot = plt.hist(distributions["convo_lens"], bins=20) |
|
|
|
fig_num_assistant_tokens_per_example_plot = plt.figure() |
|
num_assistant_tokens_per_example_plot = plt.hist( |
|
distributions["assistant_message_lens"], bins=20 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
uuid = str(uuid4()) |
|
file_path = f"training_examples_{uuid}.jsonl" |
|
training_examples_ds.to_json(path_or_buf=file_path, force_ascii=False) |
|
|
|
file_path_validation = f"validation_examples_{uuid}.jsonl" |
|
validation_examples_ds.to_json(path_or_buf=file_path_validation, force_ascii=False) |
|
|
|
|
|
if len(training_examples_ds) < 50: |
|
gr.Warning( |
|
"There are less than 50 training examples. The model may not perform well with such a small dataset. Consider adding more chat files to increase the number of training examples." |
|
) |
|
|
|
system_prompt_to_use = full_system_prompt |
|
|
|
return ( |
|
file_path, |
|
gr.update(visible=True), |
|
file_path_validation, |
|
gr.update(visible=True), |
|
stats, |
|
fig_num_messages_distribution_plot, |
|
fig_num_total_tokens_per_example_plot, |
|
fig_num_assistant_tokens_per_example_plot, |
|
system_prompt_to_use, |
|
gr.update(visible=True), |
|
) |
|
|
|
|
|
def remove_file_and_hide_button(file_path): |
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
return gr.update(visible=False) |
|
|
|
|
|
theme = gr.themes.Default(primary_hue="cyan", secondary_hue="fuchsia") |
|
|
|
with gr.Blocks(theme=theme) as demo: |
|
gr.Markdown( |
|
""" |
|
# WhatsApp Chat to Dataset Converter |
|
Upload your WhatsApp chat files and convert them into a Dataset. |
|
""" |
|
) |
|
gr.Markdown( |
|
""" |
|
## Instructions |
|
1. Click on the "Upload WhatsApp Chat Files" button. |
|
2. Select the WhatsApp chat files you want to convert. |
|
3. Write a prompt about you to give context to the training examples. |
|
4. Click on the "Submit" button. |
|
5. Wait for the process to finish. |
|
6. Download the generated training examples as a JSONL file. |
|
7. Use the training examples to train your own model. |
|
""" |
|
) |
|
|
|
input_files = gr.File( |
|
label="Upload WhatsApp Chat Files", |
|
type="filepath", |
|
file_count="multiple", |
|
file_types=["txt"], |
|
) |
|
|
|
system_prompt = gr.Textbox( |
|
label="System Prompt", |
|
placeholder="Background information about you.", |
|
lines=5, |
|
info="Enter the system prompt to be used for the training examples generation. This is the background information about you that will be used to generate the training examples.", |
|
value="""Aldan is an AI researcher who loves to play around with AI systems, travelling and learning new things.""", |
|
) |
|
|
|
whatsapp_name = gr.Textbox( |
|
label="Your WhatsApp Name", |
|
placeholder="Your WhatsApp Name", |
|
info="Enter your WhatsApp name as it appears in your profile. It needs to match exactly your name. If you're unsure, you can check the chat messages to see it.", |
|
) |
|
|
|
|
|
with gr.Accordion(label="Advanced Parameters", open=False): |
|
gr.Markdown( |
|
""" |
|
These are advanced parameters that you can change if you know what you're doing. If you're unsure, you can leave them as they are. |
|
""" |
|
) |
|
|
|
user_role = gr.Textbox( |
|
label="Role for User", |
|
info="This is a technical parameter. If you don't know what to write, just type 'user'.", |
|
value="user", |
|
) |
|
|
|
model_role = gr.Textbox( |
|
label="Role for Model", |
|
info="This is a technical parameter. Usual values are 'model' (e.g. Vertex AI) or 'assistant' (e.g. OpenAI).", |
|
value="model", |
|
) |
|
|
|
minutes_threshold = gr.Number( |
|
label="Minutes Threshold", |
|
info="Threshold in minutes to consider that a new message is a new conversation. The default value should work for most cases.", |
|
value=180, |
|
) |
|
|
|
min_messages_per_conversation = gr.Number( |
|
label="Minimum Messages per Conversation", |
|
info="Minimum number of messages per conversation to consider it as a valid conversation. The default value should work for most cases.", |
|
value=5, |
|
) |
|
|
|
max_characters_per_message = gr.Number( |
|
label="Max Characters per Message", |
|
info="One token is around 3 characters. The default value should work for most cases. For example, on Vertex AI, the maximum number of tokens per example is [32,000](https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-prepare#sample-datasets), so keeping the default value will ensure that the examples are well within the limit.", |
|
value=10000, |
|
) |
|
|
|
split_conversation_threshold = gr.Number( |
|
label="Split Conversation Threshold", |
|
info="Number of messages in a conversation to split it into multiple ones. The default value should work for most cases.", |
|
value=40, |
|
) |
|
|
|
message_line_format = gr.Textbox( |
|
label="Message Line Format", |
|
info="Format of each message line in the chat file, as a regular expression. The default value should work for most cases.", |
|
value=r"\[?(?P<msg_datetime>\S+,\s\S+?(?:\s[APap][Mm])?)\]? (?:- )?(?P<contact_name>.+?): (?P<message>.+)", |
|
) |
|
|
|
datetime_dayfirst = gr.Checkbox( |
|
label="Date format: Day first", |
|
info="Check this box if the date time format in the chat messages is in the format 'DD/MM/YYYY'. You can check your phone settings to see the date format. Otherwise, it will be assumed that the date time format is 'MM/DD/YYYY'.", |
|
value=True, |
|
) |
|
|
|
do_spelling_correction = gr.Checkbox( |
|
label="Do Spelling Correction (English)", |
|
info="Check this box if you want to perform spelling correction on the chat messages before generating the training examples.", |
|
) |
|
|
|
|
|
validation_split = gr.Slider( |
|
minimum=0.0, |
|
maximum=0.5, |
|
value=0.2, |
|
interactive=True, |
|
label="Validation Split", |
|
info="Choose the percentage of the dataset to be used for validation. For example, if you choose 0.2, 20% of the dataset will be used for validation and 80% for training.", |
|
) |
|
|
|
submit = gr.Button(value="Submit", variant="primary") |
|
|
|
output_file = gr.DownloadButton( |
|
label="Download Generated Training Examples", visible=False, variant="primary" |
|
) |
|
output_file_validation = gr.DownloadButton( |
|
label="Download Generated Validation Examples", |
|
visible=False, |
|
variant="secondary", |
|
) |
|
|
|
system_prompt_to_use = gr.Textbox( |
|
label="System Prompt that you can use", |
|
visible=False, |
|
interactive=False, |
|
show_copy_button=True, |
|
info="When using the model, if you're asked for a system prompt, you can use this text.", |
|
) |
|
|
|
|
|
with gr.Group(): |
|
|
|
gr.Markdown("## Statistics") |
|
written_stats = gr.JSON() |
|
num_messages_distribution_plot = gr.Plot( |
|
label="Number of Messages Distribution" |
|
) |
|
num_total_tokens_per_example_plot = gr.Plot( |
|
label="Total Number of Tokens per Example" |
|
) |
|
num_assistant_tokens_per_example_plot = gr.Plot( |
|
label="Number of Assistant Tokens per Example" |
|
) |
|
|
|
submit.click( |
|
file_upload_callback, |
|
inputs=[ |
|
input_files, |
|
system_prompt, |
|
do_spelling_correction, |
|
validation_split, |
|
user_role, |
|
model_role, |
|
whatsapp_name, |
|
datetime_dayfirst, |
|
message_line_format, |
|
minutes_threshold, |
|
min_messages_per_conversation, |
|
max_characters_per_message, |
|
split_conversation_threshold, |
|
], |
|
outputs=[ |
|
output_file, |
|
output_file, |
|
output_file_validation, |
|
output_file_validation, |
|
written_stats, |
|
num_messages_distribution_plot, |
|
num_total_tokens_per_example_plot, |
|
num_assistant_tokens_per_example_plot, |
|
system_prompt_to_use, |
|
system_prompt_to_use, |
|
], |
|
) |
|
|
|
output_file.click( |
|
remove_file_and_hide_button, inputs=[output_file], outputs=[output_file] |
|
) |
|
output_file_validation.click( |
|
remove_file_and_hide_button, |
|
inputs=[output_file_validation], |
|
outputs=[output_file_validation], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|