Corey Morris commited on
Commit
7ed3839
1 Parent(s): 31bed1a

Add dashed line at the appropriate scale of the largest and smallest values on the plot so that plotly still zooms in to show that

Browse files
Files changed (1) hide show
  1. app.py +35 -0
app.py CHANGED
@@ -66,8 +66,41 @@ def create_plot(df, arc_column, moral_column, models=None):
66
  xaxis = dict(),
67
  yaxis = dict())
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  return fig
70
 
 
71
  st.header('Custom scatter plots')
72
  selected_x_column = st.selectbox('Select x-axis', filtered_data.columns.tolist(), index=0)
73
  selected_y_column = st.selectbox('Select y-axis', filtered_data.columns.tolist(), index=1)
@@ -78,6 +111,8 @@ if selected_x_column != selected_y_column: # Avoid creating a plot with the s
78
  else:
79
  st.write("Please select different columns for the x and y axes.")
80
 
 
 
81
  st.header('Overall evaluation comparisons')
82
 
83
  fig = create_plot(filtered_data, 'arc:challenge|25', 'hellaswag|10')
 
66
  xaxis = dict(),
67
  yaxis = dict())
68
 
69
+ # Add a dashed line at 0.25 for the moral columns
70
+ x_min = df[arc_column].min()
71
+ x_max = df[arc_column].max()
72
+
73
+ y_min = df[moral_column].min()
74
+ y_max = df[moral_column].max()
75
+
76
+ if arc_column.startswith('MMLU'):
77
+ fig.add_shape(
78
+ type='line',
79
+ x0=0.25, x1=0.25,
80
+ y0=y_min, y1=y_max,
81
+ line=dict(
82
+ color='red',
83
+ width=2,
84
+ dash='dash'
85
+ )
86
+ )
87
+
88
+ if moral_column.startswith('MMLU'):
89
+ fig.add_shape(
90
+ type='line',
91
+ x0=x_min, x1=x_max,
92
+ y0=0.25, y1=0.25,
93
+ line=dict(
94
+ color='red',
95
+ width=2,
96
+ dash='dash'
97
+ )
98
+ )
99
+
100
+
101
  return fig
102
 
103
+ # Custom scatter plots
104
  st.header('Custom scatter plots')
105
  selected_x_column = st.selectbox('Select x-axis', filtered_data.columns.tolist(), index=0)
106
  selected_y_column = st.selectbox('Select y-axis', filtered_data.columns.tolist(), index=1)
 
111
  else:
112
  st.write("Please select different columns for the x and y axes.")
113
 
114
+ # end of custom scatter plots
115
+
116
  st.header('Overall evaluation comparisons')
117
 
118
  fig = create_plot(filtered_data, 'arc:challenge|25', 'hellaswag|10')