|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" tokenizer.py: Encodes and decodes events to/from tokens. """ |
|
import numpy as np |
|
import warnings |
|
from abc import ABC, abstractmethod |
|
from utils.note_event_dataclasses import Event, EventRange, Note |
|
from utils.event_codec import FastCodec as Codec |
|
from utils.note_event_dataclasses import NoteEvent |
|
from utils.note2event import note_event2event |
|
from utils.event2note import event2note_event, note_event2note |
|
from typing import List, Optional, Union, Tuple, Dict, Counter |
|
|
|
|
|
|
|
class EventTokenizerBase(ABC): |
|
""" |
|
A base class for encoding and decoding events to and from tokens. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
base_codec: Union[Codec, str] = 'mt3', |
|
special_tokens: List[str] = ['PAD', 'EOS', 'UNK'], |
|
extra_tokens: List[str] = [], |
|
max_shift_steps: int = 206, |
|
program_vocabulary: Optional[Dict] = None, |
|
drum_vocabulary: Optional[Dict] = None, |
|
) -> None: |
|
""" |
|
Initializes the EventTokenizerBase object. |
|
|
|
:param base_codec: The codec to use for encoding and decoding. |
|
:param special_tokens: None or list of special tokens to include in the vocabulary. |
|
:param extra_tokens: None or list of tokens to be treated as additional special tokens. |
|
:param program_vocabulary: None or a dictionary mapping program names to program indices. |
|
:param drum_vocabulary: None or a dictionary mapping drum names to drum indices. |
|
:param max_shift_steps: The maximum number of shift steps to use for the codec. |
|
""" |
|
|
|
if isinstance(base_codec, str): |
|
|
|
if base_codec.lower() == 'mt3': |
|
event_ranges = [ |
|
EventRange('pitch', min_value=0, max_value=127), |
|
EventRange('velocity', min_value=0, max_value=1), |
|
EventRange('tie', min_value=0, max_value=0), |
|
EventRange('program', min_value=0, max_value=127), |
|
EventRange('drum', min_value=0, max_value=127), |
|
] |
|
else: |
|
raise ValueError(f'Unknown codec name: {base_codec}') |
|
|
|
|
|
self.codec = Codec(special_tokens=special_tokens + extra_tokens, |
|
max_shift_steps=max_shift_steps, |
|
event_ranges=event_ranges, |
|
program_vocabulary=program_vocabulary, |
|
drum_vocabulary=drum_vocabulary, |
|
name='mt3') |
|
|
|
elif isinstance(base_codec, Codec): |
|
|
|
self.codec = base_codec |
|
if program_vocabulary is not None or drum_vocabulary is not None: |
|
print('') |
|
warnings.warn("Vocabulary cannot be applied when using a custom codec.") |
|
else: |
|
|
|
raise TypeError(f'Unknown codec type: {type(base_codec)}') |
|
self.num_tokens = self.codec._num_classes |
|
|
|
def _encode(self, events: List[Event]) -> List[int]: |
|
return [self.codec.encode_event(e) for e in events] |
|
|
|
def _decode(self, tokens: List[int]) -> List[Event]: |
|
return [self.codec.decode_event_index(idx) for idx in tokens] |
|
|
|
@abstractmethod |
|
def encode(self): |
|
""" Encode your custom events to tokens. """ |
|
pass |
|
|
|
@abstractmethod |
|
def decode(self): |
|
""" Decode your custom tokens to events.""" |
|
pass |
|
|
|
|
|
class EventTokenizer(EventTokenizerBase): |
|
""" |
|
Eencoding and decoding events to and from tokens. |
|
""" |
|
|
|
def __init__(self, |
|
base_codec: Union[Codec, str] = 'mt3', |
|
special_tokens: List[str] = ['PAD', 'EOS', 'UNK'], |
|
extra_tokens: List[str] = [], |
|
max_shift_steps: int = 206, |
|
program_vocabulary: Optional[Dict] = None, |
|
drum_vocabulary: Optional[Dict] = None) -> None: |
|
""" |
|
Initializes the EventTokenizerBase object. |
|
|
|
:param codec: The codec to use for encoding and decoding. |
|
:param special_tokens: None or list of special tokens to include in the vocabulary. |
|
:param extra_tokens: None or list of tokens to be treated as additional special tokens. |
|
:param program_vocabulary: None or a dictionary mapping program names to program indices. |
|
:param drum_vocabulary: None or a dictionary mapping drum names to drum indices. |
|
:param max_shift_steps: The maximum number of shift steps to use for the codec. |
|
""" |
|
|
|
super().__init__( |
|
base_codec=base_codec, |
|
special_tokens=special_tokens, |
|
extra_tokens=extra_tokens, |
|
max_shift_steps=max_shift_steps, |
|
program_vocabulary=program_vocabulary, |
|
drum_vocabulary=drum_vocabulary, |
|
) |
|
|
|
def encode(self, events): |
|
""" Encode your custom events to tokens. """ |
|
return super()._encode(events) |
|
|
|
def decode(self, tokens): |
|
""" Decode your custom tokens to events.""" |
|
return super()._decode(tokens) |
|
|
|
|
|
class NoteEventTokenizer(EventTokenizerBase): |
|
""" Encodes and decodes note events to/from tokens. """ |
|
|
|
def __init__( |
|
self, |
|
base_codec: Union[Codec, str] = 'mt3', |
|
max_length: int = 1024, |
|
tps: int = 100, |
|
sort_note_event: bool = True, |
|
special_tokens: List[str] = ['PAD', 'EOS', 'UNK'], |
|
extra_tokens: List[str] = [], |
|
max_shift_steps: int = 206, |
|
program_vocabulary: Optional[Dict] = None, |
|
drum_vocabulary: Optional[Dict] = None, |
|
ignore_decoding_tokens: List[str] = [], |
|
ignore_decoding_tokens_from_and_to: Optional[List[str]] = None, |
|
debug_mode: bool = False) -> None: |
|
""" |
|
Initializes the TaskEventNoteTokenizer object. |
|
|
|
List[NoteEvent] -> encdoe_note_events -> np.ndarray[int] |
|
|
|
np.ndarray[int] -> decode_note_events -> Tuple[List[NoteEvent], List[NoteEvent]] |
|
|
|
:param codec: The codec to use for encoding and decoding. |
|
:param special_tokens: None or list of special tokens to include in the vocabulary. |
|
:param extra_tokens: None or list of tokens to be treated as additional special tokens. |
|
:param program_vocabulary: None or a dictionary mapping program names to program indices. |
|
:param drum_vocabulary: None or a dictionary mapping drum names to drum indices. |
|
:param max_shift_steps: The maximum number of shift steps to use for the codec. |
|
|
|
:param ignore_decoding_tokens: List of tokens to ignore during decoding. |
|
:param ignore_decoding_tokens_from_and_to: List of tokens to ignore during decoding. [from, to] |
|
""" |
|
super().__init__(base_codec=base_codec, |
|
special_tokens=special_tokens, |
|
extra_tokens=extra_tokens, |
|
max_shift_steps=max_shift_steps, |
|
program_vocabulary=program_vocabulary, |
|
drum_vocabulary=drum_vocabulary) |
|
self.max_length = max_length |
|
self.tps = tps |
|
self.sort = sort_note_event |
|
|
|
|
|
self._prefix = [] |
|
self._suffix = [] |
|
for stk in self.codec.special_tokens: |
|
if stk == 'EOS': |
|
self._suffix.append(self.codec.special_tokens.index('EOS')) |
|
elif stk == 'PAD': |
|
self._zero_pad = [0] * 1024 |
|
elif stk == 'UNK': |
|
pass |
|
else: |
|
pass |
|
|
|
self.eos_id = self.codec.special_tokens.index('EOS') |
|
self.pad_id = self.codec.special_tokens.index('PAD') |
|
self.ids_to_ignore_decoding = [self.codec.special_tokens.index(t) for t in ignore_decoding_tokens] |
|
self.ignore_tokens_from_and_to = ignore_decoding_tokens_from_and_to |
|
self.debug_mode = debug_mode |
|
|
|
def _decode(self, tokens): |
|
|
|
return super()._decode(tokens) |
|
|
|
def encode( |
|
self, |
|
note_events: List[NoteEvent], |
|
tie_note_events: Optional[List[NoteEvent]] = None, |
|
start_time: float = 0., |
|
) -> List[int]: |
|
""" Encodes note events and tie note events to tokens. """ |
|
events = note_event2event( |
|
note_events=note_events, |
|
tie_note_events=tie_note_events, |
|
start_time=start_time, |
|
tps=self.tps, |
|
sort=self.sort) |
|
return super()._encode(events) |
|
|
|
def encode_plus( |
|
self, |
|
note_events: List[NoteEvent], |
|
tie_note_events: Optional[List[NoteEvent]] = None, |
|
start_times: float = 0., |
|
add_special_tokens: Optional[bool] = True, |
|
max_length: Optional[int] = None, |
|
pad_to_max_length: Optional[bool] = True, |
|
return_attention_mask: bool = False) -> Union[List[int], Tuple[List[int], List[int]]]: |
|
""" Encodes note events and tie note info to padded tokens. """ |
|
encoded = self.encode(note_events, tie_note_events, start_times) |
|
|
|
|
|
|
|
if add_special_tokens: |
|
if self._prefix: |
|
encoded = self._prefix + encoded |
|
if self._suffix: |
|
encoded = encoded + self._suffix |
|
|
|
if max_length is None: |
|
max_length = self.max_length |
|
|
|
length = len(encoded) |
|
if length >= max_length: |
|
encoded = encoded[:max_length] |
|
length = max_length |
|
|
|
if return_attention_mask: |
|
attention_mask = [1] * length |
|
|
|
|
|
if pad_to_max_length is True: |
|
if len(self._zero_pad) != max_length: |
|
self._zero_pad = [self.pad_id] * max_length |
|
if return_attention_mask: |
|
attention_mask += self._zero_pad[length:] |
|
encoded = encoded + self._zero_pad[length:] |
|
|
|
if return_attention_mask: |
|
return encoded, attention_mask |
|
|
|
return encoded |
|
|
|
def encode_task(self, task_events: List[Event], max_length: Optional[int] = None) -> List[int]: |
|
|
|
encoded = super()._encode(task_events) |
|
|
|
|
|
if max_length is not None: |
|
if len(self._zero_pad_task) != max_length: |
|
self._zero_pad_task = [self.pad_id] * max_length |
|
length = len(encoded) |
|
encoded = encoded + self._zero_pad[length:] |
|
|
|
return encoded |
|
|
|
def decode( |
|
self, |
|
tokens: List[int], |
|
start_time: float = 0., |
|
return_events: bool = False, |
|
) -> Union[Tuple[List[NoteEvent], List[NoteEvent]], Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], |
|
List[Event], int]]: |
|
"""Decodes a sequence of tokens into note events. |
|
|
|
Args: |
|
tokens (List[int]): The list of tokens to be decoded. |
|
start_time (float, optional): The starting time for the note events. Defaults to 0. |
|
return_events (bool, optional): Indicates whether to include the raw events in the return value. |
|
Defaults to False. |
|
|
|
Returns: |
|
Union[Tuple[List[NoteEvent], List[NoteEvent]], |
|
Tuple[List[NoteEvent], List[NoteEvent], List[Event], int]]: The decoded note events. |
|
If `return_events` is False, the returned tuple contains `note_events`, `tie_note_events`, |
|
`last_activity`, and `err_cnt`. |
|
If `return_events` is True, the returned tuple contains `note_events`, `tie_note_events`, |
|
`last_activity`, `events`, and `err_cnt`. |
|
""" |
|
if self.debug_mode: |
|
ignored_tokens_from_input = [t for t in tokens if t in self.ids_to_ignore_decoding] |
|
print(ignored_tokens_from_input) |
|
|
|
if self.ids_to_ignore_decoding: |
|
tokens = [t for t in tokens if t not in self.ids_to_ignore_decoding] |
|
|
|
events = super()._decode(tokens) |
|
note_events, tie_note_events, last_activity, err_cnt = event2note_event(events, start_time, True, self.tps) |
|
if return_events: |
|
return note_events, tie_note_events, last_activity, events, err_cnt |
|
else: |
|
return note_events, tie_note_events, last_activity, err_cnt |
|
|
|
def decode_batch( |
|
self, |
|
batch_tokens: Union[List[List[int]], np.ndarray], |
|
start_times: List[float], |
|
return_events: bool = False |
|
) -> Union[Tuple[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]], int], |
|
Tuple[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]], List[List[Event]], |
|
Counter[str]]]: |
|
""" |
|
Decodes a batch of tokens to note_events and tie_note_events. |
|
|
|
Args: |
|
batch_tokens (List[List[int]] or np.ndarray): Tokens to be decoded. |
|
start_times (List[float]): List of start times for each token set. |
|
return_events (bool, optional): Flag to determine if events should be returned. Defaults to False. |
|
|
|
""" |
|
if isinstance(batch_tokens, np.ndarray): |
|
batch_tokens = batch_tokens.tolist() |
|
|
|
if len(batch_tokens) != len(start_times): |
|
raise ValueError('The length of batch_tokens and start_times must be same.') |
|
|
|
zipped_note_events_and_tie = [] |
|
list_events = [] |
|
total_err_cnt = 0 |
|
|
|
for tokens, start_time in zip(batch_tokens, start_times): |
|
if return_events: |
|
note_events, tie_note_events, last_activity, events, err_cnt = self.decode( |
|
tokens, start_time, return_events) |
|
list_events.append(events) |
|
else: |
|
note_events, tie_note_events, last_activity, err_cnt = self.decode(tokens, start_time, return_events) |
|
|
|
zipped_note_events_and_tie.append((note_events, tie_note_events, last_activity, start_time)) |
|
total_err_cnt += err_cnt |
|
|
|
if return_events: |
|
return zipped_note_events_and_tie, list_events, total_err_cnt |
|
else: |
|
return zipped_note_events_and_tie, total_err_cnt |
|
|
|
def decode_list_batches( |
|
self, |
|
list_batch_tokens: Union[List[List[List[int]]], List[np.ndarray]], |
|
list_start_times: Union[List[List[float]], List[float]], |
|
return_events: bool = False |
|
) -> Union[Tuple[List[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]]], Counter[str]], |
|
Tuple[List[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]]], |
|
List[List[Event]], Counter[str]]]: |
|
""" |
|
Decodes a list of variable-size batches of token array to a list of |
|
zipped note_events and tie_note_events. |
|
|
|
Args: |
|
list_batch_tokens: List[np.ndarray], where array shape is (batch_size, variable_length) |
|
list_start_times: List[float], where the length is sum of all batch_sizes. |
|
return_events: bool, Defaults to False. |
|
|
|
Returns: |
|
list_list_zipped_note_events_and_tie: |
|
List[ |
|
Tuple[ |
|
List[NoteEvent]: A list of note events. |
|
List[NoteEvent]: A list of tie note events. |
|
List[Tuple[int]]: A list of last activity of segment. [(program, pitch), ...]. This is useful |
|
for validating notes within a batch of segments extracted from a file. |
|
List[float]: A list of segment start times. |
|
] |
|
] |
|
(Optional) list_events: |
|
List[List[Event]] |
|
total_err_cnt: |
|
Counter[str]: error counter. |
|
""" |
|
list_tokens = [] |
|
for arr in list_batch_tokens: |
|
for tokens in arr: |
|
list_tokens.append(tokens) |
|
assert (len(list_tokens) == len(list_start_times)) |
|
|
|
zipped_note_events_and_tie = [] |
|
list_events = [] |
|
total_err_cnt = Counter() |
|
for tokens, start_time in zip(list_tokens, list_start_times): |
|
note_events, tie_note_events, last_activity, events, err_cnt = self.decode( |
|
tokens, start_time, return_events) |
|
zipped_note_events_and_tie.append((note_events, tie_note_events, last_activity, start_time)) |
|
if return_events: |
|
list_events.append(events) |
|
total_err_cnt += err_cnt |
|
|
|
if return_events: |
|
return zipped_note_events_and_tie, list_events, total_err_cnt |
|
else: |
|
return zipped_note_events_and_tie, total_err_cnt |
|
|