from functools import partial import plotly.express as px import plotly.graph_objects as go import numpy as np import gradio as gr from typing import Dict, List from .data_processing import prepare_for_non_grouped_plotting, prepare_for_group_plotting from .utils import set_alpha def plot_scatter( data: Dict[str, Dict[float, float]], metric_name: str, log_scale_x: bool, log_scale_y: bool, normalization: bool, rounding: int, cumsum: bool, perc: bool, progress: gr.Progress, ): fig = go.Figure() data = {name: histogram for name, histogram in sorted(data.items())} for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")): histogram_prepared = prepare_for_non_grouped_plotting(histogram, normalization, rounding) x = sorted(histogram_prepared.keys()) y = [histogram_prepared[k] for k in x] if cumsum: y = np.cumsum(y).tolist() if perc: y = (np.array(y) * 100).tolist() fig.add_trace( go.Scatter( x=x, y=y, mode="lines", name=name, marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)), ) ) yaxis_title = "Frequency" if normalization else "Total" fig.update_layout( title=f"Line Plots for {metric_name}", xaxis_title=metric_name, yaxis_title=yaxis_title, xaxis_type="log" if log_scale_x and len(x) > 1 else None, yaxis_type="log" if log_scale_y and len(y) > 1 else None, width=1200, height=600, showlegend=True, ) return fig def plot_bars( data: Dict[str, List[Dict[str, float]]], metric_name: str, top_k: int, direction: str, regex: str | None, rounding: int, log_scale_x: bool, log_scale_y: bool, progress: gr.Progress, ): fig = go.Figure() x = [] y = [] for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")): x, y, stds = prepare_for_group_plotting(histogram, top_k, direction, regex, rounding) fig.add_trace(go.Bar( x=x, y=y, name=f"{name} Mean", marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)), error_y=dict(type='data', array=stds, visible=True) )) fig.update_layout( title=f"Bar Plots for {metric_name}", xaxis_title=metric_name, yaxis_title="Avg. value", xaxis_type="log" if log_scale_x and len(x) > 1 else None, yaxis_type="log" if log_scale_y and len(y) > 1 else None, autosize=True, width=1200, height=600, showlegend=True, ) return fig def plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x, log_scale_y, cumsum, perc, progress=gr.Progress()): if rounding is None or top_k is None: return None graph_fc = ( partial(plot_scatter, normalization=normalization, rounding=rounding, cumsum=cumsum, perc=perc) if grouping == "histogram" else partial(plot_bars, top_k=top_k, direction=direction, regex=regex, rounding=rounding) ) return graph_fc(data=data, metric_name=metric_name, progress=progress, log_scale_x=log_scale_x, log_scale_y=log_scale_y)