File size: 3,741 Bytes
5472531 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
"""
Split long conversations based on certain max length.
Usage: python3 -m fastchat.data.split_long_conversation \
--in sharegpt_clean.json \
--out sharegpt_split.json \
--model-name-or-path $<model-name>
"""
import argparse
from concurrent.futures import ProcessPoolExecutor
import json
from typing import Dict, Sequence, Optional
import transformers
from tqdm import tqdm
def make_sample(sample, start_idx, end_idx):
assert (end_idx - start_idx) % 2 == 0
return {
"id": sample["id"] + "_" + str(start_idx),
"model": sample.get("model", ""),
"conversations": sample["conversations"][start_idx:end_idx],
}
tokenizer = max_length = None
def split_one_sample(sample):
tokenized_lens = []
conversations = sample["conversations"]
conversations = conversations[: len(conversations) // 2 * 2]
for c in conversations:
length = len(tokenizer(c["value"]).input_ids) + 6
tokenized_lens.append(length)
start_idx = 0
cur_len = 0
if len(conversations) % 2 != 0 or len(conversations) < 2:
return []
new_samples = []
for i in range(0, len(conversations), 2):
tmp_len = tokenized_lens[i] + tokenized_lens[i + 1]
if cur_len + tmp_len > max_length:
new_samples.append(make_sample(sample, start_idx, i))
start_idx = i
cur_len = 0
elif i == len(conversations) - 2:
new_samples.append(make_sample(sample, start_idx, i + 2))
cur_len += tmp_len
return new_samples
def worker(input_data):
result = []
for sample in input_data:
result.extend(split_one_sample(sample))
return result
def split_all(content, begin, end, tokenizer_, max_length_):
"""
Keep the maximum round of conversations within the max token length constraint
"""
global tokenizer, max_length
tokenizer = tokenizer_
max_length = max_length_
content = content[begin:end]
new_content = []
# Split content into chunks
chunks = [content[i : i + 1000] for i in range(0, len(content), 1000)]
with ProcessPoolExecutor() as executor:
for result in tqdm(executor.map(worker, chunks), total=len(chunks)):
new_content.extend(result)
return new_content
def filter_invalid_roles(content):
new_content = []
for i, c in enumerate(content):
roles = ["human", "gpt"]
if len(c["conversations"]) <= 0:
continue
valid = True
for j, s in enumerate(c["conversations"]):
if s["from"] != roles[j % 2]:
valid = False
break
if valid:
new_content.append(c)
return new_content
def main(args):
content = json.load(open(args.in_file, "r"))
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.model_name_or_path,
model_max_length=args.max_length,
padding_side="right",
use_fast=False,
)
new_content = split_all(content, args.begin, args.end, tokenizer, args.max_length)
new_content = filter_invalid_roles(new_content)
print(f"#in: {len(content)}, #out: {len(new_content)}")
json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--in-file", type=str, required=True)
parser.add_argument("--out-file", type=str, default="sharegpt_split.json")
parser.add_argument("--begin", type=int)
parser.add_argument("--end", type=int)
parser.add_argument("--model-name-or-path", type=str, required=True)
parser.add_argument("--max-length", type=int, default=2048)
args = parser.parse_args()
main(args)
|