ACMC
Bugfix
bd73a7b
raw
history blame
15.3 kB
import datetime
import json
import logging
import os
import re
import datasets
import dateutil.parser
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# %%
# Now, create message groups ('conversations')
# The idea is to group messages that are close in time
# We'll use a 180 minute threshold
MINUTES_THRESHOLD = 180
MIN_MESSAGES_THRESHOLD = 5
def group_messages(messages_iterable):
groups = []
current_group = [next(messages_iterable)]
for message in messages_iterable:
assert len(current_group) > 0 # We should never have an empty group
if (
message["timestamp"] - current_group[-1]["timestamp"]
< MINUTES_THRESHOLD * 60
):
current_group.append(message)
else:
groups.append(current_group)
current_group = [message]
groups.append(current_group)
return groups
def printable_conversation(conversation):
return "\n".join(
[f"{message['contact_name']}: {message['message']}" for message in conversation]
)
import contextualSpellCheck
# %%
# Use spacy to spell check the messages
import spacy
from spellchecker import SpellChecker
spell = SpellChecker()
# nlp = spacy.load("es_core_news_sm")
nlp = spacy.load("en_core_web_sm")
def spell_check_conversation(conversation):
for i, message in enumerate(conversation["conversations"]):
# Use SpaCy to get the words
words = spell.split_words(message["message"])
logger.info(f"Words: {words}")
corrected_message = []
for word in words:
correction = spell.correction(word)
if (correction != None) and (correction != word):
logger.info(f"Spell check: {word} -> {correction}")
corrected_message.append(correction)
else:
corrected_message.append(word)
logger.info(f"Corrected message: {corrected_message}")
joined_message = " ".join(corrected_message)
conversation["conversations"][i]["message"] = joined_message
return conversation
def spell_check_conversation_spacy(conversation):
nlp.add_pipe(
"contextual spellchecker",
config={
"model_name": "bert-base-multilingual-uncased",
"max_edit_dist": 2,
},
)
docs = list(nlp.pipe([msg["message"] for msg in conversation["conversations"]]))
for i, doc in enumerate(docs):
if doc._.performed_spellCheck:
logger.info(f"Spell checked: {doc.text} -> {doc._.outcome_spellCheck}")
conversation["conversations"][i]["message"] = doc._.outcome_spellCheck
return conversation
def remove_whatapp_annotations(conversation):
"""
Removes the following annotations from the messages:
- <This message was edited>
"""
for message in conversation["conversations"]:
message["message"] = re.sub(
r"<This message was edited>", "", message["message"]
)
return conversation
# %%
"""
Sometimes, people write concurrently in the same conversation. We'll try to detect that and reorder the messages.
For example, if we have a conversation like this:
A: Hi
A: How are you?
B: Hi
B: I'm fine, thanks
A: I'm fine too
We'll reorder it to:
A: Hi
B: Hi
A: How are you?
B: I'm fine, thanks
A: I'm fine too
To do it, we'll use MobileBERT with the next sentence prediction head. We'll use the first message as the first sentence, and the second message as the second sentence. If the model predicts that the second sentence is more likely to be the next sentence, we'll swap the messages.
"""
import torch
from transformers import AutoModelForNextSentencePrediction, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForNextSentencePrediction.from_pretrained("bert-base-uncased")
if torch.cuda.is_available():
model.cuda()
def swap_messages_if_needed(message1, message2):
# If the messages have the same contact, we don't swap them
if message1["contact_name"] == message2["contact_name"]:
return message1, message2
# The timestamp must have a difference of less than 2 minutes. First, convert to datetime
datetime1 = datetime.datetime.fromtimestamp(message1["timestamp"])
datetime2 = datetime.datetime.fromtimestamp(message2["timestamp"])
if (datetime2 - datetime1).total_seconds() > 2 * 60:
return message1, message2
# If one of the messages has less than 3 words, we don't swap them
if len(message1["message"].split()) < 3 or len(message2["message"].split()) < 3:
return message1, message2
# We'll use the first message as the first sentence, and the second message as the second sentence
inputs = tokenizer(message1["message"], message2["message"], return_tensors="pt")
reverse_inputs = tokenizer(
message2["message"], message1["message"], return_tensors="pt"
)
# Join them in a single batch
joined_inputs = torch.cat([inputs["input_ids"], reverse_inputs["input_ids"]], dim=0)
if torch.cuda.is_available():
joined_inputs = joined_inputs.cuda()
with torch.no_grad():
outputs = model(input_ids=joined_inputs)
# The output is a tuple with the logits for each class (next sentence or not)
# We'll take the first one (next sentence)
logits = outputs[0]
# Apply softmax
logits = torch.softmax(logits, dim=1)
# We have two probabilities: the probability of 1 -> 2, and the probability of 2 -> 1
# We'll take the difference
swap = logits[0, 0] - logits[1, 0] < -0.2
if swap:
# Swap the messages
logger.info(
f"Swapping messages: {message1['message']} <-> {message2['message']}"
)
return message2, message1
else:
# logger.info(f"NOT swapping messages: {message1['message']} <-> {message2['message']}")
return message1, message2
def swap_messages_if_needed_in_conversation(conversation):
# We'll use the first message as the first sentence, and the second message as the second sentence
if len(conversation) <= 2:
return conversation
new_conversation = [
conversation[0],
conversation[1],
] # We'll always keep the first message in the same position
for i in range(2, len(conversation)):
message1 = new_conversation[-1]
message2 = conversation[i]
message1, message2 = swap_messages_if_needed(message1, message2)
new_conversation[-1] = message1
new_conversation.append(message2)
# logger.info(f"\nOriginal conversation:\n{printable_conversation(conversation)}")
# logger.info(f"\nNew conversation:\n{printable_conversation(new_conversation)}")
return new_conversation
test_conversation = [
{"message": "Hola!", "contact_name": "A", "timestamp": 1},
{
"message": "Está todo bien, gracias por preguntar!",
"contact_name": "B",
"timestamp": 2,
},
{
"message": "Hola, qué tal estás? Espero que vaya todo bien por España.",
"contact_name": "A",
"timestamp": 3,
},
]
# logger.info(swap_messages_if_needed_in_conversation(test_conversation))
# %%
# Now, we'll train an mT5 model to generate the next message in a conversation
import os
# %%
def process_chat_file(file, do_spelling_correction, whatsapp_name, datetime_dayfirst, message_line_format, do_reordering=False):
"""
Process a chat file and return a dataset with the conversations.
"""
exp = re.compile(
# r"(?P<msg_datetime>.+?) - (?P<contact_name>.+): (?P<message>.+)"
# r"\[?(?P<msg_datetime>\S+,\s\S+?(?:\s[APap][Mm])?)\]? (?:- )?(?P<contact_name>.+): (?P<message>.+)"
message_line_format
)
def process_line(example):
# The lines have this format: dd/mm/yy, hh:mm - <person>: <msg>
try:
groups = exp.match(example["text"]).groupdict()
timestamp = dateutil.parser.parse(groups['msg_datetime'], dayfirst=datetime_dayfirst).timestamp()
return {
"message": groups["message"],
"contact_name": groups["contact_name"],
"timestamp": timestamp,
}
except Exception as e:
logger.exception(example["text"])
raise e
ds = (
datasets.load_dataset("text", data_files=[file])["train"]
.filter(
# Has to begin by date, time, contact name, and contain at least a ':' symbol
lambda x: re.match(
r"^\d{1,2}/\d{1,2}/\d{1,2},\s\d{2}:\d{2}\s-\s.+:", x["text"]
)
)
.map(process_line, remove_columns=["text"])
)
# Filter out messages that just say '<Media omitted>'
ds = ds.filter(lambda x: x["message"] != "<Media omitted>")
groups = group_messages(iter(ds))
# Generate the dataset
conversations_ds = datasets.Dataset.from_dict({"conversations": groups})
# Filter out conversations with less than 5 messages
conversations_ds = conversations_ds.filter(
lambda x: len(x["conversations"]) >= MIN_MESSAGES_THRESHOLD
)
conversations_ds_without_whatsapp_annotations = conversations_ds.map(
remove_whatapp_annotations,
num_proc=os.cpu_count() - 1,
)
if do_spelling_correction:
spell_checked_conversations_ds = (
conversations_ds_without_whatsapp_annotations.map(spell_check_conversation)
)
else:
spell_checked_conversations_ds = conversations_ds_without_whatsapp_annotations
if do_reordering:
reordered_conversations_ds = spell_checked_conversations_ds.map(
swap_messages_if_needed_in_conversation
)
else:
reordered_conversations_ds = spell_checked_conversations_ds
# For the contact_name, rewrite everything that is not 'my_whatsapp_name' to 'Other'
def rewrite_contact_name(conversation):
for message in conversation["conversations"]:
if message["contact_name"] != whatsapp_name:
message["contact_name"] = "Other"
return conversation
changed_contact_name_ds = reordered_conversations_ds.map(
rewrite_contact_name
) # , num_proc=os.cpu_count() - 1)
# Filter out conversations with only one contact
changed_contact_name_ds = changed_contact_name_ds.filter(
lambda x: len(set([msg["contact_name"] for msg in x["conversations"]])) > 1
)
return changed_contact_name_ds
SPLIT_CONVERSATION_THRESHOLD = 40
MAX_CHARACTERS_PER_MESSAGE = 10000 # Max is 8,192 tokens (https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-about#sample-datasets)
def transform_conversations_dataset_into_training_examples(
conversations_ds, system_prompt, user_role, model_role, whatsapp_name
):
"""
Takes in a dataset with conversations and returns a dataset with training examples.
The input dataset contains a single column (conversations), with each row being a list of messages with this format:
```
[{'contact_name': 'Aldi', 'message': <message>, 'timestamp': <time>}, {'contact_name': 'Other', 'message': <message>, 'timestamp': <time>}, ... ]
```
Each row will be converted to fit the format of the training examples.
The training examples have the following format:
```
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "William Shakespeare"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]}
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "384,400 kilometers"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]}
```
"""
def process_examples(examples):
processed_examples = []
for conversation in examples["conversations"]:
messages = [{"role": "system", "content": [system_prompt]}]
counter = 0
for msg in conversation:
converted_role = (
model_role if msg["contact_name"] == whatsapp_name else user_role
)
if (
counter > SPLIT_CONVERSATION_THRESHOLD
and converted_role == user_role
):
processed_examples.append(
{
"messages": [
{
"role": m["role"],
"content": json.dumps(
m["content"], ensure_ascii=False
),
}
for m in messages
]
}
)
messages = [{"role": "system", "content": [system_prompt]}]
counter = 0
if converted_role == messages[-1]["role"]:
messages[-1]["content"] += [msg["message"]]
else:
messages.append(
{"role": converted_role, "content": [msg["message"]]}
)
counter += 1
if len(messages) >= MIN_MESSAGES_THRESHOLD:
processed_examples.append(
{
"messages": [
{
"role": m["role"],
"content": json.dumps(m["content"], ensure_ascii=False),
}
for m in messages
]
}
)
else:
logger.warning(
f"Discarding conversation because the length is not at least {MIN_MESSAGES_THRESHOLD}: {messages}"
)
# Before returning, flatten the list of dictionaries into a dictionary of lists
flattened_examples = {}
for key in processed_examples[0].keys():
flattened_examples[key] = [d[key] for d in processed_examples]
return flattened_examples
processed_examples = conversations_ds.map(
process_examples,
remove_columns=["conversations"],
# num_proc=os.cpu_count() - 1,
batched=True,
)
examples_filtered_by_length = processed_examples.filter(
lambda x: all(
[len(m["content"]) < MAX_CHARACTERS_PER_MESSAGE for m in x["messages"]]
)
)
return examples_filtered_by_length