bowdbeg commited on
Commit
d2b22fa
1 Parent(s): cd855de

make intermediate results optional for report

Browse files
Files changed (1) hide show
  1. matching_series.py +36 -13
matching_series.py CHANGED
@@ -136,6 +136,11 @@ class matching_series(evaluate.Metric):
136
  cuc_n_samples: Union[List[int], str] = "auto",
137
  metric: str = "mse",
138
  num_process: int = 1,
 
 
 
 
 
139
  ):
140
  """
141
  Compute the scores of the module given the predictions and references
@@ -149,6 +154,11 @@ class matching_series(evaluate.Metric):
149
  cuc_n_samples: number of samples to use for Coverage Under Curve calculation. If "auto", it uses the number of samples of the predictions.
150
  Returns:
151
  """
 
 
 
 
 
152
  predictions = np.array(predictions)
153
  references = np.array(references)
154
  if predictions.shape[1:] != references.shape[1:]:
@@ -271,16 +281,11 @@ class matching_series(evaluate.Metric):
271
 
272
  macro_cuc = statistics.mean(cuc_features)
273
  macro_coverages = [statistics.mean(c) for c in zip(*coverages_features)]
274
-
275
- return {
276
  "precision_distance": precision_distance,
277
  "f1_distance": f1_distance,
278
  "recall_distance": recall_distance,
279
  "index_distance": index_distance,
280
- "precision_distance_features": precision_distance_features,
281
- "f1_distance_features": f1_distance_features,
282
- "recall_distance_features": recall_distance_features,
283
- "index_distance_features": index_distance_features,
284
  "macro_precision_distance": macro_precision_distance,
285
  "macro_recall_distance": macro_recall_distance,
286
  "macro_f1_distance": macro_f1_distance,
@@ -288,19 +293,37 @@ class matching_series(evaluate.Metric):
288
  "matching_precision": matching_precision,
289
  "matching_recall": matching_recall,
290
  "matching_f1": matching_f1,
291
- "matching_precision_features": matching_precision_features,
292
- "matching_recall_features": matching_recall_features,
293
- "matching_f1_features": matching_f1_features,
294
  "macro_matching_precision": macro_matching_precision,
295
  "macro_matching_recall": macro_matching_recall,
296
  "macro_matching_f1": macro_matching_f1,
297
  "cuc": cuc,
298
- "coverages": coverages,
299
  "macro_cuc": macro_cuc,
300
- "macro_coverages": macro_coverages,
301
- "cuc_features": cuc_features,
302
- "coverages_features": coverages_features,
303
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
  def compute_cuc(
306
  self,
 
136
  cuc_n_samples: Union[List[int], str] = "auto",
137
  metric: str = "mse",
138
  num_process: int = 1,
139
+ return_distance: bool = False,
140
+ return_matching: bool = False,
141
+ return_each_features: bool = False,
142
+ return_coverages: bool = False,
143
+ return_all: bool = False,
144
  ):
145
  """
146
  Compute the scores of the module given the predictions and references
 
154
  cuc_n_samples: number of samples to use for Coverage Under Curve calculation. If "auto", it uses the number of samples of the predictions.
155
  Returns:
156
  """
157
+ if return_all:
158
+ return_distance = True
159
+ return_matching = True
160
+ return_each_features = True
161
+ return_coverages = True
162
  predictions = np.array(predictions)
163
  references = np.array(references)
164
  if predictions.shape[1:] != references.shape[1:]:
 
281
 
282
  macro_cuc = statistics.mean(cuc_features)
283
  macro_coverages = [statistics.mean(c) for c in zip(*coverages_features)]
284
+ out = {
 
285
  "precision_distance": precision_distance,
286
  "f1_distance": f1_distance,
287
  "recall_distance": recall_distance,
288
  "index_distance": index_distance,
 
 
 
 
289
  "macro_precision_distance": macro_precision_distance,
290
  "macro_recall_distance": macro_recall_distance,
291
  "macro_f1_distance": macro_f1_distance,
 
293
  "matching_precision": matching_precision,
294
  "matching_recall": matching_recall,
295
  "matching_f1": matching_f1,
 
 
 
296
  "macro_matching_precision": macro_matching_precision,
297
  "macro_matching_recall": macro_matching_recall,
298
  "macro_matching_f1": macro_matching_f1,
299
  "cuc": cuc,
 
300
  "macro_cuc": macro_cuc,
 
 
 
301
  }
302
+ if return_distance:
303
+ out["distance"] = distance
304
+ if return_matching:
305
+ out["match"] = best_match
306
+ out["match_inv"] = best_match_inv
307
+ if return_each_features:
308
+ if return_distance:
309
+ out["distance_features"] = distance_mean
310
+ out.update(
311
+ {
312
+ "precision_distance_features": precision_distance_features,
313
+ "f1_distance_features": f1_distance_features,
314
+ "recall_distance_features": recall_distance_features,
315
+ "index_distance_features": index_distance_features,
316
+ "matching_precision_features": matching_precision_features,
317
+ "matching_recall_features": matching_recall_features,
318
+ "matching_f1_features": matching_f1_features,
319
+ "cuc_features": cuc_features,
320
+ "coverages_features": coverages_features,
321
+ }
322
+ )
323
+ if return_coverages:
324
+ out["coverages"] = coverages
325
+ out["macro_coverages"] = macro_coverages
326
+ return out
327
 
328
  def compute_cuc(
329
  self,