File size: 21,758 Bytes
a03c9b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 |
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
"""metrics.py"""
from typing import List, Any, Dict, Optional, Tuple, Union
import numpy as np
import copy
from torch.nn import Module
from utils.note_event_dataclasses import NoteEvent, Note
from utils.note2event import sort_notes, notes2pc_notes
from utils.event2note import note_event2note
from sklearn.metrics import average_precision_score
from utils.metrics_helper import (f1_measure, round_float, mir_eval_note_f1, mir_eval_frame_f1, mir_eval_melody_metric,
extract_pitches_intervals_from_notes, extract_frame_time_freq_from_notes)
from torchmetrics import MeanMetric, SumMetric
class UpdatedMeanMetric(MeanMetric):
"""
A wrapper of torchmetrics.MeanMetric to support reset and update separately.
"""
def __init__(self, nan_strategy: str = 'ignore', **kwargs) -> None:
super().__init__(nan_strategy=nan_strategy, **kwargs)
self._updated = False
def update(self, *args, **kwargs):
super().update(*args, **kwargs)
self._updated = True
def is_updated(self):
return self._updated
class UpdatedSumMetric(SumMetric):
"""
A wrapper of torchmetrics.SumMetric to support reset and update separately.
"""
def __init__(self, nan_strategy: str = 'ignore', **kwargs) -> None:
super().__init__(nan_strategy=nan_strategy, **kwargs)
self._updated = False
def update(self, *args, **kwargs):
super().update(*args, **kwargs)
self._updated = True
def is_updated(self):
return self._updated
class AMTMetrics(Module):
"""
Automatic music transcription (AMT) evaluation metrics for music transcription
tasks with DDP support, following the convention of AMT. The average of file-wise
metrics is calculated.
Metrics:
--------
Instrument-agnostic note onset and note offset metrics:
(Drum notes are generally excluded)
- onset_f: the most conventional, often called Note F1
- offset_f: a pair of onset + offset matching metric
Multi-instrument note on-offset Macro-micro F1 metric, multi-F1 (of MT3):
- multi_f: counts for onset + offset + program (instrument class) matching.
For drum notes, we only count onset. macro-micro means that we
calculate weighted precision and recall by counting each note
instrument class per file, and calcualte micro F1. We then
calculate average F1 for all files with equal weights (Macro).
Instrument-group note onset and offset metrics are defined by extra_classes:
e.g. extra_classes = ['piano', 'guitar']
- onset_f_piano: piano instrument
- onset_f_guitar: guitar instrument
- offset_f_piano: piano instrument
- offset_f_guitar: guitar instrument
also p, r metrics follow...
Usage:
------
Each metric instance can be individually updated and reset for computation.
```
my_metric = AMTMetrics()
my_metric.onset_f.update(0.5)
my_metric.onset_f(0.5) # same
my_metric.onset_f(0, weight=1.0) # same and weighted by 1.0 (default)
my_metric.onset_f.compute() # return 0.333..
my_metric.onset_f.reset() # reset the metric
```
• {attribute}.update(value: float, weight: Optional[float]): Here weight is an
optional argument for weighted average.
• {attribute}.(...): Same as update method.
• {attribute}.compute(): Return the average value of the metric.
• {attribute}.reset(): Reset the metric.
Class methods:
---------------
```
d = {'onset_f': 0.5, 'offset_f': 0.5}
my_metric.bulk_update(d)
d = {'onset_f': {'value': 0.5, 'weight': 1.0}, 'offset_f': {'value': 0.5, 'weight': 1.0}}
my_metric.onset_f.update(d)
```
• bulk_update(metrics: Dict[str, Union[float, Dict[str, float]]]): Update metrics with a
dictionary as an argument.
• bulk_compute(): Return a dictionary of any non-empty metrics with average values.
• bulk_reset(): Reset all metrics.
"""
def __init__(self,
prefix: str = '',
nan_strategy: str = 'ignore',
extra_classes: Optional[List[str]] = None,
extra_metrics: Optional[List[str]] = None,
error_types: Optional[List[str]] = None,
**kwargs) -> None:
"""
Args:
suffix: prefix for the metric name, e.g. 'val' or 'test'. '_' will be added automatically.
nan_strategy: 'warn' or 'raise' or 'ignore'
"""
super().__init__(**kwargs)
self._prefix = prefix
self.nan_strategy = nan_strategy
# Instrument-agnostic Note onsets and Note on-offset metrics for non-drum notes
self.onset_f = UpdatedMeanMetric(nan_strategy=nan_strategy)
self.offset_f = UpdatedMeanMetric(nan_strategy=nan_strategy)
# Instrument-agnostic Frame F1 (skip in validation)
self.frame_f = UpdatedMeanMetric(nan_strategy=nan_strategy)
self.frame_f_pc = UpdatedMeanMetric(nan_strategy=nan_strategy)
# Drum Onset metrics
self.onset_f_drum = UpdatedMeanMetric(nan_strategy=nan_strategy)
# Multi F1 (Macro-micro F1 of MT3)
self.multi_f = UpdatedMeanMetric(nan_strategy=nan_strategy)
# Initialize extra metrics for instrument macro F1
self.extra_classes = extra_classes
if extra_classes is not None:
for class_name in extra_classes:
if not hasattr(self, class_name):
for onoff in ['onset', 'offset']:
for fpr in ['f']:
setattr(self, onoff + '_' + fpr + '_' + class_name,
UpdatedMeanMetric(nan_strategy=nan_strategy))
# setattr(self, class_name, UpdatedMeanMetric(nan_strategy=nan_strategy))
else:
raise ValueError(f"Metric '{class_name}' already exists.")
# Initialize extra metrics for instruments(F is computed later)
self.extra_classes = extra_classes
if extra_classes is not None:
for class_name in extra_classes:
if not hasattr(self, class_name):
for onoff in ['micro_onset', 'micro_offset']:
for fpr in ['p', 'r']:
setattr(self, onoff + '_' + fpr + '_' + class_name,
UpdatedMeanMetric(nan_strategy=nan_strategy))
# setattr(
# self, onoff + '_f_' + class_name, None
# ) # micro_onset_f and micro_offset_f for each instrument
else:
raise ValueError(f"Metric '{class_name}' already exists.")
# Initialize drum micro P,R (F is computed later)
self.micro_onset_p_drum = UpdatedMeanMetric(nan_strategy=nan_strategy)
self.micro_onset_r_drum = UpdatedMeanMetric(nan_strategy=nan_strategy)
# Initialize extra metrics directly
if extra_metrics is not None:
for metric_name in extra_metrics:
setattr(self, metric_name, UpdatedMeanMetric(nan_strategy=nan_strategy))
# Initialize error counters
self.error_types = error_types
if error_types is not None:
for error_type in error_types:
setattr(self, error_type, UpdatedMeanMetric(nan_strategy=nan_strategy))
def bulk_update(self, metrics: Dict[str, Union[float, Dict[str, float], Tuple[float, ...]]]) -> None:
""" Update metrics with a dictionary as an argument.
metrics:
{'onset_f': 0.5, 'offset_f': 0.5}
or {'onset_f': {'value': 0.5, 'weight': 1.0}, 'offset_f': {'value': 0.5, 'weight': 1.0}}
or {'onset_p': (0.3, 5)}
"""
for k, v in metrics.items():
if isinstance(v, dict):
getattr(self, k).update(**v)
elif isinstance(v, tuple):
getattr(self, k).update(*v)
else:
getattr(self, k).update(v)
def bulk_update_errors(self, errors: Dict[str, Union[int, float]]) -> None:
""" Update error counts with a dictionary as an argument.
errors:
{'error_type_or_message_1': (int | float) count,
'error_type_or_message_2': (int | float) count,}
"""
for error_type, count in errors.items():
# Update the error count
if isinstance(count, int) or isinstance(count, float):
getattr(self, error_type).update(count)
else:
raise ValueError(f"Count of error type '{error_type}' must be an integer or a float.")
def bulk_compute(self) -> Dict[str, float]:
computed_metrics = {}
for k, v in self._modules.items():
if isinstance(v, UpdatedMeanMetric) and v.is_updated():
computed_metrics[self._prefix + k] = v.compute()
# Create micro onset F1 for each instrument. Only when micro metrics are updated.
extra_classes = self.extra_classes if self.extra_classes is not None else []
for class_name in extra_classes + ['drum']:
# micro onset F1 for each instrument.
_micro_onset_p_instr = computed_metrics.get(self._prefix + 'micro_onset_p_' + class_name, None)
_micro_onset_r_instr = computed_metrics.get(self._prefix + 'micro_onset_r_' + class_name, None)
if _micro_onset_p_instr is not None and _micro_onset_r_instr is not None:
computed_metrics[self._prefix + 'micro_onset_f_' + class_name] = f1_measure(
_micro_onset_p_instr.item(), _micro_onset_r_instr.item())
# micro offset F1 for each instrument. 'drum' is usually not included.
_micro_offset_p_instr = computed_metrics.get(self._prefix + 'micro_offset_p_' + class_name, None)
_micro_offset_r_instr = computed_metrics.get(self._prefix + 'micro_offset_r_' + class_name, None)
if _micro_offset_p_instr is not None and _micro_offset_r_instr is not None:
computed_metrics[self._prefix + 'micro_offset_f_' + class_name] = f1_measure(
_micro_offset_p_instr.item(), _micro_offset_r_instr.item())
# Remove micro onset and offset P,R (Now we have F1)
for class_name in extra_classes + ['drum']:
for onoff in ['micro_onset', 'micro_offset']:
for pr in ['p', 'r']:
computed_metrics.pop(self._prefix + onoff + '_' + pr + '_' + class_name, None)
return computed_metrics
def bulk_reset(self) -> None:
for k, v in self._modules.items():
if isinstance(v, UpdatedMeanMetric):
v.reset()
v._updated = False
def compute_track_metrics(pred_notes: List[Note],
ref_notes: List[Note],
eval_vocab: Optional[Dict] = None,
eval_drum_vocab: Optional[Dict] = None,
onset_tolerance: float = 0.05,
add_pitch_class_metric: Optional[List[str]] = None,
add_melody_metric: Optional[List[str]] = None,
add_frame_metric: bool = False,
add_micro_metric: bool = False,
add_multi_f_metric: bool = False,
extra_info: Optional[Any] = None):
""" Track metrics
Args:
pred_notes: (List[Note]) predicted sequence of notes for a track
ref_notes: (List[Note]) reference sequence of notes for a track
return_instr_metric: (bool) return instrument-specific metrics
eval_vocab: (Dict or None) program group for instrument-specific metrics
{
instrument_or_group_name:
[program_number_0, program_number_1 ...]
}
If None, use default GM instruments.
ex) eval_vocab = {"piano": np.arange(0, 8), ...}
drum_vocab: (Dict or None) note (pitch) group for drum-specific metrics
{
instrument_or_group_name:
[note_number_0, note_number_1 ...]
}
add_pitch_class_metric: (List[str] or None) add pitch class metrics for the
given instruments. The instrument names are defined in config/vocabulrary.py.
ex) ['Bass', 'Guitar']
add_singing_oa_metric: (bool) add melody overall accuracy for tje given instruments.
The instrument names are defined in config/vocabulrary.py.
ex) ['Singing Voice']
(https://craffel.github.io/mir_eval/#mir_eval.melody.overall_accuracy
add_frame_metric: (bool) add frame-wise metrics
extra_info: (Any) extra information for debugging. Currently not implemented
Returns:
metrics: (Dict) track metrics in the AMTMetric format with attribute names such as 'onset_f_{instrument_or_group_name}'
@dataclass
class Note:
is_drum: bool
program: int
onset: float
offset: float
pitch: int
velocity: int
Caution: Note is mutable instance, even if we use copy().
"""
# Extract drum and non-drum notes
def extract_drum_and_non_drum_notes(notes: List[Note]):
drum_notes, non_drum_notes = [], []
for note in notes:
if note.is_drum:
drum_notes.append(note)
else:
non_drum_notes.append(note)
return drum_notes, non_drum_notes
pns_drum, pns_non_drum = extract_drum_and_non_drum_notes(pred_notes)
rns_drum, rns_non_drum = extract_drum_and_non_drum_notes(ref_notes)
# Reduce drum notes to drum vocab
def reduce_drum_notes_to_drum_vocab(notes: List[Note], drum_vocab: Dict):
reduced_notes = []
for note in notes:
for drum_name, pitches in drum_vocab.items():
if note.pitch in pitches:
new_note = copy.deepcopy(note)
new_note.pitch = pitches[0]
reduced_notes.append(new_note)
return sort_notes(reduced_notes)
if eval_drum_vocab != None:
pns_drum = reduce_drum_notes_to_drum_vocab(pns_drum, eval_drum_vocab)
rns_drum = reduce_drum_notes_to_drum_vocab(rns_drum, eval_drum_vocab)
# Extract Pitches (freq) and Intervals
pns_drum_pi = extract_pitches_intervals_from_notes(pns_drum, is_drum=True)
pns_non_drum_pi = extract_pitches_intervals_from_notes(pns_non_drum)
rns_drum_pi = extract_pitches_intervals_from_notes(rns_drum, is_drum=True)
rns_non_drum_pi = extract_pitches_intervals_from_notes(rns_non_drum)
# Compute file-wise PRF for drums
drum_metric = mir_eval_note_f1(pns_drum_pi['pitches'],
pns_drum_pi['intervals'],
rns_drum_pi['pitches'],
rns_drum_pi['intervals'],
onset_tolerance=onset_tolerance,
is_drum=True,
add_micro_metric=add_micro_metric)
# Compute file-wise PRF for non-drums
non_drum_metric = mir_eval_note_f1(pns_non_drum_pi['pitches'],
pns_non_drum_pi['intervals'],
rns_non_drum_pi['pitches'],
rns_non_drum_pi['intervals'],
onset_tolerance=onset_tolerance,
is_drum=False)
# Compute file-wise frame PRF for non-drums
if add_frame_metric is True:
# Extract frame-level Pitches (freq) and Intervals
pns_non_drum_tf = extract_frame_time_freq_from_notes(pns_non_drum)
rns_non_drum_tf = extract_frame_time_freq_from_notes(rns_non_drum)
res = mir_eval_frame_f1(pns_non_drum_tf, rns_non_drum_tf)
non_drum_metric = {**non_drum_metric, **res} # merge dicts
############## Compute instrument-wise PRF for non-drums ##############
if eval_vocab is None:
return drum_metric, non_drum_metric, {}
else:
instr_metric = {}
for group_name, programs in eval_vocab.items():
# Extract notes for each instrument
# bug fix for piano/drum overlap on slakh
pns_group = [note for note in pns_non_drum if note.program in programs]
rns_group = [note for note in rns_non_drum if note.program in programs]
# Compute PC instrument-wise PRF using pitch class (currently for bass)
if add_pitch_class_metric is not None:
if group_name.lower() in [g.lower() for g in add_pitch_class_metric]:
# pc: pitch information is converted to pitch classe e.g. 0-11
pns_pc_group = extract_pitches_intervals_from_notes(notes2pc_notes(pns_group))
rns_pc_group = extract_pitches_intervals_from_notes(notes2pc_notes(rns_group))
_instr_pc_metric = mir_eval_note_f1(pns_pc_group['pitches'],
pns_pc_group['intervals'],
rns_pc_group['pitches'],
rns_pc_group['intervals'],
onset_tolerance=onset_tolerance,
is_drum=False,
add_micro_metric=add_micro_metric,
suffix=group_name + '_pc')
# Add to instrument-wise PRF
for k, v in _instr_pc_metric.items():
instr_metric[k] = v
# Extract Pitches (freq) and Intervals
pns_group = extract_pitches_intervals_from_notes(pns_group)
rns_group = extract_pitches_intervals_from_notes(rns_group)
# Compute instrument-wise PRF
_instr_metric = mir_eval_note_f1(pns_group['pitches'],
pns_group['intervals'],
rns_group['pitches'],
rns_group['intervals'],
onset_tolerance=onset_tolerance,
is_drum=False,
add_micro_metric=add_micro_metric,
suffix=group_name)
# Merge instrument-wise PRF
for k, v in _instr_metric.items():
instr_metric[k] = v
# Optionally compute melody metrics: RPA, RCA, OA
if add_melody_metric is not None:
if group_name.lower() in [g.lower() for g in add_melody_metric]:
_melody_metric = mir_eval_melody_metric(pns_group['pitches'],
pns_group['intervals'],
rns_group['pitches'],
rns_group['intervals'],
cent_tolerance=50,
suffix=group_name)
for k, v in _melody_metric.items():
instr_metric[k] = v
# Calculate multi_f metric for this track
if add_multi_f_metric is True:
drum_micro_onset_tp_sum, drum_micro_onset_tpfp_sum, drum_micro_onset_tpfn_sum = 0., 0., 0.
non_drum_micro_offset_tp_sum, non_drum_micro_offset_tpfp_sum, non_drum_micro_offset_tpfn_sum = 0., 0., 0.
# Collect offset metric for non-drum notes
for k, v in instr_metric.items():
if 'micro_offset_p_' in k and not np.isnan(v['value']):
non_drum_micro_offset_tp_sum += v['value'] * v['weight']
non_drum_micro_offset_tpfp_sum += v['weight']
if 'micro_offset_r_' in k and not np.isnan(v['value']):
non_drum_micro_offset_tpfn_sum += v['weight']
# Collect onset metric for drum notes
for k, v in drum_metric.items():
if 'micro_onset_p_drum' in k and not np.isnan(v['value']):
drum_micro_onset_tp_sum += v['value'] * v['weight']
drum_micro_onset_tpfp_sum += v['weight']
if 'micro_onset_r_drum' in k and not np.isnan(v['value']):
drum_micro_onset_tpfn_sum += v['weight']
tp = non_drum_micro_offset_tp_sum + drum_micro_onset_tp_sum
tpfp = non_drum_micro_offset_tpfp_sum + drum_micro_onset_tpfp_sum
tpfn = non_drum_micro_offset_tpfn_sum + drum_micro_onset_tpfn_sum
multi_p_track = tp / tpfp if tpfp > 0 else np.nan
multi_r_track = tp / tpfn if tpfn > 0 else np.nan
multi_f_track = f1_measure(multi_p_track, multi_r_track)
instr_metric['multi_f'] = multi_f_track
return drum_metric, non_drum_metric, instr_metric
|