File size: 3,500 Bytes
17ab7e0
 
 
 
 
 
efcf813
17ab7e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import pandas as pd
import plotly.express as px
import gradio as gr
import urllib.parse
import plotly.graph_objects as go
import numpy as np



def read_google_sheet(sheet_id, sheet_name):
    # URL encode the sheet name
    encoded_sheet_name = urllib.parse.quote(sheet_name)
    
    # Construct the base URL
    base_url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/gviz/tq?tqx=out:csv&sheet={encoded_sheet_name}"
    
    try:
        # Read the sheet into a pandas DataFrame
        df = pd.read_csv(base_url)
        return df
    except Exception as e:
        print(f"An error occurred: {e}")
        return None

# Function to generate tick values and labels
def log2_ticks(values):
    min_val, max_val = np.floor(values.min()), np.ceil(values.max())
    print(max_val, min_val)
    tick_vals = np.arange(min_val, max_val+1)
    tick_text = [f"{2**val:.0f}" for val in tick_vals]
    return tick_vals, tick_text

# Load data
sheet_id = "1g07tdGf9ocOZ8XZgLGepI5Q4u6ZH961J0T9O9P64rYw"
sheet_names = [f"{i} node" if i == 1 else f"{i} nodes" for i in [1, 8]]

df = pd.concat([read_google_sheet(sheet_id, sheet_name) for sheet_name in sheet_names])
df = df.rename(columns={"micro_batch_size":"mbs", "batch_accumulation_per_replica": "gradacc"})
df["tok/s/gpu"] = df["tok/s/gpu"].replace(-1, 0)
df["throughput"] = df["tok/s/gpu"]*df["nnodes"]*8



def get_figure(nodes, hide_nans):
    
    # Create a temporary DataFrame with only the rows where nnodes is 8
    df_tmp = df[df["nnodes"]==nodes].reset_index(drop=True)

    if hide_nans:
        df_tmp = df_tmp.dropna()

    # Apply log2 scale to all columns except throughput
    log_columns = ['dp', 'tp', 'pp', 'mbs', 'gradacc']
    for col in log_columns:
        df_tmp[f'log_{col}'] = np.log2(df_tmp[col])
    
    
    
    # Generate dimensions list
    dimensions = []
    for col in log_columns:
        ticks, labels = log2_ticks(df_tmp[f'log_{col}'])
        dimensions.append(
            dict(range = [df_tmp[f'log_{col}'].min(), df_tmp[f'log_{col}'].max()],
                 label = col,
                 values = df_tmp[f'log_{col}'],
                 tickvals = ticks,
                 ticktext = labels)
        )
    
    # Add throughput dimension (not log-scaled)
    dimensions.append(
        dict(range = [df_tmp['throughput'].min(), df_tmp['throughput'].max()],
             label = 'throughput', 
             values = df_tmp['throughput'])
    )
    
    fig = go.Figure(data=
        go.Parcoords(
            line = dict(color = df_tmp['throughput'],
                        colorscale = 'GnBu',
                        showscale = True,
                        cmin = df_tmp['throughput'].min(),
                        cmax = df_tmp['throughput'].max()),
            dimensions = dimensions
        )
    )
    
    # Update the layout if needed
    fig.update_layout(
        title = "3D parallel setup throughput ",
        plot_bgcolor = 'white',
        paper_bgcolor = 'white'
    )

    
    return fig


with gr.Blocks() as demo:
    title = gr.Markdown("# 3D parallel benchmark")
    with gr.Row():
        nnodes = gr.Dropdown(choices=[1, 8], label="Number of nodes", value=8)
        hide_nan = gr.Dropdown(choices=[False, True], label="Hide NaNs", value=False)

    plot = gr.Plot()
    demo.load(get_figure, [nnodes, hide_nan], [plot])
    nnodes.change(get_figure, [nnodes, hide_nan], [plot])
    hide_nan.change(get_figure, [nnodes, hide_nan], [plot])

demo.launch(show_api=False)