|
|
|
from uuid import uuid4 |
|
import gradio as gr |
|
import datasets |
|
import json |
|
import io |
|
from utils import ( |
|
process_chat_file, |
|
transform_conversations_dataset_into_training_examples, |
|
) |
|
from validation import ( |
|
check_format_errors, |
|
check_token_counts, |
|
estimate_cost, |
|
get_distributions, |
|
) |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def convert_to_dataset(files, do_spelling_correction, progress): |
|
modified_dataset = None |
|
for file in progress.tqdm(files, desc="Processing files"): |
|
if modified_dataset is None: |
|
|
|
modified_dataset = process_chat_file(file, do_spelling_correction=do_spelling_correction) |
|
else: |
|
|
|
this_file_dataset = process_chat_file(file, do_spelling_correction=do_spelling_correction) |
|
modified_dataset = datasets.concatenate_datasets( |
|
[modified_dataset, this_file_dataset] |
|
) |
|
return modified_dataset |
|
|
|
|
|
def file_upload_callback(files, system_prompt, do_spelling_correction, validation_split, progress=gr.Progress()): |
|
print(f"Processing {files}") |
|
full_system_prompt = f"""You are a chatbot. Your goal is to simulate realistic, natural chat conversations as if you were me. |
|
# Task |
|
A participant can send multiple messages in a row, delimited by '\"', in the following schema: |
|
{{string}}[]. Your answer always needs to be JSON compliant. Always start your answer with [\" |
|
# Information about me |
|
You should use the following information about me to answer: |
|
{system_prompt} |
|
# Example |
|
[{{\"role\":\"user\",\"content\":\"[\"Hello!\",\"How are you?\"]\"}},{{\"role\":\"assistant\",\"content\":\"[\"Hi!\",\"I'm doing great.\",\"What about you?\"]\"}},{{\"role\":\"user\",\"content\":\"[\"I'm doing well.\",\"Have you been travelling?\"]\"}}] |
|
Response: |
|
[{{\"role\":\"assistant\",\"content\":\"[\"Yes, I've been to many places.\",\"I love travelling.\"]\"}}]""" |
|
|
|
|
|
full_system_prompt = system_prompt |
|
dataset = convert_to_dataset(files=files, progress=progress, do_spelling_correction=do_spelling_correction) |
|
training_examples_ds = transform_conversations_dataset_into_training_examples( |
|
conversations_ds=dataset, system_prompt=full_system_prompt |
|
) |
|
|
|
|
|
training_examples_ds = training_examples_ds.train_test_split(test_size=validation_split, seed=42) |
|
training_examples_ds, validation_examples_ds = training_examples_ds["train"], training_examples_ds["test"] |
|
|
|
format_errors = check_format_errors(training_examples_ds) |
|
distributions = get_distributions(training_examples_ds) |
|
cost_stats = estimate_cost(training_examples_ds) |
|
|
|
stats = { |
|
"Format Errors": format_errors, |
|
"Number of examples missing system message": distributions["n_missing_system"], |
|
"Number of examples missing user message": distributions["n_missing_user"], |
|
"Cost Statistics": cost_stats, |
|
} |
|
|
|
fig_num_messages_distribution_plot = plt.figure() |
|
num_messages_distribution_plot = plt.hist(distributions["n_messages"], bins=20) |
|
|
|
fig_num_total_tokens_per_example_plot = plt.figure() |
|
num_total_tokens_per_example_plot = plt.hist(distributions["convo_lens"], bins=20) |
|
|
|
fig_num_assistant_tokens_per_example_plot = plt.figure() |
|
num_assistant_tokens_per_example_plot = plt.hist( |
|
distributions["assistant_message_lens"], |
|
bins=20 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
uuid = str(uuid4()) |
|
file_path = f"training_examples_{uuid}.jsonl" |
|
training_examples_ds.to_json(path_or_buf=file_path, force_ascii=False) |
|
|
|
file_path_validation = f"validation_examples_{uuid}.jsonl" |
|
validation_examples_ds.to_json(path_or_buf=file_path_validation, force_ascii=False) |
|
|
|
return ( |
|
file_path, |
|
gr.update(visible=True), |
|
file_path_validation, |
|
gr.update(visible=True), |
|
stats, |
|
fig_num_messages_distribution_plot, |
|
fig_num_total_tokens_per_example_plot, |
|
fig_num_assistant_tokens_per_example_plot |
|
) |
|
|
|
|
|
def remove_file_and_hide_button(file_path): |
|
import os |
|
|
|
try: |
|
os.remove(file_path) |
|
except Exception as e: |
|
print(f"Error removing file {file_path}: {e}") |
|
|
|
return gr.update(visible=False) |
|
|
|
|
|
theme = gr.themes.Default(primary_hue="cyan", secondary_hue="fuchsia") |
|
|
|
with gr.Blocks(theme=theme) as demo: |
|
gr.Markdown( |
|
""" |
|
# WhatsApp Chat to Dataset Converter |
|
Upload your WhatsApp chat files and convert them into a Dataset. |
|
""" |
|
) |
|
gr.Markdown( |
|
""" |
|
## Instructions |
|
1. Click on the "Upload WhatsApp Chat Files" button. |
|
2. Select the WhatsApp chat files you want to convert. |
|
3. Write a prompt about you to give context to the training examples. |
|
4. Click on the "Submit" button. |
|
5. Wait for the process to finish. |
|
6. Download the generated training examples as a JSONL file. |
|
7. Use the training examples to train your own model. |
|
""" |
|
) |
|
|
|
input_files = gr.File( |
|
label="Upload WhatsApp Chat Files", |
|
type="filepath", |
|
file_count="multiple", |
|
file_types=["txt"], |
|
) |
|
|
|
system_prompt = gr.Textbox( |
|
label="System Prompt", |
|
placeholder="Background information about you.", |
|
lines=5, |
|
info="Enter the system prompt to be used for the training examples generation. This is the background information about you that will be used to generate the training examples.", |
|
value="""Aldan is an AI researcher who loves to play around with AI systems, travelling and learning new things.""", |
|
) |
|
|
|
do_spelling_correction = gr.Checkbox( |
|
label="Do Spelling Correction (English)", |
|
info="Check this box if you want to perform spelling correction on the chat messages before generating the training examples.", |
|
) |
|
|
|
|
|
validation_split = gr.Slider( |
|
minimum=0.0, |
|
maximum=0.5, |
|
value=0.2, |
|
interactive=True, |
|
label="Validation Split", |
|
info="Choose the percentage of the dataset to be used for validation. For example, if you choose 0.2, 20% of the dataset will be used for validation and 80% for training.", |
|
) |
|
|
|
submit = gr.Button(value="Submit", variant="primary") |
|
|
|
output_file = gr.DownloadButton(label="Download Generated Training Examples", visible=False, variant="primary") |
|
output_file_validation = gr.DownloadButton(label="Download Generated Validation Examples", visible=False, variant="secondary") |
|
|
|
|
|
with gr.Group(): |
|
|
|
gr.Markdown("## Statistics") |
|
written_stats = gr.JSON() |
|
num_messages_distribution_plot = gr.Plot(label="Number of Messages Distribution") |
|
num_total_tokens_per_example_plot = gr.Plot(label="Total Number of Tokens per Example") |
|
num_assistant_tokens_per_example_plot = gr.Plot( |
|
label="Number of Assistant Tokens per Example" |
|
) |
|
|
|
submit.click( |
|
file_upload_callback, |
|
inputs=[input_files, system_prompt, do_spelling_correction, validation_split], |
|
outputs=[ |
|
output_file, |
|
output_file, |
|
output_file_validation, |
|
output_file_validation, |
|
written_stats, |
|
num_messages_distribution_plot, |
|
num_total_tokens_per_example_plot, |
|
num_assistant_tokens_per_example_plot, |
|
] |
|
) |
|
|
|
output_file.click(remove_file_and_hide_button, inputs=[output_file], outputs=[output_file]) |
|
output_file_validation.click(remove_file_and_hide_button, inputs=[output_file_validation], outputs=[output_file_validation]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|