""" This script is largely copied from the Vicuna repo: https://github.com/lm-sys/FastChat/blob/main/fastchat/data/split_long_conversation.py We fixed a bug in `split_one_sample`, which previously includes long conversations in the processed data. Now we skip these long conversations. """ import argparse from concurrent.futures import ProcessPoolExecutor import json import transformers from tqdm import tqdm def shareGPT_pipeline(tokenizer, raw_datasets, overwrite_cache): def preprocess_conversation(convo): key_mapping = {"role" : "from", "content" : "value"} value_mapping = {"user" : "user", "human" : "user", "gpt" : "assistant", 'system': 'assitant', 'bing': 'assitant', 'chatgpt': 'assitant', 'bard': 'assitant'} # mapping = {"human" : "user", "gpt" : "assitant"} if value_mapping[convo[0][key_mapping['role']]] != 'user': convo = convo[1:] preproc_convos_user = [{"role": 'user', "content": convo_elem[key_mapping['content']]} for i, convo_elem in enumerate(convo) if (i % 2 == 0 and value_mapping[convo_elem[key_mapping['role']]] == 'user')] preproc_convos_assistant = [{"role": 'assistant', "content": convo_elem[key_mapping['content']]} for i, convo_elem in enumerate(convo) if (i % 2 == 1 and value_mapping[convo_elem[key_mapping['role']]] == 'assistant')] if len(preproc_convos_user) != len(preproc_convos_assistant): return [] preproc_convos = [conv_elem for pair in zip(preproc_convos_user, preproc_convos_assistant) for conv_elem in pair] return preproc_convos def filter_incorrect_conversations(examples): convos = examples["conversations"] ids_to_remove = [True if preprocess_conversation(convo) == [] else False for convo in convos] return { "ids_to_remove" : ids_to_remove, } def formatting_prompts_func(examples): convos = examples["conversations"] # preproc_convos = [convo for convo in convos if (convo[0]['from'] == 'human' or convo[0]['from'] == 'user')] preproc_convos = [preprocess_conversation(convo) for convo in convos] # preproc_convos2 = [preproc_convo for preproc_convo in preproc_convos if preproc_convo[0]['role'] == 'user'] texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for i, convo in enumerate(preproc_convos)] return { "text" : texts,} filtered_datasets = raw_datasets.filter(lambda example: example['conversations'] != [], load_from_cache_file=not overwrite_cache,) dataset = filtered_datasets.map(filter_incorrect_conversations, batched = True, load_from_cache_file=not overwrite_cache,) filtered_datasets2 = dataset.filter(lambda example: example['ids_to_remove'] == False, load_from_cache_file=not overwrite_cache,) raw_datasets_preprocessed = filtered_datasets2.map(formatting_prompts_func, batched = True, load_from_cache_file=not overwrite_cache,) return raw_datasets_preprocessed