YourMT3 / amt /src /tests /event_codec_test.py
mimbres's picture
.
a03c9b4
raw
history blame
5.89 kB
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
"""event_codec_test.py:
This file contains tests for the following classes:
• Event
• EventRange
• FastCodec equivalent to MT3 author's Codec
See tokenizer_test.py for the FastCodec performance benchmark
"""
import unittest
from utils.note_event_dataclasses import Event, EventRange
from utils.event_codec import FastCodec as Codec
# from utils.event_codec import Codec
class TestEvent(unittest.TestCase):
def test_Event(self):
e = Event(type='shift', value=0)
self.assertEqual(e.type, 'shift')
self.assertEqual(e.value, 0)
class TestEventRange(unittest.TestCase):
def test_EventRange(self):
er = EventRange('abc', min_value=0, max_value=500)
self.assertEqual(er.type, 'abc')
self.assertEqual(er.min_value, 0)
self.assertEqual(er.max_value, 500)
class TestEventCodec(unittest.TestCase):
def test_event_codec(self):
ec = Codec(
special_tokens=['asd'],
max_shift_steps=1001,
event_ranges=[
EventRange('pitch', min_value=0, max_value=127),
EventRange('velocity', min_value=0, max_value=1),
EventRange('tie', min_value=0, max_value=0),
EventRange('program', min_value=0, max_value=127),
EventRange('drum', min_value=0, max_value=127),
],
)
events = [
Event(type='shift', value=0), # actually not needed
Event(type='shift', value=1), # 10 ms shift
Event(type='shift', value=1000), # 10 s shift
Event(type='pitch', value=0), # lowest pitch 8.18 Hz
Event(type='pitch', value=60), # C4 or 261.63 Hz
Event(type='pitch', value=127), # highest pitch G9 or 12543.85 Hz
Event(type='velocity', value=0), # lowest velocity)
Event(type='velocity', value=1), # lowest velocity)
Event(type='tie', value=0), # tie
Event(type='program', value=0), # program
Event(type='program', value=127), # program
Event(type='drum', value=0), # drum
Event(type='drum', value=127), # drum
]
encoded = [ec.encode_event(e) for e in events]
decoded = [ec.decode_event_index(idx) for idx in encoded]
self.assertSequenceEqual(events, decoded)
class TestEventCodecErrorCases(unittest.TestCase):
def setUp(self):
self.event_ranges = [
EventRange("program", 0, 127),
EventRange("pitch", 0, 127),
EventRange("velocity", 0, 3),
EventRange("drum", 0, 127),
EventRange("tie", 0, 1),
]
self.ec = Codec([], 1000, self.event_ranges)
def test_encode_event_with_invalid_event_type(self):
with self.assertRaises(ValueError):
self.ec.encode_event(Event("unknown_event_type", 50))
def test_encode_event_with_invalid_event_value(self):
with self.assertRaises(ValueError):
self.ec.encode_event(Event("program", 200))
def test_event_type_range_with_invalid_event_type(self):
with self.assertRaises(ValueError):
self.ec.event_type_range("unknown_event_type")
def test_decode_event_index_with_invalid_index(self):
with self.assertRaises(ValueError):
self.ec.decode_event_index(1000000)
class TestEventCodecVocabulary(unittest.TestCase):
def test_encode_event_using_program_vocabulary(self):
prog_vocab = {"Piano": [0, 1, 2, 3, 4, 5, 6, 7], "xxx": [50, 30, 120]}
ec = Codec(special_tokens=['asd'],
max_shift_steps=1001,
event_ranges=[
EventRange('pitch', min_value=0, max_value=127),
EventRange('velocity', min_value=0, max_value=1),
EventRange('tie', min_value=0, max_value=0),
EventRange('program', min_value=0, max_value=127),
EventRange('drum', min_value=0, max_value=127),
],
program_vocabulary=prog_vocab)
events = [
Event(type='program', value=0), # 0 --> 0
Event(type='program', value=7), # 7 --> 0
Event(type='program', value=111), # 111 --> 111
Event(type='program', value=30), # 30 --> 50
]
encoded = [ec.encode_event(e) for e in events]
expected = [1133, 1133, 1244, 1183]
self.assertSequenceEqual(encoded, expected)
def test_encode_event_using_drum_vocabulary(self):
drum_vocab = {"Kick": [50, 51, 52], "Snare": [53, 54]}
ec = Codec(special_tokens=['asd'],
max_shift_steps=1001,
event_ranges=[
EventRange('pitch', min_value=0, max_value=127),
EventRange('velocity', min_value=0, max_value=1),
EventRange('tie', min_value=0, max_value=0),
EventRange('program', min_value=0, max_value=127),
EventRange('drum', min_value=0, max_value=127),
],
drum_vocabulary=drum_vocab)
events = [
Event(type='drum', value=50),
Event(type='drum', value=51),
Event(type='drum', value=53),
Event(type='drum', value=54),
]
encoded = [ec.encode_event(e) for e in events]
self.assertEqual(encoded[0], encoded[1])
self.assertEqual(encoded[2], encoded[3])
if __name__ == '__main__':
unittest.main()