Spaces:
Running
Running
import { getColor, COLORS } from "./colors.mjs" | |
import Plotly from "plotly.js-basic-dist-min" | |
import _ from "lodash" | |
const DATA_FOLDER = "assets/data/plots" | |
const LINE_SETTINGS = { | |
width: 2.5, | |
type: "scatter", | |
mode: "lines", | |
} | |
const BAR_SETTINGS = { | |
width: 0.5, | |
type: "bar", | |
opacity: 0.9, | |
marker: { | |
line: { | |
width: 1.0 | |
} | |
} | |
} | |
const METRIC_ID_TO_PRIORITY = { | |
"agg_score": 0, | |
"hellaswag/acc_norm": 1, | |
"arc/acc_norm": 2, | |
"mmlu/acc_norm": 3, | |
"openbookqa/acc_norm": 4, | |
"commonsense_qa/acc_norm": 5, | |
"piqa/acc_norm": 6, | |
"siqa/acc_norm": 7, | |
"winogrande/acc_norm": 8, | |
// Stats | |
"lines_ended_with_punct": 0, | |
"lines_chars": 1, | |
"short_lines": 2, | |
} | |
const TASK_ID_TO_NAME = { | |
// Ablations | |
agg_score: "Aggregate Score", | |
"commonsense_qa/acc_norm": "Commonsense QA", | |
"hellaswag/acc_norm": "HellaSwag", | |
"openbookqa/acc_norm": "OpenBook QA", | |
"piqa/acc_norm": "PIQA", | |
"siqa/acc_norm": "Social IQA", | |
"winogrande/acc_norm": "WinoGrande", | |
"arc/acc_norm": "ARC", | |
"mmlu/acc_norm": "MMLU", | |
// Stats | |
"lines_ended_with_punct": "Lines Ended With Punctuation", | |
"lines_chars": "Lines Chars", | |
"short_lines": "Short Lines", | |
}; | |
const DATASET_ID_TO_NAME = { | |
pii_removed: "Fineweb", | |
allenai_c4_en: "C4", | |
"tiiuae_falcon-refinedweb_data": "RefinedWeb", | |
"red-pajama-v2_jsonl-deduplicated-extract": "RedPajamaV2", | |
"dolma-sample": "Dolma1.6", | |
dedup_minhash_independent_output: "Independent Dedup MinHash", | |
"dedup_minhash_CC-MAIN-2013-48_output": "Full MinHash CC-MAIN-2013-48", | |
"dedup_minhash_independent_output_CC-MAIN-2013-48": "Independent MinHash CC-MAIN-2013-48", | |
"ind_minhash-CC-MAIN-2019-18": "Independent MinHash CC-MAIN-2019-18", | |
"wet-extraction-2019-18": "WET Extraction 2019-18", | |
"dedup_minhash_CC-MAIN-2013-48_output": "Full MinHash CC-MAIN-2013-48", | |
"dedup_minhash_independent_output_CC-MAIN-2013-48": "Independent MinHash CC-MAIN-2013-48", | |
}; | |
const DEFAULT_SETTINGS = { | |
slider: { | |
max: 30, | |
min: 0, | |
default: 0, | |
}, | |
defaultMetric: "agg_score", | |
type: "line" | |
}; | |
const DEFAULT_LAYOUT = { | |
font: { | |
family: "apple-system, Arial, sans-serif", | |
}, | |
title: { | |
text: "Plot Title", | |
font: { | |
size: 19, | |
family: "apple-system, Arial, sans-serif", | |
}, | |
}, | |
xaxis: { | |
title: { | |
text: "Training tokens (billions)", | |
font: { | |
size: 15, | |
family: "apple-system, Arial, sans-serif", | |
}, | |
}, | |
tickfont: { | |
size: 14, | |
family: "apple-system, Arial, sans-serif", | |
}, | |
showgrid: false, | |
mirror: true, | |
ticks: "outside", | |
showline: true, | |
}, | |
yaxis: { | |
title: { | |
text: "Agg Score", | |
font: { | |
size: 15, | |
family: "apple-system, Arial, sans-serif", | |
}, | |
standoff: 10, | |
}, | |
showgrid: false, | |
mirror: true, | |
ticks: "outside", | |
showline: true, | |
tickfont: { | |
size: 14, | |
family: "apple-system, Arial, sans-serif", | |
}, | |
}, | |
yaxis2: { | |
title: { | |
text: "Words Contamination", | |
font: { | |
size: 15, | |
family: "apple-system, Arial, sans-serif", | |
}, | |
standoff: 10, | |
}, | |
tickfont: { | |
size: 14, | |
family: "apple-system, Arial, sans-serif", | |
}, | |
showgrid: false, | |
ticks: "outside", | |
}, | |
legend: { | |
orientation: "v", | |
xanchor: "right", | |
yanchor: "bottom", | |
x: 1, | |
y: 0, | |
font: { | |
size: 14, | |
family: "apple-system, Arial, sans-serif", | |
}, | |
bgcolor: "rgba(0,0,0,0)", | |
}, | |
margin: { | |
t: 30, | |
b: 50, | |
}, | |
height: 400, | |
}; | |
const getAutoRange = (traces) => { | |
let minX = Math.min(...traces.flatMap((trace) => trace.x)); | |
let maxX = Math.max(...traces.flatMap((trace) => trace.x)); | |
return [minX * 0.95, maxX * 1.05]; | |
}; | |
const getColorForTrace = (traceName, colorsMapping, index) => { | |
// First check if the color already exists in colorsMaping and if so return it | |
const reusedColor = colorsMapping.get(traceName) | |
if (reusedColor) { | |
return reusedColor | |
} | |
let color = getColor(index) | |
while (colorsMapping.has(color) && index < COLORS.length) { | |
index += 1 | |
color = getColor(index) | |
} | |
colorsMapping.set(traceName, color) | |
return color | |
} | |
const createAblationPlottingElements = ( | |
plotElement, | |
indexMapping, | |
settings | |
) => { | |
const plot = document.createElement("figure"); | |
const controls = document.createElement("div"); | |
plot.classList.add("plotly"); | |
controls.classList.add("plotly_controls"); | |
plotElement.appendChild(plot); | |
plotElement.appendChild(controls); | |
const metricOptions = Object.keys(indexMapping).filter( | |
(metric) => metric in TASK_ID_TO_NAME | |
); | |
// Dropdown | |
let dropdown = undefined | |
if (metricOptions.length > 1) { | |
const dropdownLabel = document.createElement("label"); | |
dropdownLabel.textContent = "Metric:"; | |
dropdown = document.createElement("select"); | |
dropdown.innerHTML = metricOptions | |
.sort((a, b) => (METRIC_ID_TO_PRIORITY[a] ?? 0) - (METRIC_ID_TO_PRIORITY[b] ?? 0)) | |
.map( | |
(option) => | |
`<option value="${option}">${TASK_ID_TO_NAME[option]}</option>` | |
) | |
.join(""); | |
dropdown.value = settings.defaultMetric; | |
const dropdownContainer = document.createElement("div"); | |
dropdownContainer.classList.add("plotly_input_container"); | |
dropdownContainer.appendChild(dropdownLabel); | |
dropdownContainer.appendChild(dropdown); | |
controls.appendChild(dropdownContainer); | |
} | |
let slider = undefined; | |
if (settings.slider !== null) { | |
const sliderLabel = document.createElement("label"); | |
sliderLabel.textContent = "Rolling window:"; | |
slider = document.createElement("input"); | |
slider.type = "range"; | |
slider.min = settings.slider.min; | |
slider.max = settings.slider.max; | |
slider.value = settings.slider.default; | |
// current value | |
const sliderValue = document.createElement("span"); | |
sliderValue.textContent = slider.value; | |
slider.addEventListener("input", () => { | |
sliderValue.textContent = slider.value; | |
}); | |
const sliderInputContainer = document.createElement("div"); | |
sliderInputContainer.classList.add("plotly_slider"); | |
sliderInputContainer.appendChild(slider); | |
sliderInputContainer.appendChild(sliderValue); | |
const sliderContainer = document.createElement("div"); | |
sliderContainer.classList.add("plotly_input_container"); | |
sliderContainer.appendChild(sliderLabel); | |
sliderContainer.appendChild(sliderInputContainer); | |
controls.appendChild(sliderContainer); | |
} | |
let caption = undefined | |
if (settings.caption) { | |
caption = document.createElement("figcaption"); | |
caption.classList.add("plotly_caption"); | |
caption.textContent = settings.caption; | |
plotElement.appendChild(caption); | |
} | |
return { dropdown, slider, plot, caption }; | |
}; | |
const rollingWindow = function (data, windowSize) { | |
if (windowSize === 0) { | |
return data; | |
} | |
const rollingData = []; | |
// Start at halfWindowSize to ensure we can get a full window | |
for (let i = windowSize; i < data.length; i++) { | |
const windowStart = i - windowSize; | |
const windowEnd = i; | |
const windowData = data.slice(windowStart, windowEnd); | |
const windowAverage = | |
windowData.reduce((acc, value) => acc + value, 0) / | |
windowData.length; | |
rollingData.push(windowAverage); | |
} | |
return rollingData; | |
}; | |
const createTraces = (data, settings, colorsMapping, sliderValue) => { | |
if (!data) { | |
return [] | |
} | |
const res = Array.from(Object.entries(data)).map(([key, traceData], index) => { | |
const y = rollingWindow(traceData.y, sliderValue); | |
const x = traceData.x.slice(0, y.length); | |
const plotSettings = settings?.type === "bar" ? BAR_SETTINGS : LINE_SETTINGS; | |
const traceColor = traceData.color ?? getColorForTrace(key, colorsMapping, index) | |
const trace = _.merge({}, { | |
x: x, | |
y: y, | |
name: traceData.label ?? DATASET_ID_TO_NAME[key] ?? key, | |
marker: { | |
color: traceColor, | |
}, | |
line: { | |
color: traceColor, | |
}, | |
yaxis: traceData.yaxis ?? "y1" | |
}, plotSettings); | |
return trace | |
}); | |
return res | |
} | |
export const init_ablation_plot = function () { | |
const plotElements = document.querySelectorAll('[id^="plot-"]'); | |
plotElements.forEach(async (plotElement) => { | |
const plotName = plotElement.id.replace("plot-", ""); | |
const indexData = await fetch(`${DATA_FOLDER}/${plotName}/index.json`).then( | |
(response) => response.json() | |
); | |
const settings = _.merge({}, DEFAULT_SETTINGS, indexData.settings); | |
const indexMapping = indexData.files; | |
const { dropdown, slider, plot } = createAblationPlottingElements( | |
plotElement, | |
indexMapping, | |
settings | |
); | |
plot.id = `graph-${plotName}`; | |
if (dropdown !== undefined) { | |
dropdown.addEventListener("change", () => updatePlot(dropdown, slider)); | |
} | |
let timeoutId; | |
// Debounce the slider | |
if (slider !== undefined) { | |
slider.addEventListener("input", () => { | |
clearTimeout(timeoutId); | |
timeoutId = setTimeout(() => { | |
updatePlot(dropdown, slider); | |
}, 500); | |
}); | |
} | |
// Shared plot | |
Plotly.newPlot(plot, []); | |
// This is to ensure that the colors are consistent acrros different metrics | |
const colorsMapping = new Map() | |
async function updatePlot(dropdown, slider) { | |
const metricName = dropdown?.value ?? settings.defaultMetric; | |
const sliderValue = parseInt(slider?.value ?? 0); | |
const metricData = await fetch( | |
`${DATA_FOLDER}/${plotName}/${indexMapping[metricName]["file"]}` | |
).then((response) => response.json()); | |
const traces = (metricData?.traces ?? []).concat(createTraces(metricData.data, settings, colorsMapping, sliderValue)) | |
const width = plot.parentElement.offsetWidth; | |
const layout = _.merge( | |
{}, | |
DEFAULT_LAYOUT, | |
{ | |
width: width, | |
yaxis: { title: { text: TASK_ID_TO_NAME[metricName] } }, | |
xaxis: { | |
range: null | |
}, | |
}, | |
metricData.layout | |
); | |
Plotly.react(plot, traces, layout); | |
window.addEventListener("resize", () => { | |
// If the window size is smaller than 768, we don't care as it's not shown | |
if (window.innerWidth < 768) { | |
return; | |
} | |
Plotly.relayout(plot, { | |
width: plot.parentElement.offsetWidth, | |
}); | |
}); | |
} | |
// Initial plot | |
updatePlot(dropdown, slider); | |
}); | |
}; |