omwdataset / eval_result_figures.py
hunterhector's picture
add fiture to early sections too
a4dc57a
raw
history blame
No virus
2.71 kB
import os
from plotly import graph_objects as go
import pandas as pd
## Evaluation Graphs
# Load the data
all_eval_results = {}
for fname in os.listdir("data/txt360_eval"):
if fname.endswith(".csv"):
metric_name = fname.replace("CKPT Eval - ", "").replace(".csv", "")
all_eval_results[metric_name] = {}
# with open(os.path.join("data/txt360_eval", fname)) as f:
df = pd.read_csv(os.path.join("data/txt360_eval", fname))
# slimpajama_res = df.iloc[2:, 2].astype(float).fillna(0.0) # slimpajama
fineweb_res = df.iloc[2:, 1].astype(float).fillna(method="bfill") # fineweb
txt360_base = df.iloc[2:, 2].astype(float).fillna(method="bfill") # txt360-dedup-only
txt360_web_up = df.iloc[2:, 3].astype(float).fillna(method="bfill") # txt360-web-only-upsampled
txt360_all_up_stack = df.iloc[2:, 4].astype(float).fillna(method="bfill") # txt360-all-upsampled + stackv2
# each row is 20B tokens.
# all_eval_results[metric_name]["slimpajama"] = slimpajama_res
all_eval_results[metric_name]["fineweb"] = fineweb_res
all_eval_results[metric_name]["txt360-dedup-only"] = txt360_base
all_eval_results[metric_name]["txt360-web-only-upsampled"] = txt360_web_up
all_eval_results[metric_name]["txt360-all-upsampled + stackv2"] = txt360_all_up_stack
all_eval_results[metric_name]["token"] = [20 * i for i in range(len(fineweb_res))]
# Eval Result Plots
all_eval_res_figs = {}
for metric_name, res in all_eval_results.items():
fig_res = go.Figure()
# Add lines
fig_res.add_trace(go.Scatter(
x=all_eval_results[metric_name]["token"],
y=all_eval_results[metric_name]["fineweb"],
mode='lines', name='FineWeb'
))
fig_res.add_trace(go.Scatter(
x=all_eval_results[metric_name]["token"],
y=all_eval_results[metric_name]["txt360-web-only-upsampled"],
mode='lines', name='TxT360 - CC Data Upsampled'
))
fig_res.add_trace(go.Scatter(
x=all_eval_results[metric_name]["token"],
y=all_eval_results[metric_name]["txt360-dedup-only"],
mode='lines', name='TxT360 - CC Data Dedup'
))
fig_res.add_trace(go.Scatter(
x=all_eval_results[metric_name]["token"],
y=all_eval_results[metric_name]["txt360-all-upsampled + stackv2"],
mode='lines', name='TxT360 - Full Upsampled + Stack V2'
))
# Update layout
fig_res.update_layout(
title=f"{metric_name} Performance",
title_x=0.5, # Centers the title
xaxis_title="Billion Tokens",
yaxis_title=metric_name,
legend_title="Dataset",
)
all_eval_res_figs[metric_name] = fig_res