File size: 4,576 Bytes
a03c9b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import unittest
import pytest
from numpy import random
from utils.note_event_dataclasses import NoteEvent, Event, EventRange
from utils.event_codec import FastCodec as Codec
from utils.tokenizer import EventTokenizer
from utils.tokenizer import NoteEventTokenizer
class TestEventTokenizerBase(unittest.TestCase):
def test_encode_and_decode(self):
tokenizer = EventTokenizer()
events = [
Event('pitch', 64),
Event('velocity', 1),
Event('tie', 0),
Event('program', 10),
Event('drum', 0)
]
tokens = tokenizer.encode(events)
decoded_events = tokenizer.decode(tokens)
self.assertEqual(events, decoded_events)
def test_unknown_codec_name(self):
with self.assertRaises(ValueError):
EventTokenizer(base_codec='unknown')
def test_unknown_codec_type(self):
with self.assertRaises(TypeError):
EventTokenizer(base_codec=123)
def test_encode_and_decode_with_custom_codec(self):
special_tokens = ['PAD', 'EOS', 'SOS', 'T']
max_shift_steps = 100
event_ranges = [
EventRange('eat', min_value=0, max_value=9),
EventRange('sleep', min_value=0, max_value=9),
EventRange('play', min_value=0, max_value=1)
]
my_codec = Codec(special_tokens, max_shift_steps, event_ranges)
tokenizer = EventTokenizer(my_codec)
events = [
Event('eat', 3),
Event('shift', 9),
Event('sleep', 9),
Event('shift', 20),
Event('play', 1)
]
tokens = tokenizer.encode(events)
# 0~3: special tokens
# 4~103: shift tokens
# 104~112: eat tokens
# 113~121: sleep tokens
# 122~123: play tokens
expected_tokens = [107, 13, 123, 24, 125]
self.assertEqual(tokens, expected_tokens)
decoded_events = tokenizer.decode(tokens)
self.assertEqual(events, decoded_events)
class TestEventTokenizerBaseProcessTime(unittest.TestCase):
def setUp(self) -> None:
self.tokenizer = EventTokenizer('mt3')
self.random_tokens = random.randint(0, 500, size=333)
self.events = [
Event(type='pitch', value=60),
Event(type='velocity', value=1),
Event(type='program', value=0),
Event(type='shift', value=10),
Event(type='tie', value=0),
Event(type='drum', value=0),
] * 55
@pytest.mark.timeout(0.008) # 32 ms --> 8 ms
def test_event_tokenizer_encode(self):
for i in range(64):
encoded = self.tokenizer.encode(self.events)
@pytest.mark.timeout(0.01) # 40 ms --> 10 ms
def test_event_tokenizer_decode(self):
for i in range(64):
decoded = self.tokenizer.decode(self.random_tokens)
# yapf: disable
class NoteEventTokenizerTest(unittest.TestCase):
def test_note_event_tokenizer_encode(self):
tokenizer = NoteEventTokenizer()
note_events = [
NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60, activity=set()),
NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36, activity=set()),
NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity=set())
]
tokens = tokenizer.encode(note_events)
decoded_events, decoded_tie_events, last_activity, err_cnt = tokenizer.decode(tokens)
self.assertSequenceEqual(note_events, decoded_events)
self.assertSequenceEqual([], decoded_tie_events)
self.assertEqual(len(last_activity), 0)
self.assertEqual(len(err_cnt), 0)
def test_note_event_tokenizer_encode_plus(self):
tokenizer = NoteEventTokenizer()
note_events = [
NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60, activity=set()),
NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36, activity=set()),
NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity=set())
]
tokens = tokenizer.encode_plus(note_events, max_length=30)
decoded_events, decoded_tie_events, last_activity, err_cnt = tokenizer.decode(tokens)
self.assertSequenceEqual(note_events, decoded_events)
self.assertSequenceEqual([], decoded_tie_events)
self.assertEqual(len(last_activity), 0)
self.assertEqual(len(err_cnt), 0)
# yapf: enable
if __name__ == '__main__':
unittest.main() |