|
""" |
|
This script is largely copied from the Vicuna repo: https://github.com/lm-sys/FastChat/blob/main/fastchat/data/split_long_conversation.py |
|
We fixed a bug in `split_one_sample`, which previously includes long conversations in the processed data. Now we skip these long conversations. |
|
""" |
|
import argparse |
|
from concurrent.futures import ProcessPoolExecutor |
|
import json |
|
import transformers |
|
from tqdm import tqdm |
|
|
|
def shareGPT_pipeline(tokenizer, raw_datasets, overwrite_cache): |
|
|
|
def preprocess_conversation(convo): |
|
key_mapping = {"role" : "from", "content" : "value"} |
|
value_mapping = {"user" : "user", "human" : "user", "gpt" : "assistant", 'system': 'assitant', 'bing': 'assitant', 'chatgpt': 'assitant', 'bard': 'assitant'} |
|
|
|
if value_mapping[convo[0][key_mapping['role']]] != 'user': |
|
convo = convo[1:] |
|
preproc_convos_user = [{"role": 'user', "content": convo_elem[key_mapping['content']]} for i, convo_elem in enumerate(convo) if (i % 2 == 0 and value_mapping[convo_elem[key_mapping['role']]] == 'user')] |
|
preproc_convos_assistant = [{"role": 'assistant', "content": convo_elem[key_mapping['content']]} for i, convo_elem in enumerate(convo) if (i % 2 == 1 and value_mapping[convo_elem[key_mapping['role']]] == 'assistant')] |
|
if len(preproc_convos_user) != len(preproc_convos_assistant): |
|
return [] |
|
preproc_convos = [conv_elem for pair in zip(preproc_convos_user, preproc_convos_assistant) for conv_elem in pair] |
|
return preproc_convos |
|
|
|
def filter_incorrect_conversations(examples): |
|
convos = examples["conversations"] |
|
ids_to_remove = [True if preprocess_conversation(convo) == [] else False for convo in convos] |
|
return { "ids_to_remove" : ids_to_remove, } |
|
|
|
def formatting_prompts_func(examples): |
|
convos = examples["conversations"] |
|
|
|
preproc_convos = [preprocess_conversation(convo) for convo in convos] |
|
|
|
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for i, convo in enumerate(preproc_convos)] |
|
return { "text" : texts,} |
|
|
|
filtered_datasets = raw_datasets.filter(lambda example: example['conversations'] != [], load_from_cache_file=not overwrite_cache,) |
|
dataset = filtered_datasets.map(filter_incorrect_conversations, batched = True, load_from_cache_file=not overwrite_cache,) |
|
filtered_datasets2 = dataset.filter(lambda example: example['ids_to_remove'] == False, load_from_cache_file=not overwrite_cache,) |
|
raw_datasets_preprocessed = filtered_datasets2.map(formatting_prompts_func, batched = True, load_from_cache_file=not overwrite_cache,) |
|
|
|
return raw_datasets_preprocessed |