|
""" |
|
Get stats of a dataset. |
|
|
|
Usage: python3 -m fastchat.data.get_stats --in sharegpt.json |
|
""" |
|
|
|
import argparse |
|
from concurrent.futures import ProcessPoolExecutor |
|
import json |
|
|
|
import numpy as np |
|
from tqdm import tqdm |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
K = 1e3 |
|
M = 1e6 |
|
|
|
|
|
def tokenize_one_sample(c): |
|
for i in range(len(c["conversations"])): |
|
v = c["conversations"][i]["value"] |
|
c["conversations"][i]["value"] = tokenizer.tokenize(v) |
|
return c |
|
|
|
|
|
def tokenize_dataset(content): |
|
processed = [] |
|
with ProcessPoolExecutor() as executor: |
|
for result in tqdm( |
|
executor.map(tokenize_one_sample, content), total=len(content) |
|
): |
|
processed.append(result) |
|
|
|
return processed |
|
|
|
|
|
def compute_stats(content): |
|
sample_lens = [] |
|
sample_turns = [] |
|
prompt_lens = [] |
|
res_lens = [] |
|
|
|
for c in content: |
|
sample_len = 0 |
|
sample_turns.append(len(c["conversations"]) // 2) |
|
for i in range(len(c["conversations"]) // 2): |
|
p = c["conversations"][i * 2]["value"] |
|
r = c["conversations"][i * 2 + 1]["value"] |
|
|
|
turn_len = len(p) + len(r) |
|
sample_len += turn_len |
|
prompt_lens.append(len(p)) |
|
res_lens.append(len(r)) |
|
sample_lens.append(sample_len) |
|
|
|
return sample_lens, sample_turns, prompt_lens, res_lens |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--in-file", type=str) |
|
parser.add_argument( |
|
"--model-name-or-path", type=str, default="meta-llama/Llama-2-7b-chat-hf" |
|
) |
|
args = parser.parse_args() |
|
|
|
content = json.load(open(args.in_file, "r")) |
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False) |
|
content = tokenize_dataset(content) |
|
|
|
sample_lens, sample_turns, prompt_lens, res_lens = compute_stats(content) |
|
print(f"#sequence: {len(content)/K:.2f} K") |
|
print(f"#tokens: {np.sum(sample_lens)/M:.2f} M") |
|
print(f"avg. turns: {np.mean(sample_turns):.2f}") |
|
print(f"avg. prompt length: {np.mean(prompt_lens):.2f}") |
|
print(f"avg. response length: {np.mean(res_lens):.2f}") |
|
|
|
print("\n- Histogram -") |
|
bin_edges = [0, 1024, 2048, 4096, 8192, 16384, 32768] |
|
hist = np.histogram(sample_lens, bins=bin_edges)[0] |
|
for i in range(len(hist)): |
|
print(f"L{bin_edges[i]} - {bin_edges[i+1]}: {hist[i]}") |
|
|