# Copyright 2024 The YourMT3 Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Please see the details in the LICENSE file. """metrics.py""" from typing import List, Any, Dict, Optional, Tuple, Union import numpy as np import copy from torch.nn import Module from utils.note_event_dataclasses import NoteEvent, Note from utils.note2event import sort_notes, notes2pc_notes from utils.event2note import note_event2note from sklearn.metrics import average_precision_score from utils.metrics_helper import (f1_measure, round_float, mir_eval_note_f1, mir_eval_frame_f1, mir_eval_melody_metric, extract_pitches_intervals_from_notes, extract_frame_time_freq_from_notes) from torchmetrics import MeanMetric, SumMetric class UpdatedMeanMetric(MeanMetric): """ A wrapper of torchmetrics.MeanMetric to support reset and update separately. """ def __init__(self, nan_strategy: str = 'ignore', **kwargs) -> None: super().__init__(nan_strategy=nan_strategy, **kwargs) self._updated = False def update(self, *args, **kwargs): super().update(*args, **kwargs) self._updated = True def is_updated(self): return self._updated class UpdatedSumMetric(SumMetric): """ A wrapper of torchmetrics.SumMetric to support reset and update separately. """ def __init__(self, nan_strategy: str = 'ignore', **kwargs) -> None: super().__init__(nan_strategy=nan_strategy, **kwargs) self._updated = False def update(self, *args, **kwargs): super().update(*args, **kwargs) self._updated = True def is_updated(self): return self._updated class AMTMetrics(Module): """ Automatic music transcription (AMT) evaluation metrics for music transcription tasks with DDP support, following the convention of AMT. The average of file-wise metrics is calculated. Metrics: -------- Instrument-agnostic note onset and note offset metrics: (Drum notes are generally excluded) - onset_f: the most conventional, often called Note F1 - offset_f: a pair of onset + offset matching metric Multi-instrument note on-offset Macro-micro F1 metric, multi-F1 (of MT3): - multi_f: counts for onset + offset + program (instrument class) matching. For drum notes, we only count onset. macro-micro means that we calculate weighted precision and recall by counting each note instrument class per file, and calcualte micro F1. We then calculate average F1 for all files with equal weights (Macro). Instrument-group note onset and offset metrics are defined by extra_classes: e.g. extra_classes = ['piano', 'guitar'] - onset_f_piano: piano instrument - onset_f_guitar: guitar instrument - offset_f_piano: piano instrument - offset_f_guitar: guitar instrument also p, r metrics follow... Usage: ------ Each metric instance can be individually updated and reset for computation. ``` my_metric = AMTMetrics() my_metric.onset_f.update(0.5) my_metric.onset_f(0.5) # same my_metric.onset_f(0, weight=1.0) # same and weighted by 1.0 (default) my_metric.onset_f.compute() # return 0.333.. my_metric.onset_f.reset() # reset the metric ``` • {attribute}.update(value: float, weight: Optional[float]): Here weight is an optional argument for weighted average. • {attribute}.(...): Same as update method. • {attribute}.compute(): Return the average value of the metric. • {attribute}.reset(): Reset the metric. Class methods: --------------- ``` d = {'onset_f': 0.5, 'offset_f': 0.5} my_metric.bulk_update(d) d = {'onset_f': {'value': 0.5, 'weight': 1.0}, 'offset_f': {'value': 0.5, 'weight': 1.0}} my_metric.onset_f.update(d) ``` • bulk_update(metrics: Dict[str, Union[float, Dict[str, float]]]): Update metrics with a dictionary as an argument. • bulk_compute(): Return a dictionary of any non-empty metrics with average values. • bulk_reset(): Reset all metrics. """ def __init__(self, prefix: str = '', nan_strategy: str = 'ignore', extra_classes: Optional[List[str]] = None, extra_metrics: Optional[List[str]] = None, error_types: Optional[List[str]] = None, **kwargs) -> None: """ Args: suffix: prefix for the metric name, e.g. 'val' or 'test'. '_' will be added automatically. nan_strategy: 'warn' or 'raise' or 'ignore' """ super().__init__(**kwargs) self._prefix = prefix self.nan_strategy = nan_strategy # Instrument-agnostic Note onsets and Note on-offset metrics for non-drum notes self.onset_f = UpdatedMeanMetric(nan_strategy=nan_strategy) self.offset_f = UpdatedMeanMetric(nan_strategy=nan_strategy) # Instrument-agnostic Frame F1 (skip in validation) self.frame_f = UpdatedMeanMetric(nan_strategy=nan_strategy) self.frame_f_pc = UpdatedMeanMetric(nan_strategy=nan_strategy) # Drum Onset metrics self.onset_f_drum = UpdatedMeanMetric(nan_strategy=nan_strategy) # Multi F1 (Macro-micro F1 of MT3) self.multi_f = UpdatedMeanMetric(nan_strategy=nan_strategy) # Initialize extra metrics for instrument macro F1 self.extra_classes = extra_classes if extra_classes is not None: for class_name in extra_classes: if not hasattr(self, class_name): for onoff in ['onset', 'offset']: for fpr in ['f']: setattr(self, onoff + '_' + fpr + '_' + class_name, UpdatedMeanMetric(nan_strategy=nan_strategy)) # setattr(self, class_name, UpdatedMeanMetric(nan_strategy=nan_strategy)) else: raise ValueError(f"Metric '{class_name}' already exists.") # Initialize extra metrics for instruments(F is computed later) self.extra_classes = extra_classes if extra_classes is not None: for class_name in extra_classes: if not hasattr(self, class_name): for onoff in ['micro_onset', 'micro_offset']: for fpr in ['p', 'r']: setattr(self, onoff + '_' + fpr + '_' + class_name, UpdatedMeanMetric(nan_strategy=nan_strategy)) # setattr( # self, onoff + '_f_' + class_name, None # ) # micro_onset_f and micro_offset_f for each instrument else: raise ValueError(f"Metric '{class_name}' already exists.") # Initialize drum micro P,R (F is computed later) self.micro_onset_p_drum = UpdatedMeanMetric(nan_strategy=nan_strategy) self.micro_onset_r_drum = UpdatedMeanMetric(nan_strategy=nan_strategy) # Initialize extra metrics directly if extra_metrics is not None: for metric_name in extra_metrics: setattr(self, metric_name, UpdatedMeanMetric(nan_strategy=nan_strategy)) # Initialize error counters self.error_types = error_types if error_types is not None: for error_type in error_types: setattr(self, error_type, UpdatedMeanMetric(nan_strategy=nan_strategy)) def bulk_update(self, metrics: Dict[str, Union[float, Dict[str, float], Tuple[float, ...]]]) -> None: """ Update metrics with a dictionary as an argument. metrics: {'onset_f': 0.5, 'offset_f': 0.5} or {'onset_f': {'value': 0.5, 'weight': 1.0}, 'offset_f': {'value': 0.5, 'weight': 1.0}} or {'onset_p': (0.3, 5)} """ for k, v in metrics.items(): if isinstance(v, dict): getattr(self, k).update(**v) elif isinstance(v, tuple): getattr(self, k).update(*v) else: getattr(self, k).update(v) def bulk_update_errors(self, errors: Dict[str, Union[int, float]]) -> None: """ Update error counts with a dictionary as an argument. errors: {'error_type_or_message_1': (int | float) count, 'error_type_or_message_2': (int | float) count,} """ for error_type, count in errors.items(): # Update the error count if isinstance(count, int) or isinstance(count, float): getattr(self, error_type).update(count) else: raise ValueError(f"Count of error type '{error_type}' must be an integer or a float.") def bulk_compute(self) -> Dict[str, float]: computed_metrics = {} for k, v in self._modules.items(): if isinstance(v, UpdatedMeanMetric) and v.is_updated(): computed_metrics[self._prefix + k] = v.compute() # Create micro onset F1 for each instrument. Only when micro metrics are updated. extra_classes = self.extra_classes if self.extra_classes is not None else [] for class_name in extra_classes + ['drum']: # micro onset F1 for each instrument. _micro_onset_p_instr = computed_metrics.get(self._prefix + 'micro_onset_p_' + class_name, None) _micro_onset_r_instr = computed_metrics.get(self._prefix + 'micro_onset_r_' + class_name, None) if _micro_onset_p_instr is not None and _micro_onset_r_instr is not None: computed_metrics[self._prefix + 'micro_onset_f_' + class_name] = f1_measure( _micro_onset_p_instr.item(), _micro_onset_r_instr.item()) # micro offset F1 for each instrument. 'drum' is usually not included. _micro_offset_p_instr = computed_metrics.get(self._prefix + 'micro_offset_p_' + class_name, None) _micro_offset_r_instr = computed_metrics.get(self._prefix + 'micro_offset_r_' + class_name, None) if _micro_offset_p_instr is not None and _micro_offset_r_instr is not None: computed_metrics[self._prefix + 'micro_offset_f_' + class_name] = f1_measure( _micro_offset_p_instr.item(), _micro_offset_r_instr.item()) # Remove micro onset and offset P,R (Now we have F1) for class_name in extra_classes + ['drum']: for onoff in ['micro_onset', 'micro_offset']: for pr in ['p', 'r']: computed_metrics.pop(self._prefix + onoff + '_' + pr + '_' + class_name, None) return computed_metrics def bulk_reset(self) -> None: for k, v in self._modules.items(): if isinstance(v, UpdatedMeanMetric): v.reset() v._updated = False def compute_track_metrics(pred_notes: List[Note], ref_notes: List[Note], eval_vocab: Optional[Dict] = None, eval_drum_vocab: Optional[Dict] = None, onset_tolerance: float = 0.05, add_pitch_class_metric: Optional[List[str]] = None, add_melody_metric: Optional[List[str]] = None, add_frame_metric: bool = False, add_micro_metric: bool = False, add_multi_f_metric: bool = False, extra_info: Optional[Any] = None): """ Track metrics Args: pred_notes: (List[Note]) predicted sequence of notes for a track ref_notes: (List[Note]) reference sequence of notes for a track return_instr_metric: (bool) return instrument-specific metrics eval_vocab: (Dict or None) program group for instrument-specific metrics { instrument_or_group_name: [program_number_0, program_number_1 ...] } If None, use default GM instruments. ex) eval_vocab = {"piano": np.arange(0, 8), ...} drum_vocab: (Dict or None) note (pitch) group for drum-specific metrics { instrument_or_group_name: [note_number_0, note_number_1 ...] } add_pitch_class_metric: (List[str] or None) add pitch class metrics for the given instruments. The instrument names are defined in config/vocabulrary.py. ex) ['Bass', 'Guitar'] add_singing_oa_metric: (bool) add melody overall accuracy for tje given instruments. The instrument names are defined in config/vocabulrary.py. ex) ['Singing Voice'] (https://craffel.github.io/mir_eval/#mir_eval.melody.overall_accuracy add_frame_metric: (bool) add frame-wise metrics extra_info: (Any) extra information for debugging. Currently not implemented Returns: metrics: (Dict) track metrics in the AMTMetric format with attribute names such as 'onset_f_{instrument_or_group_name}' @dataclass class Note: is_drum: bool program: int onset: float offset: float pitch: int velocity: int Caution: Note is mutable instance, even if we use copy(). """ # Extract drum and non-drum notes def extract_drum_and_non_drum_notes(notes: List[Note]): drum_notes, non_drum_notes = [], [] for note in notes: if note.is_drum: drum_notes.append(note) else: non_drum_notes.append(note) return drum_notes, non_drum_notes pns_drum, pns_non_drum = extract_drum_and_non_drum_notes(pred_notes) rns_drum, rns_non_drum = extract_drum_and_non_drum_notes(ref_notes) # Reduce drum notes to drum vocab def reduce_drum_notes_to_drum_vocab(notes: List[Note], drum_vocab: Dict): reduced_notes = [] for note in notes: for drum_name, pitches in drum_vocab.items(): if note.pitch in pitches: new_note = copy.deepcopy(note) new_note.pitch = pitches[0] reduced_notes.append(new_note) return sort_notes(reduced_notes) if eval_drum_vocab != None: pns_drum = reduce_drum_notes_to_drum_vocab(pns_drum, eval_drum_vocab) rns_drum = reduce_drum_notes_to_drum_vocab(rns_drum, eval_drum_vocab) # Extract Pitches (freq) and Intervals pns_drum_pi = extract_pitches_intervals_from_notes(pns_drum, is_drum=True) pns_non_drum_pi = extract_pitches_intervals_from_notes(pns_non_drum) rns_drum_pi = extract_pitches_intervals_from_notes(rns_drum, is_drum=True) rns_non_drum_pi = extract_pitches_intervals_from_notes(rns_non_drum) # Compute file-wise PRF for drums drum_metric = mir_eval_note_f1(pns_drum_pi['pitches'], pns_drum_pi['intervals'], rns_drum_pi['pitches'], rns_drum_pi['intervals'], onset_tolerance=onset_tolerance, is_drum=True, add_micro_metric=add_micro_metric) # Compute file-wise PRF for non-drums non_drum_metric = mir_eval_note_f1(pns_non_drum_pi['pitches'], pns_non_drum_pi['intervals'], rns_non_drum_pi['pitches'], rns_non_drum_pi['intervals'], onset_tolerance=onset_tolerance, is_drum=False) # Compute file-wise frame PRF for non-drums if add_frame_metric is True: # Extract frame-level Pitches (freq) and Intervals pns_non_drum_tf = extract_frame_time_freq_from_notes(pns_non_drum) rns_non_drum_tf = extract_frame_time_freq_from_notes(rns_non_drum) res = mir_eval_frame_f1(pns_non_drum_tf, rns_non_drum_tf) non_drum_metric = {**non_drum_metric, **res} # merge dicts ############## Compute instrument-wise PRF for non-drums ############## if eval_vocab is None: return drum_metric, non_drum_metric, {} else: instr_metric = {} for group_name, programs in eval_vocab.items(): # Extract notes for each instrument # bug fix for piano/drum overlap on slakh pns_group = [note for note in pns_non_drum if note.program in programs] rns_group = [note for note in rns_non_drum if note.program in programs] # Compute PC instrument-wise PRF using pitch class (currently for bass) if add_pitch_class_metric is not None: if group_name.lower() in [g.lower() for g in add_pitch_class_metric]: # pc: pitch information is converted to pitch classe e.g. 0-11 pns_pc_group = extract_pitches_intervals_from_notes(notes2pc_notes(pns_group)) rns_pc_group = extract_pitches_intervals_from_notes(notes2pc_notes(rns_group)) _instr_pc_metric = mir_eval_note_f1(pns_pc_group['pitches'], pns_pc_group['intervals'], rns_pc_group['pitches'], rns_pc_group['intervals'], onset_tolerance=onset_tolerance, is_drum=False, add_micro_metric=add_micro_metric, suffix=group_name + '_pc') # Add to instrument-wise PRF for k, v in _instr_pc_metric.items(): instr_metric[k] = v # Extract Pitches (freq) and Intervals pns_group = extract_pitches_intervals_from_notes(pns_group) rns_group = extract_pitches_intervals_from_notes(rns_group) # Compute instrument-wise PRF _instr_metric = mir_eval_note_f1(pns_group['pitches'], pns_group['intervals'], rns_group['pitches'], rns_group['intervals'], onset_tolerance=onset_tolerance, is_drum=False, add_micro_metric=add_micro_metric, suffix=group_name) # Merge instrument-wise PRF for k, v in _instr_metric.items(): instr_metric[k] = v # Optionally compute melody metrics: RPA, RCA, OA if add_melody_metric is not None: if group_name.lower() in [g.lower() for g in add_melody_metric]: _melody_metric = mir_eval_melody_metric(pns_group['pitches'], pns_group['intervals'], rns_group['pitches'], rns_group['intervals'], cent_tolerance=50, suffix=group_name) for k, v in _melody_metric.items(): instr_metric[k] = v # Calculate multi_f metric for this track if add_multi_f_metric is True: drum_micro_onset_tp_sum, drum_micro_onset_tpfp_sum, drum_micro_onset_tpfn_sum = 0., 0., 0. non_drum_micro_offset_tp_sum, non_drum_micro_offset_tpfp_sum, non_drum_micro_offset_tpfn_sum = 0., 0., 0. # Collect offset metric for non-drum notes for k, v in instr_metric.items(): if 'micro_offset_p_' in k and not np.isnan(v['value']): non_drum_micro_offset_tp_sum += v['value'] * v['weight'] non_drum_micro_offset_tpfp_sum += v['weight'] if 'micro_offset_r_' in k and not np.isnan(v['value']): non_drum_micro_offset_tpfn_sum += v['weight'] # Collect onset metric for drum notes for k, v in drum_metric.items(): if 'micro_onset_p_drum' in k and not np.isnan(v['value']): drum_micro_onset_tp_sum += v['value'] * v['weight'] drum_micro_onset_tpfp_sum += v['weight'] if 'micro_onset_r_drum' in k and not np.isnan(v['value']): drum_micro_onset_tpfn_sum += v['weight'] tp = non_drum_micro_offset_tp_sum + drum_micro_onset_tp_sum tpfp = non_drum_micro_offset_tpfp_sum + drum_micro_onset_tpfp_sum tpfn = non_drum_micro_offset_tpfn_sum + drum_micro_onset_tpfn_sum multi_p_track = tp / tpfp if tpfp > 0 else np.nan multi_r_track = tp / tpfn if tpfn > 0 else np.nan multi_f_track = f1_measure(multi_p_track, multi_r_track) instr_metric['multi_f'] = multi_f_track return drum_metric, non_drum_metric, instr_metric