Gregor Betz commited on
Commit
ed93a1c
1 Parent(s): c58d788
Files changed (1) hide show
  1. backend/data.py +12 -3
backend/data.py CHANGED
@@ -18,12 +18,13 @@ def load_cot_data():
18
  ####
19
 
20
  # download raw data
 
21
  snapshot_download(
22
  repo_id=EVAL_DATASET,
23
  revision="main",
24
  local_dir=EVAL_RESULTS_PATH,
25
  repo_type="dataset",
26
- max_workers=60,
27
  token=TOKEN
28
  )
29
 
@@ -86,7 +87,7 @@ def load_cot_data():
86
  df_cot_avg["task"] = "all"
87
 
88
  # add average results to cot df
89
- df_cot = pd.concat([df_cot_avg, df_cot], ignore_index=True)
90
 
91
 
92
  ####
@@ -94,7 +95,8 @@ def load_cot_data():
94
  ####
95
 
96
  # load traces data and extract configs
97
- dataset = datasets.load_dataset(TRACES_DATASET, split="test", token=TOKEN)
 
98
  dataset = dataset.select_columns(["config_data"])
99
  df_cottraces = pd.DataFrame({"config_data": dataset["config_data"]})
100
  del dataset
@@ -126,6 +128,9 @@ def load_cot_data():
126
  for col in ['acc_base', 'acc_cot', 'acc_gain']:
127
  df_cot[col] = 100 * df_cot[col]
128
 
 
 
 
129
  ####
130
  # Create error dataframe
131
  ####
@@ -136,4 +141,8 @@ def load_cot_data():
136
  df_cot_err.reset_index(inplace=True)
137
  df_cot_err.rename(columns={"acc_base-mean": "base accuracy", "acc_cot-mean": "cot accuracy", "acc_gain-mean": "marginal acc. gain"}, inplace=True)
138
 
 
 
 
 
139
  return df_cot_err, df_cot
 
18
  ####
19
 
20
  # download raw data
21
+ print("Downloading evaluation results...")
22
  snapshot_download(
23
  repo_id=EVAL_DATASET,
24
  revision="main",
25
  local_dir=EVAL_RESULTS_PATH,
26
  repo_type="dataset",
27
+ max_workers=8,
28
  token=TOKEN
29
  )
30
 
 
87
  df_cot_avg["task"] = "all"
88
 
89
  # add average results to cot df
90
+ df_cot = pd.concat([df_cot_avg, df_cot], ignore_index=True)
91
 
92
 
93
  ####
 
95
  ####
96
 
97
  # load traces data and extract configs
98
+ print("Loading traces data...")
99
+ dataset = datasets.load_dataset(TRACES_DATASET, split="test", token=TOKEN, num_proc=8)
100
  dataset = dataset.select_columns(["config_data"])
101
  df_cottraces = pd.DataFrame({"config_data": dataset["config_data"]})
102
  del dataset
 
128
  for col in ['acc_base', 'acc_cot', 'acc_gain']:
129
  df_cot[col] = 100 * df_cot[col]
130
 
131
+ print("Regimes dataframe created:")
132
+ print(df_cot.head(3))
133
+
134
  ####
135
  # Create error dataframe
136
  ####
 
141
  df_cot_err.reset_index(inplace=True)
142
  df_cot_err.rename(columns={"acc_base-mean": "base accuracy", "acc_cot-mean": "cot accuracy", "acc_gain-mean": "marginal acc. gain"}, inplace=True)
143
 
144
+ print("Error dataframe created:")
145
+ print(df_cot_err.head(3))
146
+
147
+
148
  return df_cot_err, df_cot