Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -109,7 +109,9 @@ def filter_leaderboard(df, model_name, sort_by):
|
|
109 |
return filtered_df
|
110 |
|
111 |
|
112 |
-
def create_scatter_plot(df, x_axis, y_axis):
|
|
|
|
|
113 |
fig = px.scatter(
|
114 |
df,
|
115 |
x=x_axis,
|
@@ -120,10 +122,27 @@ def create_scatter_plot(df, x_axis, y_axis):
|
|
120 |
trendline='ols',
|
121 |
trendline_options=dict(log_x=True, log_y=True),
|
122 |
color='highlighted',
|
123 |
-
color_discrete_map={True:
|
124 |
title=f'{y_axis} vs {x_axis}'
|
125 |
)
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
return fig
|
129 |
|
@@ -160,7 +179,7 @@ def update_leaderboard_and_plot(
|
|
160 |
else:
|
161 |
combined_df = filtered_df
|
162 |
|
163 |
-
fig = create_scatter_plot(combined_df, x_axis, y_axis)
|
164 |
display_df = combined_df.drop(columns=['highlighted'])
|
165 |
display_df = display_df.style.apply(lambda x: ['background-color: #FFA500' if combined_df.loc[x.name, 'highlighted'] else '' for _ in x], axis=1).format(precision=2)
|
166 |
return display_df, fig
|
|
|
109 |
return filtered_df
|
110 |
|
111 |
|
112 |
+
def create_scatter_plot(df, x_axis, y_axis, model_filter, highlight_filter):
|
113 |
+
selected_color = 'orange'
|
114 |
+
|
115 |
fig = px.scatter(
|
116 |
df,
|
117 |
x=x_axis,
|
|
|
122 |
trendline='ols',
|
123 |
trendline_options=dict(log_x=True, log_y=True),
|
124 |
color='highlighted',
|
125 |
+
color_discrete_map={True: selected_color, False: 'blue'},
|
126 |
title=f'{y_axis} vs {x_axis}'
|
127 |
)
|
128 |
+
|
129 |
+
# Create legend labels
|
130 |
+
legend_labels = {}
|
131 |
+
if highlight_filter:
|
132 |
+
legend_labels[True] = f'{highlight_filter}'
|
133 |
+
legend_labels[False] = f'{model_filter or "all models"}'
|
134 |
+
else:
|
135 |
+
legend_labels[False] = f'{model_filter or "all models"}'
|
136 |
+
|
137 |
+
# Update legend
|
138 |
+
for trace in fig.data:
|
139 |
+
if isinstance(trace.marker.color, str): # This is for the scatter traces
|
140 |
+
trace.name = legend_labels.get(trace.marker.color == selected_color, '')
|
141 |
+
|
142 |
+
fig.update_layout(
|
143 |
+
showlegend=True,
|
144 |
+
legend_title_text='Model Selection'
|
145 |
+
)
|
146 |
|
147 |
return fig
|
148 |
|
|
|
179 |
else:
|
180 |
combined_df = filtered_df
|
181 |
|
182 |
+
fig = create_scatter_plot(combined_df, x_axis, y_axis, model_name, highlight_name)
|
183 |
display_df = combined_df.drop(columns=['highlighted'])
|
184 |
display_df = display_df.style.apply(lambda x: ['background-color: #FFA500' if combined_df.loc[x.name, 'highlighted'] else '' for _ in x], axis=1).format(precision=2)
|
185 |
return display_df, fig
|