|
import datetime |
|
import json |
|
import logging |
|
import os |
|
import re |
|
import datasets |
|
import dateutil.parser |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
|
|
|
|
|
|
MINUTES_THRESHOLD = 180 |
|
MIN_MESSAGES_THRESHOLD = 5 |
|
|
|
|
|
def group_messages(messages_iterable): |
|
groups = [] |
|
current_group = [next(messages_iterable)] |
|
for message in messages_iterable: |
|
assert len(current_group) > 0 |
|
if ( |
|
message["timestamp"] - current_group[-1]["timestamp"] |
|
< MINUTES_THRESHOLD * 60 |
|
): |
|
current_group.append(message) |
|
else: |
|
groups.append(current_group) |
|
current_group = [message] |
|
groups.append(current_group) |
|
return groups |
|
|
|
|
|
def printable_conversation(conversation): |
|
return "\n".join( |
|
[f"{message['contact_name']}: {message['message']}" for message in conversation] |
|
) |
|
|
|
|
|
import contextualSpellCheck |
|
|
|
|
|
|
|
import spacy |
|
from spellchecker import SpellChecker |
|
|
|
spell = SpellChecker() |
|
|
|
nlp = spacy.load("en_core_web_sm") |
|
|
|
|
|
def spell_check_conversation(conversation): |
|
for i, message in enumerate(conversation["conversations"]): |
|
|
|
words = spell.split_words(message["message"]) |
|
logger.info(f"Words: {words}") |
|
corrected_message = [] |
|
for word in words: |
|
correction = spell.correction(word) |
|
if (correction != None) and (correction != word): |
|
logger.info(f"Spell check: {word} -> {correction}") |
|
corrected_message.append(correction) |
|
else: |
|
corrected_message.append(word) |
|
|
|
logger.info(f"Corrected message: {corrected_message}") |
|
joined_message = " ".join(corrected_message) |
|
conversation["conversations"][i]["message"] = joined_message |
|
|
|
return conversation |
|
|
|
|
|
def spell_check_conversation_spacy(conversation): |
|
|
|
nlp.add_pipe( |
|
"contextual spellchecker", |
|
config={ |
|
"model_name": "bert-base-multilingual-uncased", |
|
"max_edit_dist": 2, |
|
}, |
|
) |
|
docs = list(nlp.pipe([msg["message"] for msg in conversation["conversations"]])) |
|
for i, doc in enumerate(docs): |
|
if doc._.performed_spellCheck: |
|
logger.info(f"Spell checked: {doc.text} -> {doc._.outcome_spellCheck}") |
|
conversation["conversations"][i]["message"] = doc._.outcome_spellCheck |
|
|
|
return conversation |
|
|
|
|
|
def remove_whatapp_annotations(conversation): |
|
""" |
|
Removes the following annotations from the messages: |
|
- <This message was edited> |
|
""" |
|
for message in conversation["conversations"]: |
|
message["message"] = re.sub( |
|
r"<This message was edited>", "", message["message"] |
|
) |
|
return conversation |
|
|
|
|
|
|
|
""" |
|
Sometimes, people write concurrently in the same conversation. We'll try to detect that and reorder the messages. |
|
For example, if we have a conversation like this: |
|
A: Hi |
|
A: How are you? |
|
B: Hi |
|
B: I'm fine, thanks |
|
A: I'm fine too |
|
We'll reorder it to: |
|
A: Hi |
|
B: Hi |
|
A: How are you? |
|
B: I'm fine, thanks |
|
A: I'm fine too |
|
|
|
To do it, we'll use MobileBERT with the next sentence prediction head. We'll use the first message as the first sentence, and the second message as the second sentence. If the model predicts that the second sentence is more likely to be the next sentence, we'll swap the messages. |
|
""" |
|
|
|
import torch |
|
from transformers import AutoModelForNextSentencePrediction, AutoTokenizer |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
model = AutoModelForNextSentencePrediction.from_pretrained("bert-base-uncased") |
|
if torch.cuda.is_available(): |
|
model.cuda() |
|
|
|
|
|
def swap_messages_if_needed(message1, message2): |
|
|
|
if message1["contact_name"] == message2["contact_name"]: |
|
return message1, message2 |
|
|
|
datetime1 = datetime.datetime.fromtimestamp(message1["timestamp"]) |
|
datetime2 = datetime.datetime.fromtimestamp(message2["timestamp"]) |
|
if (datetime2 - datetime1).total_seconds() > 2 * 60: |
|
return message1, message2 |
|
|
|
if len(message1["message"].split()) < 3 or len(message2["message"].split()) < 3: |
|
return message1, message2 |
|
|
|
inputs = tokenizer(message1["message"], message2["message"], return_tensors="pt") |
|
reverse_inputs = tokenizer( |
|
message2["message"], message1["message"], return_tensors="pt" |
|
) |
|
|
|
joined_inputs = torch.cat([inputs["input_ids"], reverse_inputs["input_ids"]], dim=0) |
|
if torch.cuda.is_available(): |
|
joined_inputs = joined_inputs.cuda() |
|
with torch.no_grad(): |
|
outputs = model(input_ids=joined_inputs) |
|
|
|
|
|
logits = outputs[0] |
|
|
|
logits = torch.softmax(logits, dim=1) |
|
|
|
|
|
swap = logits[0, 0] - logits[1, 0] < -0.2 |
|
if swap: |
|
|
|
logger.info( |
|
f"Swapping messages: {message1['message']} <-> {message2['message']}" |
|
) |
|
return message2, message1 |
|
else: |
|
|
|
return message1, message2 |
|
|
|
|
|
def swap_messages_if_needed_in_conversation(conversation): |
|
|
|
if len(conversation) <= 2: |
|
return conversation |
|
new_conversation = [ |
|
conversation[0], |
|
conversation[1], |
|
] |
|
for i in range(2, len(conversation)): |
|
message1 = new_conversation[-1] |
|
message2 = conversation[i] |
|
message1, message2 = swap_messages_if_needed(message1, message2) |
|
new_conversation[-1] = message1 |
|
new_conversation.append(message2) |
|
|
|
|
|
|
|
return new_conversation |
|
|
|
|
|
test_conversation = [ |
|
{"message": "Hola!", "contact_name": "A", "timestamp": 1}, |
|
{ |
|
"message": "Está todo bien, gracias por preguntar!", |
|
"contact_name": "B", |
|
"timestamp": 2, |
|
}, |
|
{ |
|
"message": "Hola, qué tal estás? Espero que vaya todo bien por España.", |
|
"contact_name": "A", |
|
"timestamp": 3, |
|
}, |
|
] |
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
|
|
def process_chat_file(file, do_spelling_correction, whatsapp_name, datetime_dayfirst, message_line_format, do_reordering=False): |
|
""" |
|
Process a chat file and return a dataset with the conversations. |
|
""" |
|
exp = re.compile( |
|
|
|
|
|
message_line_format |
|
) |
|
|
|
def process_line(example): |
|
|
|
try: |
|
groups = exp.match(example["text"]).groupdict() |
|
timestamp = dateutil.parser.parse(groups['msg_datetime'], dayfirst=datetime_dayfirst).timestamp() |
|
return { |
|
"message": groups["message"], |
|
"contact_name": groups["contact_name"], |
|
"timestamp": timestamp, |
|
} |
|
except Exception as e: |
|
logger.exception(example["text"]) |
|
raise e |
|
|
|
ds = ( |
|
datasets.load_dataset("text", data_files=[file])["train"] |
|
.filter( |
|
|
|
lambda x: re.match( |
|
r"^\d{1,2}/\d{1,2}/\d{1,2},\s\d{2}:\d{2}\s-\s.+:", x["text"] |
|
) |
|
) |
|
.map(process_line, remove_columns=["text"]) |
|
) |
|
|
|
|
|
ds = ds.filter(lambda x: x["message"] != "<Media omitted>") |
|
|
|
groups = group_messages(iter(ds)) |
|
|
|
conversations_ds = datasets.Dataset.from_dict({"conversations": groups}) |
|
|
|
|
|
conversations_ds = conversations_ds.filter( |
|
lambda x: len(x["conversations"]) >= MIN_MESSAGES_THRESHOLD |
|
) |
|
|
|
conversations_ds_without_whatsapp_annotations = conversations_ds.map( |
|
remove_whatapp_annotations, |
|
num_proc=os.cpu_count() - 1, |
|
) |
|
|
|
if do_spelling_correction: |
|
spell_checked_conversations_ds = ( |
|
conversations_ds_without_whatsapp_annotations.map(spell_check_conversation) |
|
) |
|
else: |
|
spell_checked_conversations_ds = conversations_ds_without_whatsapp_annotations |
|
|
|
if do_reordering: |
|
reordered_conversations_ds = spell_checked_conversations_ds.map( |
|
swap_messages_if_needed_in_conversation |
|
) |
|
else: |
|
reordered_conversations_ds = spell_checked_conversations_ds |
|
|
|
|
|
def rewrite_contact_name(conversation): |
|
for message in conversation["conversations"]: |
|
if message["contact_name"] != whatsapp_name: |
|
message["contact_name"] = "Other" |
|
return conversation |
|
|
|
changed_contact_name_ds = reordered_conversations_ds.map( |
|
rewrite_contact_name |
|
) |
|
|
|
|
|
changed_contact_name_ds = changed_contact_name_ds.filter( |
|
lambda x: len(set([msg["contact_name"] for msg in x["conversations"]])) > 1 |
|
) |
|
|
|
return changed_contact_name_ds |
|
|
|
|
|
SPLIT_CONVERSATION_THRESHOLD = 40 |
|
MAX_CHARACTERS_PER_MESSAGE = 10000 |
|
|
|
|
|
def transform_conversations_dataset_into_training_examples( |
|
conversations_ds, system_prompt, user_role, model_role, whatsapp_name |
|
): |
|
""" |
|
Takes in a dataset with conversations and returns a dataset with training examples. |
|
|
|
The input dataset contains a single column (conversations), with each row being a list of messages with this format: |
|
``` |
|
[{'contact_name': 'Aldi', 'message': <message>, 'timestamp': <time>}, {'contact_name': 'Other', 'message': <message>, 'timestamp': <time>}, ... ] |
|
``` |
|
|
|
Each row will be converted to fit the format of the training examples. |
|
|
|
The training examples have the following format: |
|
``` |
|
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]} |
|
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "William Shakespeare"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]} |
|
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "384,400 kilometers"}, {"role": "user", "content": "Can you be more sarcastic?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]} |
|
``` |
|
""" |
|
|
|
def process_examples(examples): |
|
processed_examples = [] |
|
for conversation in examples["conversations"]: |
|
messages = [{"role": "system", "content": [system_prompt]}] |
|
counter = 0 |
|
for msg in conversation: |
|
converted_role = ( |
|
model_role if msg["contact_name"] == whatsapp_name else user_role |
|
) |
|
if ( |
|
counter > SPLIT_CONVERSATION_THRESHOLD |
|
and converted_role == user_role |
|
): |
|
processed_examples.append( |
|
{ |
|
"messages": [ |
|
{ |
|
"role": m["role"], |
|
"content": json.dumps( |
|
m["content"], ensure_ascii=False |
|
), |
|
} |
|
for m in messages |
|
] |
|
} |
|
) |
|
messages = [{"role": "system", "content": [system_prompt]}] |
|
counter = 0 |
|
if converted_role == messages[-1]["role"]: |
|
messages[-1]["content"] += [msg["message"]] |
|
else: |
|
messages.append( |
|
{"role": converted_role, "content": [msg["message"]]} |
|
) |
|
counter += 1 |
|
if len(messages) >= MIN_MESSAGES_THRESHOLD: |
|
processed_examples.append( |
|
{ |
|
"messages": [ |
|
{ |
|
"role": m["role"], |
|
"content": json.dumps(m["content"], ensure_ascii=False), |
|
} |
|
for m in messages |
|
] |
|
} |
|
) |
|
else: |
|
logger.warning( |
|
f"Discarding conversation because the length is not at least {MIN_MESSAGES_THRESHOLD}: {messages}" |
|
) |
|
|
|
flattened_examples = {} |
|
for key in processed_examples[0].keys(): |
|
flattened_examples[key] = [d[key] for d in processed_examples] |
|
return flattened_examples |
|
|
|
processed_examples = conversations_ds.map( |
|
process_examples, |
|
remove_columns=["conversations"], |
|
|
|
batched=True, |
|
) |
|
|
|
examples_filtered_by_length = processed_examples.filter( |
|
lambda x: all( |
|
[len(m["content"]) < MAX_CHARACTERS_PER_MESSAGE for m in x["messages"]] |
|
) |
|
) |
|
|
|
return examples_filtered_by_length |
|
|