File size: 21,758 Bytes
a03c9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
"""metrics.py"""

from typing import List, Any, Dict, Optional, Tuple, Union
import numpy as np
import copy
from torch.nn import Module

from utils.note_event_dataclasses import NoteEvent, Note
from utils.note2event import sort_notes, notes2pc_notes
from utils.event2note import note_event2note
from sklearn.metrics import average_precision_score
from utils.metrics_helper import (f1_measure, round_float, mir_eval_note_f1, mir_eval_frame_f1, mir_eval_melody_metric,
                                  extract_pitches_intervals_from_notes, extract_frame_time_freq_from_notes)
from torchmetrics import MeanMetric, SumMetric


class UpdatedMeanMetric(MeanMetric):
    """
    A wrapper of torchmetrics.MeanMetric to support reset and update separately.
    """

    def __init__(self, nan_strategy: str = 'ignore', **kwargs) -> None:
        super().__init__(nan_strategy=nan_strategy, **kwargs)
        self._updated = False

    def update(self, *args, **kwargs):
        super().update(*args, **kwargs)
        self._updated = True

    def is_updated(self):
        return self._updated


class UpdatedSumMetric(SumMetric):
    """
    A wrapper of torchmetrics.SumMetric to support reset and update separately.
    """

    def __init__(self, nan_strategy: str = 'ignore', **kwargs) -> None:
        super().__init__(nan_strategy=nan_strategy, **kwargs)
        self._updated = False

    def update(self, *args, **kwargs):
        super().update(*args, **kwargs)
        self._updated = True

    def is_updated(self):
        return self._updated


class AMTMetrics(Module):
    """
    Automatic music transcription (AMT) evaluation metrics for music transcription 
    tasks with DDP support, following the convention of AMT. The average of file-wise
    metrics is calculated.
    
    Metrics:
    --------
    
    Instrument-agnostic note onset and note offset metrics:
    (Drum notes are generally excluded)
    
    - onset_f: the most conventional, often called Note F1
    - offset_f: a pair of onset + offset matching metric
    
    Multi-instrument note on-offset Macro-micro F1 metric, multi-F1 (of MT3):
    
    - multi_f: counts for onset + offset + program (instrument class) matching.
               For drum notes, we only count onset. macro-micro means that we
               calculate weighted precision and recall by counting each note 
               instrument class per file, and calcualte micro F1. We then 
               calculate average F1 for all files with equal weights (Macro).          

    Instrument-group note onset and offset metrics are defined by extra_classes:

    e.g. extra_classes = ['piano', 'guitar']
    - onset_f_piano: piano instrument
    - onset_f_guitar: guitar instrument
    - offset_f_piano: piano instrument
    - offset_f_guitar: guitar instrument
    also p, r metrics follow...


    Usage:
    ------
    
    Each metric instance can be individually updated and reset for computation.
    
    ```
    my_metric = AMTMetrics()
    my_metric.onset_f.update(0.5)
    my_metric.onset_f(0.5) # same
    my_metric.onset_f(0, weight=1.0) # same and weighted by 1.0 (default)
    my_metric.onset_f.compute() # return 0.333..
    my_metric.onset_f.reset() # reset the metric

    ```
    • {attribute}.update(value: float, weight: Optional[float]): Here weight is an
        optional argument for weighted average.
    • {attribute}.(...): Same as update method.
    • {attribute}.compute(): Return the average value of the metric.
    • {attribute}.reset(): Reset the metric.

    Class methods:
    ---------------
    
    ```
    d = {'onset_f': 0.5, 'offset_f': 0.5}
    my_metric.bulk_update(d)
    d = {'onset_f': {'value': 0.5, 'weight': 1.0}, 'offset_f': {'value': 0.5, 'weight': 1.0}}
    my_metric.onset_f.update(d)
    ```

    • bulk_update(metrics: Dict[str, Union[float, Dict[str, float]]]): Update metrics with a
        dictionary as an argument.
    • bulk_compute(): Return a dictionary of any non-empty metrics with average values. 
    • bulk_reset(): Reset all metrics.

    """

    def __init__(self,
                 prefix: str = '',
                 nan_strategy: str = 'ignore',
                 extra_classes: Optional[List[str]] = None,
                 extra_metrics: Optional[List[str]] = None,
                 error_types: Optional[List[str]] = None,
                 **kwargs) -> None:
        """
        Args:
            suffix: prefix for the metric name, e.g. 'val' or 'test'. '_' will be added automatically.
            nan_strategy: 'warn' or 'raise' or 'ignore'

        """
        super().__init__(**kwargs)
        self._prefix = prefix
        self.nan_strategy = nan_strategy

        # Instrument-agnostic Note onsets and Note on-offset metrics for non-drum notes
        self.onset_f = UpdatedMeanMetric(nan_strategy=nan_strategy)
        self.offset_f = UpdatedMeanMetric(nan_strategy=nan_strategy)

        # Instrument-agnostic Frame F1 (skip in validation)
        self.frame_f = UpdatedMeanMetric(nan_strategy=nan_strategy)
        self.frame_f_pc = UpdatedMeanMetric(nan_strategy=nan_strategy)

        # Drum Onset metrics
        self.onset_f_drum = UpdatedMeanMetric(nan_strategy=nan_strategy)

        # Multi F1 (Macro-micro F1 of MT3)
        self.multi_f = UpdatedMeanMetric(nan_strategy=nan_strategy)

        # Initialize extra metrics for instrument macro F1
        self.extra_classes = extra_classes
        if extra_classes is not None:
            for class_name in extra_classes:
                if not hasattr(self, class_name):
                    for onoff in ['onset', 'offset']:
                        for fpr in ['f']:
                            setattr(self, onoff + '_' + fpr + '_' + class_name,
                                    UpdatedMeanMetric(nan_strategy=nan_strategy))
                    # setattr(self, class_name, UpdatedMeanMetric(nan_strategy=nan_strategy))
                else:
                    raise ValueError(f"Metric '{class_name}' already exists.")

        # Initialize extra metrics for instruments(F is computed later)
        self.extra_classes = extra_classes
        if extra_classes is not None:
            for class_name in extra_classes:
                if not hasattr(self, class_name):
                    for onoff in ['micro_onset', 'micro_offset']:
                        for fpr in ['p', 'r']:
                            setattr(self, onoff + '_' + fpr + '_' + class_name,
                                    UpdatedMeanMetric(nan_strategy=nan_strategy))
                        # setattr(
                        #     self, onoff + '_f_' + class_name, None
                        # )  # micro_onset_f and micro_offset_f for each instrument
                else:
                    raise ValueError(f"Metric '{class_name}' already exists.")

        # Initialize drum micro P,R (F is computed later)
        self.micro_onset_p_drum = UpdatedMeanMetric(nan_strategy=nan_strategy)
        self.micro_onset_r_drum = UpdatedMeanMetric(nan_strategy=nan_strategy)

        # Initialize extra metrics directly
        if extra_metrics is not None:
            for metric_name in extra_metrics:
                setattr(self, metric_name, UpdatedMeanMetric(nan_strategy=nan_strategy))

        # Initialize error counters
        self.error_types = error_types
        if error_types is not None:
            for error_type in error_types:
                setattr(self, error_type, UpdatedMeanMetric(nan_strategy=nan_strategy))

    def bulk_update(self, metrics: Dict[str, Union[float, Dict[str, float], Tuple[float, ...]]]) -> None:
        """ Update metrics with a dictionary as an argument.

        metrics:
            {'onset_f': 0.5, 'offset_f': 0.5}
            or {'onset_f': {'value': 0.5, 'weight': 1.0}, 'offset_f': {'value': 0.5, 'weight': 1.0}} 
            or {'onset_p': (0.3, 5)}

        """
        for k, v in metrics.items():
            if isinstance(v, dict):
                getattr(self, k).update(**v)
            elif isinstance(v, tuple):
                getattr(self, k).update(*v)
            else:
                getattr(self, k).update(v)

    def bulk_update_errors(self, errors: Dict[str, Union[int, float]]) -> None:
        """ Update error counts with a dictionary as an argument. 

        errors:
            {'error_type_or_message_1': (int | float) count,
             'error_type_or_message_2': (int | float) count,}
        
        """
        for error_type, count in errors.items():
            # Update the error count
            if isinstance(count, int) or isinstance(count, float):
                getattr(self, error_type).update(count)
            else:
                raise ValueError(f"Count of error type '{error_type}' must be an integer or a float.")

    def bulk_compute(self) -> Dict[str, float]:
        computed_metrics = {}
        for k, v in self._modules.items():
            if isinstance(v, UpdatedMeanMetric) and v.is_updated():
                computed_metrics[self._prefix + k] = v.compute()
        # Create micro onset F1 for each instrument. Only when micro metrics are updated.
        extra_classes = self.extra_classes if self.extra_classes is not None else []
        for class_name in extra_classes + ['drum']:
            # micro onset F1 for each instrument.
            _micro_onset_p_instr = computed_metrics.get(self._prefix + 'micro_onset_p_' + class_name, None)
            _micro_onset_r_instr = computed_metrics.get(self._prefix + 'micro_onset_r_' + class_name, None)
            if _micro_onset_p_instr is not None and _micro_onset_r_instr is not None:
                computed_metrics[self._prefix + 'micro_onset_f_' + class_name] = f1_measure(
                    _micro_onset_p_instr.item(), _micro_onset_r_instr.item())
            # micro offset F1 for each instrument. 'drum' is usually not included.
            _micro_offset_p_instr = computed_metrics.get(self._prefix + 'micro_offset_p_' + class_name, None)
            _micro_offset_r_instr = computed_metrics.get(self._prefix + 'micro_offset_r_' + class_name, None)
            if _micro_offset_p_instr is not None and _micro_offset_r_instr is not None:
                computed_metrics[self._prefix + 'micro_offset_f_' + class_name] = f1_measure(
                    _micro_offset_p_instr.item(), _micro_offset_r_instr.item())

        # Remove micro onset and offset P,R (Now we have F1)
        for class_name in extra_classes + ['drum']:
            for onoff in ['micro_onset', 'micro_offset']:
                for pr in ['p', 'r']:
                    computed_metrics.pop(self._prefix + onoff + '_' + pr + '_' + class_name, None)

        return computed_metrics

    def bulk_reset(self) -> None:
        for k, v in self._modules.items():
            if isinstance(v, UpdatedMeanMetric):
                v.reset()
                v._updated = False


def compute_track_metrics(pred_notes: List[Note],
                          ref_notes: List[Note],
                          eval_vocab: Optional[Dict] = None,
                          eval_drum_vocab: Optional[Dict] = None,
                          onset_tolerance: float = 0.05,
                          add_pitch_class_metric: Optional[List[str]] = None,
                          add_melody_metric: Optional[List[str]] = None,
                          add_frame_metric: bool = False,
                          add_micro_metric: bool = False,
                          add_multi_f_metric: bool = False,
                          extra_info: Optional[Any] = None):
    """ Track metrics

    Args:
        pred_notes: (List[Note]) predicted sequence of notes for a track
        ref_notes: (List[Note]) reference sequence of notes for a track
        return_instr_metric: (bool) return instrument-specific metrics
        eval_vocab: (Dict or None) program group for instrument-specific metrics
            {
                instrument_or_group_name:
                  [program_number_0, program_number_1  ...]
            }
            If None, use default GM instruments.

            ex) eval_vocab = {"piano": np.arange(0, 8), ...}
        drum_vocab: (Dict or None) note (pitch) group for drum-specific metrics
            {
                instrument_or_group_name:
                    [note_number_0, note_number_1  ...]
            }
        add_pitch_class_metric: (List[str] or None) add pitch class metrics for the
            given instruments. The instrument names are defined in config/vocabulrary.py.
            ex) ['Bass', 'Guitar']
        add_singing_oa_metric: (bool) add melody overall accuracy for tje given instruments.
            The instrument names are defined in config/vocabulrary.py.
            ex) ['Singing Voice']
            (https://craffel.github.io/mir_eval/#mir_eval.melody.overall_accuracy
        add_frame_metric: (bool) add frame-wise metrics
        extra_info: (Any) extra information for debugging. Currently not implemented

    Returns:
        metrics: (Dict) track metrics in the AMTMetric format with attribute names such as 'onset_f_{instrument_or_group_name}'


    @dataclass
    class Note:
        is_drum: bool
        program: int
        onset: float
        offset: float
        pitch: int
        velocity: int

    Caution: Note is mutable instance, even if we use copy().

    """

    # Extract drum and non-drum notes
    def extract_drum_and_non_drum_notes(notes: List[Note]):
        drum_notes, non_drum_notes = [], []
        for note in notes:
            if note.is_drum:
                drum_notes.append(note)
            else:
                non_drum_notes.append(note)
        return drum_notes, non_drum_notes

    pns_drum, pns_non_drum = extract_drum_and_non_drum_notes(pred_notes)
    rns_drum, rns_non_drum = extract_drum_and_non_drum_notes(ref_notes)

    # Reduce drum notes to drum vocab
    def reduce_drum_notes_to_drum_vocab(notes: List[Note], drum_vocab: Dict):
        reduced_notes = []
        for note in notes:
            for drum_name, pitches in drum_vocab.items():
                if note.pitch in pitches:
                    new_note = copy.deepcopy(note)
                    new_note.pitch = pitches[0]
                    reduced_notes.append(new_note)
        return sort_notes(reduced_notes)

    if eval_drum_vocab != None:
        pns_drum = reduce_drum_notes_to_drum_vocab(pns_drum, eval_drum_vocab)
        rns_drum = reduce_drum_notes_to_drum_vocab(rns_drum, eval_drum_vocab)

    # Extract Pitches (freq) and Intervals
    pns_drum_pi = extract_pitches_intervals_from_notes(pns_drum, is_drum=True)
    pns_non_drum_pi = extract_pitches_intervals_from_notes(pns_non_drum)
    rns_drum_pi = extract_pitches_intervals_from_notes(rns_drum, is_drum=True)
    rns_non_drum_pi = extract_pitches_intervals_from_notes(rns_non_drum)

    # Compute file-wise PRF for drums
    drum_metric = mir_eval_note_f1(pns_drum_pi['pitches'],
                                   pns_drum_pi['intervals'],
                                   rns_drum_pi['pitches'],
                                   rns_drum_pi['intervals'],
                                   onset_tolerance=onset_tolerance,
                                   is_drum=True,
                                   add_micro_metric=add_micro_metric)

    # Compute file-wise PRF for non-drums
    non_drum_metric = mir_eval_note_f1(pns_non_drum_pi['pitches'],
                                       pns_non_drum_pi['intervals'],
                                       rns_non_drum_pi['pitches'],
                                       rns_non_drum_pi['intervals'],
                                       onset_tolerance=onset_tolerance,
                                       is_drum=False)

    # Compute file-wise frame PRF for non-drums
    if add_frame_metric is True:
        # Extract frame-level Pitches (freq) and Intervals
        pns_non_drum_tf = extract_frame_time_freq_from_notes(pns_non_drum)
        rns_non_drum_tf = extract_frame_time_freq_from_notes(rns_non_drum)

        res = mir_eval_frame_f1(pns_non_drum_tf, rns_non_drum_tf)
        non_drum_metric = {**non_drum_metric, **res}  # merge dicts

    ############## Compute instrument-wise PRF for non-drums ##############

    if eval_vocab is None:
        return drum_metric, non_drum_metric, {}
    else:
        instr_metric = {}
        for group_name, programs in eval_vocab.items():
            # Extract notes for each instrument
            # bug fix for piano/drum overlap on slakh
            pns_group = [note for note in pns_non_drum if note.program in programs]
            rns_group = [note for note in rns_non_drum if note.program in programs]

            # Compute PC instrument-wise PRF using pitch class (currently for bass)
            if add_pitch_class_metric is not None:
                if group_name.lower() in [g.lower() for g in add_pitch_class_metric]:
                    # pc: pitch information is converted to pitch classe e.g. 0-11
                    pns_pc_group = extract_pitches_intervals_from_notes(notes2pc_notes(pns_group))
                    rns_pc_group = extract_pitches_intervals_from_notes(notes2pc_notes(rns_group))

                    _instr_pc_metric = mir_eval_note_f1(pns_pc_group['pitches'],
                                                        pns_pc_group['intervals'],
                                                        rns_pc_group['pitches'],
                                                        rns_pc_group['intervals'],
                                                        onset_tolerance=onset_tolerance,
                                                        is_drum=False,
                                                        add_micro_metric=add_micro_metric,
                                                        suffix=group_name + '_pc')
                    # Add to instrument-wise PRF
                    for k, v in _instr_pc_metric.items():
                        instr_metric[k] = v

            # Extract Pitches (freq) and Intervals
            pns_group = extract_pitches_intervals_from_notes(pns_group)
            rns_group = extract_pitches_intervals_from_notes(rns_group)

            # Compute instrument-wise PRF
            _instr_metric = mir_eval_note_f1(pns_group['pitches'],
                                             pns_group['intervals'],
                                             rns_group['pitches'],
                                             rns_group['intervals'],
                                             onset_tolerance=onset_tolerance,
                                             is_drum=False,
                                             add_micro_metric=add_micro_metric,
                                             suffix=group_name)

            # Merge instrument-wise PRF
            for k, v in _instr_metric.items():
                instr_metric[k] = v

            # Optionally compute melody metrics: RPA, RCA, OA
            if add_melody_metric is not None:
                if group_name.lower() in [g.lower() for g in add_melody_metric]:
                    _melody_metric = mir_eval_melody_metric(pns_group['pitches'],
                                                            pns_group['intervals'],
                                                            rns_group['pitches'],
                                                            rns_group['intervals'],
                                                            cent_tolerance=50,
                                                            suffix=group_name)
                    for k, v in _melody_metric.items():
                        instr_metric[k] = v

        # Calculate multi_f metric for this track
        if add_multi_f_metric is True:
            drum_micro_onset_tp_sum, drum_micro_onset_tpfp_sum, drum_micro_onset_tpfn_sum = 0., 0., 0.
            non_drum_micro_offset_tp_sum, non_drum_micro_offset_tpfp_sum, non_drum_micro_offset_tpfn_sum = 0., 0., 0.
            # Collect offset metric for non-drum notes
            for k, v in instr_metric.items():
                if 'micro_offset_p_' in k and not np.isnan(v['value']):
                    non_drum_micro_offset_tp_sum += v['value'] * v['weight']
                    non_drum_micro_offset_tpfp_sum += v['weight']
                if 'micro_offset_r_' in k and not np.isnan(v['value']):
                    non_drum_micro_offset_tpfn_sum += v['weight']
            # Collect onset metric for drum notes
            for k, v in drum_metric.items():
                if 'micro_onset_p_drum' in k and not np.isnan(v['value']):
                    drum_micro_onset_tp_sum += v['value'] * v['weight']
                    drum_micro_onset_tpfp_sum += v['weight']
                if 'micro_onset_r_drum' in k and not np.isnan(v['value']):
                    drum_micro_onset_tpfn_sum += v['weight']

            tp = non_drum_micro_offset_tp_sum + drum_micro_onset_tp_sum
            tpfp = non_drum_micro_offset_tpfp_sum + drum_micro_onset_tpfp_sum
            tpfn = non_drum_micro_offset_tpfn_sum + drum_micro_onset_tpfn_sum
            multi_p_track = tp / tpfp if tpfp > 0 else np.nan
            multi_r_track = tp / tpfn if tpfn > 0 else np.nan
            multi_f_track = f1_measure(multi_p_track, multi_r_track)
            instr_metric['multi_f'] = multi_f_track

    return drum_metric, non_drum_metric, instr_metric