File size: 5,274 Bytes
35378f6
 
 
 
 
69c36b6
35378f6
 
 
 
 
 
 
b345ff4
69c36b6
35378f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b345ff4
 
 
35378f6
 
 
 
4f18cc8
35378f6
 
 
 
 
69c36b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35378f6
 
 
 
69c36b6
35378f6
 
 
 
 
 
 
b345ff4
 
69c36b6
35378f6
 
 
 
 
 
69c36b6
35378f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import pandas as pd
import plotly.express as px

from src.assets.text_content import SHORT_NAMES

def plotly_plot(df:pd.DataFrame, LIST:list, ALL:list, NAMES:list, LEGEND:list, MOBILE:list ):
    '''
    Takes in a list of models for a plotly plot
    Args:
        df: A dummy dataframe of latest version
        LIST: List of models to plot
        ALL: Either [] or ["Show All Models"] - toggle view to plot all models 
        NAMES: Either [] or ["Show Names"] - toggle view to show model names on plot 
        LEGEND: Either [] or ["Show Legend"] - toggle view to show legend on plot
        MOBILE: Either [] or ["Mobile View"] - toggle view to for smaller screens
    Returns:
        Fig: plotly figure
    '''
    
    # Get list of all models and append short names column to df
    list_columns = list(df.columns)
    ALL_LIST = list(df[list_columns[0]].unique())
    short_names = label_map(ALL_LIST)
    list_short_names = list(short_names.values())
    df["Short"] = list_short_names

    if ALL:
        LIST = ALL_LIST
    # Filter dataframe based on the provided list of models
    df = df[df[list_columns[0]].isin(LIST)]
    

    if NAMES:
        fig = px.scatter(df, x=list_columns[2], y=list_columns[3], color=list_columns[0], symbol=list_columns[0],
                 color_discrete_map={"category1": "blue", "category2": "red"},
                 hover_name=list_columns[0], template="plotly_white", text="Short")
        fig.update_traces(textposition='top center')
    else:
        fig = px.scatter(df, x=list_columns[2], y=list_columns[3], color=list_columns[0], symbol=list_columns[0],
                    color_discrete_map={"category1": "blue", "category2": "red"},
                    hover_name=list_columns[0], template="plotly_white")
        
    if not LEGEND:
        fig.update_layout(showlegend=False)
        
    fig.update_layout(
        xaxis_title='% Played',
        yaxis_title='Quality Score',
        title='Overview of benchmark results',
        height=1000
    )

    fig.update_xaxes(range=[-5, 105])
    fig.update_yaxes(range=[-5, 105])

    if MOBILE:
        fig.update_layout(height=300)


    if MOBILE and LEGEND:
        fig.update_layout(height=450)
        fig.update_layout(legend=dict(
            yanchor="bottom",
            y=-5.52,
            xanchor="left",
            x=0.01
        ))

        fig.update_layout(
            xaxis_title="",
            yaxis_title="",
            title="% Played v/s Quality Score"
        )

    return fig


# ['Model', 'Clemscore', 'All(Played)', 'All(Quality Score)']
def compare_plots(df: pd.DataFrame, LIST1: list, LIST2: list, ALL:list, NAMES:list, LEGEND: list, MOBILE: list):
    '''
    Quality Score v/s % Played plot by selecting models
    Args:
        df: A dummy dataframe of latest version
        LIST1: The list of open source models to show in the plot, updated from frontend
        LIST2: The list of commercial models to show in the plot, updated from frontend
        ALL: Either [] or ["Show All Models"] - toggle view to plot all models 
        NAMES: Either [] or ["Show Names"] - toggle view to show model names on plot
        LEGEND: Either [] or ["Show Legend"] - toggle view to show legend on plot 
        MOBILE: Either [] or ["Mobile View"] - toggle view to for smaller screens
    Returns:
        fig: The plot
    '''

    # Combine lists for Open source and commercial models
    LIST = LIST1 + LIST2
    fig = plotly_plot(df, LIST, ALL, NAMES, LEGEND, MOBILE)

    return fig
    
def shorten_model_name(full_name):
    # Split the name into parts
    parts = full_name.split('-')

    # Process the name parts to keep only the parts with digits (model sizes and versions)
    short_name_parts = [part for part in parts if any(char.isdigit() for char in part)]

    if len(parts) == 1:
        short_name = ''.join(full_name[0:min(3, len(full_name))])
    else:
        # Join the parts to form the short name
        short_name = '-'.join(short_name_parts)

        # Remove any leading or trailing hyphens
        short_name = full_name[0] + '-'+ short_name.strip('-')

    return short_name

def label_map(model_list: list) -> dict:
    '''
    Generate a map from long names to short names, to plot them in frontend graph
    Define the short names in src/assets/text_content.py
    Args: 
        model_list: A list of long model names
    Returns:
        short_name: A dict from long to short name
    '''
    short_names = {}
    for model_name in model_list:
        if model_name in SHORT_NAMES:
            short_name = SHORT_NAMES[model_name]
        else:
            short_name = shorten_model_name(model_name)

        # Define the short name and indicate both models are same
        short_names[model_name] = short_name

    return short_names
    
def split_models(MODEL_LIST: list):
    '''
    Split the models into open source and commercial
    '''
    open_models = []
    comm_models = []

    for model in MODEL_LIST:
        if model.startswith(('gpt-', 'claude-', 'command')):
            comm_models.append(model)
        else:
            open_models.append(model)

    open_models.sort(key=lambda o: o.upper())
    comm_models.sort(key=lambda c: c.upper())
    return open_models, comm_models