|
from functools import partial |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
import numpy as np |
|
import gradio as gr |
|
from typing import Dict, List |
|
from .data_processing import prepare_for_non_grouped_plotting, prepare_for_group_plotting |
|
from .utils import set_alpha |
|
|
|
def plot_scatter( |
|
data: Dict[str, Dict[float, float]], |
|
metric_name: str, |
|
log_scale_x: bool, |
|
log_scale_y: bool, |
|
normalization: bool, |
|
rounding: int, |
|
cumsum: bool, |
|
perc: bool, |
|
progress: gr.Progress, |
|
): |
|
fig = go.Figure() |
|
data = {name: histogram for name, histogram in sorted(data.items())} |
|
for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")): |
|
histogram_prepared = prepare_for_non_grouped_plotting(histogram, normalization, rounding) |
|
x = sorted(histogram_prepared.keys()) |
|
y = [histogram_prepared[k] for k in x] |
|
if cumsum: |
|
y = np.cumsum(y).tolist() |
|
if perc: |
|
y = (np.array(y) * 100).tolist() |
|
|
|
fig.add_trace( |
|
go.Scatter( |
|
x=x, |
|
y=y, |
|
mode="lines", |
|
name=name, |
|
marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)), |
|
) |
|
) |
|
|
|
yaxis_title = "Frequency" if normalization else "Total" |
|
|
|
fig.update_layout( |
|
title=f"Line Plots for {metric_name}", |
|
xaxis_title=metric_name, |
|
yaxis_title=yaxis_title, |
|
xaxis_type="log" if log_scale_x and len(x) > 1 else None, |
|
yaxis_type="log" if log_scale_y and len(y) > 1 else None, |
|
width=1200, |
|
height=600, |
|
showlegend=True, |
|
) |
|
|
|
return fig |
|
|
|
def plot_bars( |
|
data: Dict[str, List[Dict[str, float]]], |
|
metric_name: str, |
|
top_k: int, |
|
direction: str, |
|
regex: str | None, |
|
rounding: int, |
|
log_scale_x: bool, |
|
log_scale_y: bool, |
|
progress: gr.Progress, |
|
): |
|
fig = go.Figure() |
|
x = [] |
|
y = [] |
|
|
|
for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")): |
|
x, y, stds = prepare_for_group_plotting(histogram, top_k, direction, regex, rounding) |
|
|
|
fig.add_trace(go.Bar( |
|
x=x, |
|
y=y, |
|
name=f"{name} Mean", |
|
marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)), |
|
error_y=dict(type='data', array=stds, visible=True) |
|
)) |
|
|
|
fig.update_layout( |
|
title=f"Bar Plots for {metric_name}", |
|
xaxis_title=metric_name, |
|
yaxis_title="Avg. value", |
|
xaxis_type="log" if log_scale_x and len(x) > 1 else None, |
|
yaxis_type="log" if log_scale_y and len(y) > 1 else None, |
|
autosize=True, |
|
width=1200, |
|
height=600, |
|
showlegend=True, |
|
) |
|
|
|
return fig |
|
|
|
def plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x, log_scale_y, |
|
cumsum, perc, progress=gr.Progress()): |
|
if rounding is None or top_k is None: |
|
return None |
|
graph_fc = ( |
|
partial(plot_scatter, normalization=normalization, rounding=rounding, cumsum=cumsum, perc=perc) |
|
if grouping == "histogram" |
|
else partial(plot_bars, top_k=top_k, direction=direction, regex=regex, rounding=rounding) |
|
) |
|
return graph_fc(data=data, metric_name=metric_name, progress=progress, log_scale_x=log_scale_x, |
|
log_scale_y=log_scale_y) |