hynky's picture
hynky HF staff
UI overhaul + seapration of concerns
40e38d3
raw
history blame
3.6 kB
from functools import partial
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import gradio as gr
from typing import Dict, List
from .data_processing import prepare_for_non_grouped_plotting, prepare_for_group_plotting
from .utils import set_alpha
def plot_scatter(
data: Dict[str, Dict[float, float]],
metric_name: str,
log_scale_x: bool,
log_scale_y: bool,
normalization: bool,
rounding: int,
cumsum: bool,
perc: bool,
progress: gr.Progress,
):
fig = go.Figure()
data = {name: histogram for name, histogram in sorted(data.items())}
for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")):
histogram_prepared = prepare_for_non_grouped_plotting(histogram, normalization, rounding)
x = sorted(histogram_prepared.keys())
y = [histogram_prepared[k] for k in x]
if cumsum:
y = np.cumsum(y).tolist()
if perc:
y = (np.array(y) * 100).tolist()
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="lines",
name=name,
marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)),
)
)
yaxis_title = "Frequency" if normalization else "Total"
fig.update_layout(
title=f"Line Plots for {metric_name}",
xaxis_title=metric_name,
yaxis_title=yaxis_title,
xaxis_type="log" if log_scale_x and len(x) > 1 else None,
yaxis_type="log" if log_scale_y and len(y) > 1 else None,
width=1200,
height=600,
showlegend=True,
)
return fig
def plot_bars(
data: Dict[str, List[Dict[str, float]]],
metric_name: str,
top_k: int,
direction: str,
regex: str | None,
rounding: int,
log_scale_x: bool,
log_scale_y: bool,
progress: gr.Progress,
):
fig = go.Figure()
x = []
y = []
for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")):
x, y, stds = prepare_for_group_plotting(histogram, top_k, direction, regex, rounding)
fig.add_trace(go.Bar(
x=x,
y=y,
name=f"{name} Mean",
marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)),
error_y=dict(type='data', array=stds, visible=True)
))
fig.update_layout(
title=f"Bar Plots for {metric_name}",
xaxis_title=metric_name,
yaxis_title="Avg. value",
xaxis_type="log" if log_scale_x and len(x) > 1 else None,
yaxis_type="log" if log_scale_y and len(y) > 1 else None,
autosize=True,
width=1200,
height=600,
showlegend=True,
)
return fig
def plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x, log_scale_y,
cumsum, perc, progress=gr.Progress()):
if rounding is None or top_k is None:
return None
graph_fc = (
partial(plot_scatter, normalization=normalization, rounding=rounding, cumsum=cumsum, perc=perc)
if grouping == "histogram"
else partial(plot_bars, top_k=top_k, direction=direction, regex=regex, rounding=rounding)
)
return graph_fc(data=data, metric_name=metric_name, progress=progress, log_scale_x=log_scale_x,
log_scale_y=log_scale_y)