File size: 3,875 Bytes
a03c9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
for n in range(1000):
    sampled_data = ds.__getitem__(n)

    a = deepcopy(sampled_data['note_event_segments'])
    b = deepcopy(sampled_data['note_event_segments'])

    for (note_events, tie_note_events, start_time) in list(zip(*b.values())):
        note_events = pitch_shift_note_events(note_events, 2)
        tie_note_events = pitch_shift_note_events(tie_note_events, 2)

    # compare
    for i, (note_events, tie_note_events, start_time) in enumerate(list(zip(*b.values()))):
        for j, ne in enumerate(note_events):
            if ne.is_drum is False:
                if ne.pitch != a['note_events'][i][j].pitch + 2:
                    print(i, j)
                assert ne.pitch == a['note_events'][i][j].pitch + 2

        for k, tne in enumerate(tie_note_events):
            assert tne.pitch == a['tie_note_events'][i][k].pitch + 2

    print('test {} passed'.format(n))


def assert_note_events_almost_equal(actual_note_events,
                                    predicted_note_events,
                                    ignore_time=False,
                                    ignore_activity=True,
                                    delta=5.1e-3):
    """
    Asserts that the given lists of Note instances are equal up to a small
    floating-point tolerance, similar to `assertAlmostEqual` of `unittest`.
    Tolerance is 5e-3 by default, which is 5 ms for 100 ticks-per-second.

    If `ignore_time` is True, then the time field is ignored. (useful for 
    comparing tie note events, default is False)

    If `ignore_activity` is True, then the activity field is ignored (default
    is True).
    """
    assert len(actual_note_events) == len(predicted_note_events)
    for j, (actual_note_event,
            predicted_note_event) in enumerate(zip(actual_note_events, predicted_note_events)):
        if ignore_time is False:
            assert abs(actual_note_event.time - predicted_note_event.time) <= delta
        assert actual_note_event.is_drum == predicted_note_event.is_drum
        if actual_note_event.is_drum is False and predicted_note_event.is_drum is False:
            assert actual_note_event.program == predicted_note_event.program
        assert actual_note_event.pitch == predicted_note_event.pitch
        assert actual_note_event.velocity == predicted_note_event.velocity
        if ignore_activity is False:
            assert actual_note_event.activity == predicted_note_event.activity


cache_old = deepcopy(dict(ds.cache))
for n in range(500):
    sampled_data = ds.__getitem__(n)
    cache_new = ds.cache
    cnt = 0
    for k, v in cache_new.items():
        if k in cache_old:
            cnt += 1
            assert (cache_new[k]['programs'] == cache_old[k]['programs']).all()
            assert (cache_new[k]['is_drum'] == cache_old[k]['is_drum']).all()
            assert (cache_new[k]['has_stems'] == cache_old[k]['has_stems'])
            assert (cache_new[k]['has_unannotated'] == cache_old[k]['has_unannotated'])
            assert (cache_new[k]['audio_array'] == cache_old[k]['audio_array']).all()

            for nes_new, nes_old in zip(cache_new[k]['note_event_segments']['note_events'],
                                        cache_old[k]['note_event_segments']['note_events']):
                assert_note_events_almost_equal(nes_new, nes_old)

            for tnes_new, tnes_old in zip(cache_new[k]['note_event_segments']['tie_note_events'],
                                          cache_old[k]['note_event_segments']['tie_note_events']):
                assert_note_events_almost_equal(tnes_new, tnes_old, ignore_time=True)

            for s_new, s_old in zip(cache_new[k]['note_event_segments']['start_times'],
                                    cache_old[k]['note_event_segments']['start_times']):
                assert s_new == s_old
    cache_old = deepcopy(dict(ds.cache))
    print(n, cnt)