rwightman HF staff commited on
Commit
eaadf1d
1 Parent(s): 0bcf830

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -4
app.py CHANGED
@@ -109,7 +109,9 @@ def filter_leaderboard(df, model_name, sort_by):
109
  return filtered_df
110
 
111
 
112
- def create_scatter_plot(df, x_axis, y_axis):
 
 
113
  fig = px.scatter(
114
  df,
115
  x=x_axis,
@@ -120,10 +122,27 @@ def create_scatter_plot(df, x_axis, y_axis):
120
  trendline='ols',
121
  trendline_options=dict(log_x=True, log_y=True),
122
  color='highlighted',
123
- color_discrete_map={True: 'orange', False: 'blue'},
124
  title=f'{y_axis} vs {x_axis}'
125
  )
126
- fig.update_layout(showlegend=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  return fig
129
 
@@ -160,7 +179,7 @@ def update_leaderboard_and_plot(
160
  else:
161
  combined_df = filtered_df
162
 
163
- fig = create_scatter_plot(combined_df, x_axis, y_axis)
164
  display_df = combined_df.drop(columns=['highlighted'])
165
  display_df = display_df.style.apply(lambda x: ['background-color: #FFA500' if combined_df.loc[x.name, 'highlighted'] else '' for _ in x], axis=1).format(precision=2)
166
  return display_df, fig
 
109
  return filtered_df
110
 
111
 
112
+ def create_scatter_plot(df, x_axis, y_axis, model_filter, highlight_filter):
113
+ selected_color = 'orange'
114
+
115
  fig = px.scatter(
116
  df,
117
  x=x_axis,
 
122
  trendline='ols',
123
  trendline_options=dict(log_x=True, log_y=True),
124
  color='highlighted',
125
+ color_discrete_map={True: selected_color, False: 'blue'},
126
  title=f'{y_axis} vs {x_axis}'
127
  )
128
+
129
+ # Create legend labels
130
+ legend_labels = {}
131
+ if highlight_filter:
132
+ legend_labels[True] = f'{highlight_filter}'
133
+ legend_labels[False] = f'{model_filter or "all models"}'
134
+ else:
135
+ legend_labels[False] = f'{model_filter or "all models"}'
136
+
137
+ # Update legend
138
+ for trace in fig.data:
139
+ if isinstance(trace.marker.color, str): # This is for the scatter traces
140
+ trace.name = legend_labels.get(trace.marker.color == selected_color, '')
141
+
142
+ fig.update_layout(
143
+ showlegend=True,
144
+ legend_title_text='Model Selection'
145
+ )
146
 
147
  return fig
148
 
 
179
  else:
180
  combined_df = filtered_df
181
 
182
+ fig = create_scatter_plot(combined_df, x_axis, y_axis, model_name, highlight_name)
183
  display_df = combined_df.drop(columns=['highlighted'])
184
  display_df = display_df.style.apply(lambda x: ['background-color: #FFA500' if combined_df.loc[x.name, 'highlighted'] else '' for _ in x], axis=1).format(precision=2)
185
  return display_df, fig