File size: 2,642 Bytes
a03c9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
""" assert_fn.py """
import numpy as np


def assert_notes_almost_equal(actual_notes, predicted_notes, delta=5e-3):
    """
    Asserts that the given lists of Note instances are equal up to a small
    floating-point tolerance, similar to `assertAlmostEqual` of `unittest`.
    Tolerance is 5e-3 by default, which is 5 ms for 100 ticks-per-second.
    """
    assert len(actual_notes) == len(predicted_notes)
    for actual_note, predicted_note in zip(actual_notes, predicted_notes):
        assert abs(actual_note.onset - predicted_note.onset) < delta
        assert abs(actual_note.offset - predicted_note.offset) < delta
        assert actual_note.pitch == predicted_note.pitch
        if actual_note.is_drum is False and predicted_note.is_drum is False:
            assert actual_note.program == predicted_note.program
        assert actual_note.is_drum == predicted_note.is_drum
        assert actual_note.velocity == predicted_note.velocity


def assert_note_events_almost_equal(actual_note_events,
                                    predicted_note_events,
                                    ignore_time=False,
                                    ignore_activity=True,
                                    delta=5.1e-3):
    """
    Asserts that the given lists of Note instances are equal up to a small
    floating-point tolerance, similar to `assertAlmostEqual` of `unittest`.
    Tolerance is 5e-3 by default, which is 5 ms for 100 ticks-per-second.

    If `ignore_time` is True, then the time field is ignored. (useful for 
    comparing tie note events, default is False)

    If `ignore_activity` is True, then the activity field is ignored (default
    is True).
    """
    assert len(actual_note_events) == len(predicted_note_events)
    for j, (actual_note_event,
            predicted_note_event) in enumerate(zip(actual_note_events, predicted_note_events)):
        if ignore_time is False:
            assert abs(actual_note_event.time - predicted_note_event.time) <= delta
        assert actual_note_event.is_drum == predicted_note_event.is_drum
        if actual_note_event.is_drum is False and predicted_note_event.is_drum is False:
            assert actual_note_event.program == predicted_note_event.program
        assert actual_note_event.pitch == predicted_note_event.pitch
        assert actual_note_event.velocity == predicted_note_event.velocity
        if ignore_activity is False:
            assert actual_note_event.activity == predicted_note_event.activity


def assert_track_metrics_score1(metrics) -> None:
    for k, v in metrics.items():
        if np.isnan(v) is False:
            assert v == 1.0