|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""event_codec_test.py: |
|
|
|
This file contains tests for the following classes: |
|
• Event |
|
• EventRange |
|
• FastCodec equivalent to MT3 author's Codec |
|
|
|
See tokenizer_test.py for the FastCodec performance benchmark |
|
|
|
""" |
|
import unittest |
|
from utils.note_event_dataclasses import Event, EventRange |
|
from utils.event_codec import FastCodec as Codec |
|
|
|
|
|
|
|
class TestEvent(unittest.TestCase): |
|
|
|
def test_Event(self): |
|
e = Event(type='shift', value=0) |
|
self.assertEqual(e.type, 'shift') |
|
self.assertEqual(e.value, 0) |
|
|
|
|
|
class TestEventRange(unittest.TestCase): |
|
|
|
def test_EventRange(self): |
|
er = EventRange('abc', min_value=0, max_value=500) |
|
self.assertEqual(er.type, 'abc') |
|
self.assertEqual(er.min_value, 0) |
|
self.assertEqual(er.max_value, 500) |
|
|
|
|
|
class TestEventCodec(unittest.TestCase): |
|
|
|
def test_event_codec(self): |
|
ec = Codec( |
|
special_tokens=['asd'], |
|
max_shift_steps=1001, |
|
event_ranges=[ |
|
EventRange('pitch', min_value=0, max_value=127), |
|
EventRange('velocity', min_value=0, max_value=1), |
|
EventRange('tie', min_value=0, max_value=0), |
|
EventRange('program', min_value=0, max_value=127), |
|
EventRange('drum', min_value=0, max_value=127), |
|
], |
|
) |
|
|
|
events = [ |
|
Event(type='shift', value=0), |
|
Event(type='shift', value=1), |
|
Event(type='shift', value=1000), |
|
Event(type='pitch', value=0), |
|
Event(type='pitch', value=60), |
|
Event(type='pitch', value=127), |
|
Event(type='velocity', value=0), |
|
Event(type='velocity', value=1), |
|
Event(type='tie', value=0), |
|
Event(type='program', value=0), |
|
Event(type='program', value=127), |
|
Event(type='drum', value=0), |
|
Event(type='drum', value=127), |
|
] |
|
|
|
encoded = [ec.encode_event(e) for e in events] |
|
decoded = [ec.decode_event_index(idx) for idx in encoded] |
|
self.assertSequenceEqual(events, decoded) |
|
|
|
|
|
class TestEventCodecErrorCases(unittest.TestCase): |
|
|
|
def setUp(self): |
|
self.event_ranges = [ |
|
EventRange("program", 0, 127), |
|
EventRange("pitch", 0, 127), |
|
EventRange("velocity", 0, 3), |
|
EventRange("drum", 0, 127), |
|
EventRange("tie", 0, 1), |
|
] |
|
self.ec = Codec([], 1000, self.event_ranges) |
|
|
|
def test_encode_event_with_invalid_event_type(self): |
|
with self.assertRaises(ValueError): |
|
self.ec.encode_event(Event("unknown_event_type", 50)) |
|
|
|
def test_encode_event_with_invalid_event_value(self): |
|
with self.assertRaises(ValueError): |
|
self.ec.encode_event(Event("program", 200)) |
|
|
|
def test_event_type_range_with_invalid_event_type(self): |
|
with self.assertRaises(ValueError): |
|
self.ec.event_type_range("unknown_event_type") |
|
|
|
def test_decode_event_index_with_invalid_index(self): |
|
with self.assertRaises(ValueError): |
|
self.ec.decode_event_index(1000000) |
|
|
|
|
|
class TestEventCodecVocabulary(unittest.TestCase): |
|
|
|
def test_encode_event_using_program_vocabulary(self): |
|
prog_vocab = {"Piano": [0, 1, 2, 3, 4, 5, 6, 7], "xxx": [50, 30, 120]} |
|
ec = Codec(special_tokens=['asd'], |
|
max_shift_steps=1001, |
|
event_ranges=[ |
|
EventRange('pitch', min_value=0, max_value=127), |
|
EventRange('velocity', min_value=0, max_value=1), |
|
EventRange('tie', min_value=0, max_value=0), |
|
EventRange('program', min_value=0, max_value=127), |
|
EventRange('drum', min_value=0, max_value=127), |
|
], |
|
program_vocabulary=prog_vocab) |
|
|
|
events = [ |
|
Event(type='program', value=0), |
|
Event(type='program', value=7), |
|
Event(type='program', value=111), |
|
Event(type='program', value=30), |
|
] |
|
encoded = [ec.encode_event(e) for e in events] |
|
expected = [1133, 1133, 1244, 1183] |
|
self.assertSequenceEqual(encoded, expected) |
|
|
|
def test_encode_event_using_drum_vocabulary(self): |
|
drum_vocab = {"Kick": [50, 51, 52], "Snare": [53, 54]} |
|
ec = Codec(special_tokens=['asd'], |
|
max_shift_steps=1001, |
|
event_ranges=[ |
|
EventRange('pitch', min_value=0, max_value=127), |
|
EventRange('velocity', min_value=0, max_value=1), |
|
EventRange('tie', min_value=0, max_value=0), |
|
EventRange('program', min_value=0, max_value=127), |
|
EventRange('drum', min_value=0, max_value=127), |
|
], |
|
drum_vocabulary=drum_vocab) |
|
|
|
events = [ |
|
Event(type='drum', value=50), |
|
Event(type='drum', value=51), |
|
Event(type='drum', value=53), |
|
Event(type='drum', value=54), |
|
] |
|
encoded = [ec.encode_event(e) for e in events] |
|
self.assertEqual(encoded[0], encoded[1]) |
|
self.assertEqual(encoded[2], encoded[3]) |
|
|
|
|
|
if __name__ == '__main__': |
|
unittest.main() |
|
|