|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
from typing import Optional, Union, Tuple, Dict, Any, List, Counter |
|
from utils.note_event_dataclasses import NoteEvent, Event, NoteEventListsBundle |
|
from config.task import task_cfg |
|
from config.config import model_cfg |
|
from utils.tokenizer import NoteEventTokenizer |
|
from utils.utils import create_program2channel_vocab |
|
from utils.note2event import separate_channel_by_program_group_from_note_event_lists_bundle |
|
|
|
SINGING_PROGRAM = 100 |
|
DRUM_PROGRAM = 128 |
|
UNANNOTATED_PROGRAM = 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TaskManager: |
|
""" |
|
The TaskManager class manages tasks for training. It is initialized with a task name and retrieves |
|
the corresponding configuration from the task_cfg dictionary defined in config/task.py. |
|
|
|
Attributes: |
|
# Basic |
|
task_name (str): The name of the task being managed. |
|
base_codec (str): The base codec associated with the task. |
|
train_program_vocab (dict): The program vocabulary used for training. |
|
train_drum_vocab (dict): The drum vocabulary used for training. |
|
subtask_tokens (list): Additional tokens specific to subtasks, if any. |
|
extra_tokens (list): Extra tokens used in the task, including subtask tokens. |
|
ignore_decoding_tokens (list): Tokens to ignore during decoding. |
|
ignore_decoding_tokens_by_delimiter (Optional, list[str, str]): Tokens to ignore during decoding by delimiters. Default is None. |
|
tokenizer (NoteEventTokenizer): An instance of the NoteEventTokenizer class for tokenizing note events. |
|
eval_subtask_prefix (dict): A dictionary defining evaluation subtask prefixes to tokens. |
|
|
|
# Multi-channel decoding task exclusive |
|
num_decoding_channels (int): The number of decoding channels. |
|
max_token_length_per_ch (int): The maximum token length per channel. |
|
mask_loss_strategy (str): The mask loss strategy to use. NOT IMPLEMENTED YET. |
|
program2channel_vocab (dict): A dictionary mapping program to channel. |
|
|
|
Methods: |
|
get_tokenizer(): Returns the tokenizer instance associated with the task. |
|
set_tokenizer(): Initializes the tokenizer using the NoteEventTokenizer class with the appropriate parameters. |
|
""" |
|
|
|
def __init__(self, task_name: str = "mt3_full_plus", max_shift_steps: int = 206, debug_mode: bool = False): |
|
""" |
|
Initializes a TaskManager object with the specified task name. |
|
|
|
Args: |
|
task_name (str): The name of the task to manage. |
|
max_shift_steps (int): The maximum shift steps for the tokenizer. Default is 206. Definable in config/config.py. |
|
debug_mode (bool): Whether to enable debug mode. Default is False. |
|
""" |
|
self.debug_mode = debug_mode |
|
self.task_name = task_name |
|
|
|
if task_name not in task_cfg.keys(): |
|
raise ValueError("Invalid task name") |
|
else: |
|
self.task = task_cfg[task_name] |
|
|
|
|
|
self.base_codec = self.task.get("base_codec", "mt3") |
|
self.train_program_vocab = self.task["train_program_vocab"] |
|
self.train_drum_vocab = self.task["train_drum_vocab"] |
|
self.subtask_tokens = self.task.get("subtask_tokens", []) |
|
self.extra_tokens = self.subtask_tokens + self.task.get("extra_tokens", []) |
|
self.ignore_decoding_tokens = self.task.get("ignore_decoding_tokens", []) |
|
self.ignore_decoding_tokens_from_and_to = self.task.get("ignore_decoding_tokens_from_and_to", None) |
|
self.max_note_token_length = self.task.get("max_note_token_length", model_cfg["event_length"]) |
|
self.max_task_token_length = self.task.get("max_task_token_length", 0) |
|
self.padding_task_token = self.task.get("padding_task_token", False) |
|
self._eval_subtask_prefix = self.task.get("eval_subtask_prefix", None) |
|
self.eval_subtask_prefix_dict = {} |
|
|
|
|
|
self.num_decoding_channels = self.task.get("num_decoding_channels", 1) |
|
if self.num_decoding_channels > 1: |
|
program2channel_vocab_source = self.task.get("program2channel_vocab_source", None) |
|
if program2channel_vocab_source is None: |
|
program2channel_vocab_source = self.train_program_vocab |
|
|
|
|
|
if self.num_decoding_channels == len(program2channel_vocab_source) + 1: |
|
self.program2channel_vocab, _ = create_program2channel_vocab(program2channel_vocab_source) |
|
else: |
|
raise ValueError("Invalid num_decoding_channels, or program2channel_vocab not provided") |
|
|
|
self.max_note_token_length_per_ch = self.task.get("max_note_token_length_per_ch") |
|
self.mask_loss_strategy = self.task.get("mask_loss_strategy", None) |
|
else: |
|
self.max_note_token_length_per_ch = self.max_note_token_length |
|
|
|
|
|
self.max_total_token_length = self.max_note_token_length_per_ch + self.max_task_token_length |
|
|
|
|
|
self.max_shift_steps = max_shift_steps |
|
|
|
|
|
self.set_tokenizer() |
|
self.set_eval_task_prefix() |
|
self.num_tokens = self.tokenizer.num_tokens |
|
self.inverse_vocab_program = self.tokenizer.codec.inverse_vocab_program |
|
|
|
def set_eval_task_prefix(self) -> None: |
|
""" |
|
Sets the evaluation task prefix for the task. |
|
|
|
Example: |
|
self.eval_task_prefix_dict = { |
|
"default": [Event("transcribe_all", 0), Event("task", 0)], |
|
"singing-only": [Event("transcribe_singing", 0), Event("task", 0)] |
|
} |
|
""" |
|
if self._eval_subtask_prefix is not None: |
|
assert "default" in self._eval_subtask_prefix.keys() |
|
for key, val in self._eval_subtask_prefix.items(): |
|
if self.padding_task_token: |
|
self.eval_subtask_prefix_dict[key] = self.tokenizer.encode_task( |
|
val, max_length=self.max_task_token_length) |
|
else: |
|
self.eval_subtask_prefix_dict[key] = self.tokenizer.encode_task(val) |
|
else: |
|
self.eval_subtask_prefix_dict["default"] = [] |
|
|
|
def get_eval_subtask_prefix_dict(self) -> dict: |
|
return self.eval_subtask_prefix_dict |
|
|
|
def get_tokenizer(self) -> NoteEventTokenizer: |
|
""" |
|
Returns the tokenizer instance associated with the task. |
|
|
|
Returns: |
|
NoteEventTokenizer: The tokenizer instance. |
|
""" |
|
return self.tokenizer |
|
|
|
def set_tokenizer(self) -> None: |
|
""" |
|
Initializes the tokenizer using the NoteEventTokenizer class with the appropriate parameters. |
|
""" |
|
self.tokenizer = NoteEventTokenizer(base_codec=self.base_codec, |
|
max_length=self.max_total_token_length, |
|
program_vocabulary=self.train_program_vocab, |
|
drum_vocabulary=self.train_drum_vocab, |
|
special_tokens=['PAD', 'EOS', 'UNK'], |
|
extra_tokens=self.extra_tokens, |
|
max_shift_steps=self.max_shift_steps, |
|
ignore_decoding_tokens=self.ignore_decoding_tokens, |
|
ignore_decoding_tokens_from_and_to=self.ignore_decoding_tokens_from_and_to, |
|
debug_mode=self.debug_mode) |
|
|
|
|
|
def tokenize_task_and_note_events_batch( |
|
self, |
|
programs_segments: List[List[int]], |
|
has_unannotated_segments: List[bool], |
|
note_event_segments: NoteEventListsBundle, |
|
subunit_programs_segments: Optional[List[List[np.ndarray]]] = None, |
|
subunit_note_event_segments: Optional[List[NoteEventListsBundle]] = None, |
|
stage: str = 'train' |
|
): |
|
"""Tokenizes a batch of note events into a batch of encoded tokens. |
|
Optionally, appends task tokens to the note event tokens. |
|
|
|
Args: |
|
programs_segments (List[int]): A list of program numbers. |
|
has_unannotated_segments (bool): Whether the batch has unannotated segments. |
|
note_event_segments (NoteEventListsBundle): A bundle of note events. |
|
subunit_programs_segments (Optional[List[List[np.ndarray]]]): A list of subunit programs. |
|
subunit_note_event_segments (Optional[List[NoteEventListsBundle]]): A list of subunit note events. |
|
|
|
Returns: |
|
np.ndarray: A batch of encoded tokens, with shape (B, C, L). |
|
""" |
|
if self.task_name == 'exclusive': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raise NotImplementedError("Exclusive transcription task is not implemented yet.") |
|
else: |
|
|
|
return self.tokenize_note_events_batch(note_event_segments) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tokenize_note_events_batch(self, |
|
note_event_segments: NoteEventListsBundle, |
|
start_time_to_zero: bool = False, |
|
sort: bool = True) -> np.ndarray: |
|
"""Tokenizes a batch of note events into a batch of encoded tokens. |
|
|
|
Args: |
|
note_event_segments (NoteEventListsBundle): A bundle of note events. |
|
|
|
Returns: |
|
np.ndarray: A batch of encoded tokens, with shape (B, C, L). |
|
""" |
|
batch_sz = len(note_event_segments["note_events"]) |
|
note_token_array = np.zeros((batch_sz, self.num_decoding_channels, self.max_note_token_length_per_ch), |
|
dtype=np.int32) |
|
|
|
if self.num_decoding_channels == 1: |
|
|
|
zipped_events = list(zip(*note_event_segments.values())) |
|
for b in range(batch_sz): |
|
note_token_array[b, 0, :] = self.tokenizer.encode_plus(*zipped_events[b], |
|
max_length=self.max_note_token_length, |
|
pad_to_max_length=True) |
|
elif self.num_decoding_channels > 1: |
|
|
|
ch_sep_ne_bundle = separate_channel_by_program_group_from_note_event_lists_bundle( |
|
source_note_event_lists_bundle=note_event_segments, |
|
num_program_groups=self.num_decoding_channels, |
|
program2channel_vocab=self.program2channel_vocab, |
|
start_time_to_zero=start_time_to_zero, |
|
sort=sort) |
|
|
|
for b in range(batch_sz): |
|
zipped_channel = list(zip(*ch_sep_ne_bundle[b].values())) |
|
for c in range(self.num_decoding_channels): |
|
note_token_array[b, c, :] = self.tokenizer.encode_plus(*zipped_channel[c], |
|
max_length=self.max_note_token_length_per_ch, |
|
pad_to_max_length=True) |
|
return note_token_array |
|
|
|
def tokenize_note_events(self, |
|
note_events: List[NoteEvent], |
|
tie_note_events: Optional[List[NoteEvent]] = None, |
|
start_time: float = 0., |
|
**kwargs: Any) -> List[int]: |
|
"""(Deprecated) Tokenizes a sequence of note events into a sequence of encoded tokens.""" |
|
return self.tokenizer.encode_plus(note_events, tie_note_events, start_time, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tokenize_task_events(self, programs: List[int], has_unannotated: bool) -> List[int]: |
|
"""Tokenizes a sequence of programs into a sequence of encoded tokens. Used for training.""" |
|
if self.task_name == 'singing_drum_v1': |
|
if has_unannotated: |
|
if SINGING_PROGRAM in programs: |
|
task_events = [Event('transcribe_singing', 0), Event('task', 0)] |
|
elif DRUM_PROGRAM in programs: |
|
task_events = [Event('transcribe_drum', 0), Event('task', 0)] |
|
else: |
|
task_events = [Event('transcribe_all', 0), Event('task', 0)] |
|
else: |
|
return [] |
|
|
|
if self.padding_task_token: |
|
return self.tokenizer.encode_task(task_events, max_length=self.max_task_token_length) |
|
else: |
|
return self.tokenizer.encode_task(task_events) |
|
|
|
def detokenize( |
|
self, |
|
tokens: List[int], |
|
start_time: float = 0., |
|
return_events: bool = False |
|
) -> Union[Tuple[List[NoteEvent], List[NoteEvent]], Tuple[List[NoteEvent], List[NoteEvent], List[Event], int]]: |
|
"""Decodes a sequence of tokens into note events, ignoring specific token IDs. |
|
Returns: |
|
Union[Tuple[List[NoteEvent], List[NoteEvent]], |
|
Tuple[List[NoteEvent], List[NoteEvent], List[Event], int]]: The decoded note events. |
|
If `return_events` is False, the returned tuple contains `note_events`, `tie_note_events`, |
|
`last_activity`, and `err_cnt`. |
|
If `return_events` is True, the returned tuple contains `note_events`, `tie_note_events`, |
|
`last_activity`, `events`, and `err_cnt`. |
|
|
|
Notes: |
|
This decoding process ignores specific token IDs based on `self.ids_to_ignore_decoding` attribute. |
|
""" |
|
return self.tokenizer.decode(tokens=tokens, start_time=start_time, return_events=return_events) |
|
|
|
def detokenize_list_batches( |
|
self, |
|
list_batch_tokens: Union[List[List[List[int]]], List[np.ndarray]], |
|
list_start_times: Union[List[List[float]], List[float]], |
|
return_events: bool = False |
|
) -> Union[Tuple[List[List[Tuple[List[NoteEvent], List[NoteEvent], int, float]]], Counter[str]], Tuple[ |
|
List[List[Tuple[List[NoteEvent], List[NoteEvent], int, float]]], List[List[Event]], Counter[str]]]: |
|
""" Decodes a list of variable size batches of token array to a list of |
|
zipped note_events and tie_note_events. |
|
|
|
Args: |
|
list_batch_tokens: List[np.ndarray], where array shape is (batch_size, variable_length) |
|
list_start_times: List[float], where the length is sum of all batch_sizes. |
|
return_events: bool |
|
|
|
Returns: |
|
list_list_zipped_note_events_and_tie: |
|
List[ |
|
Tuple[ |
|
List[NoteEvent]: A list of note events. |
|
List[NoteEvent]: A list of tie note events. |
|
List[Tuple[int]]: A list of last activity of segment. [(program, pitch), ...]. This is useful |
|
for validating notes within a batch of segments extracted from a file. |
|
List[float]: A list of segment start times. |
|
] |
|
] |
|
(Optional) list_events: |
|
List[List[Event]] |
|
total_err_cnt: |
|
Counter[str]: error counter. |
|
|
|
""" |
|
return self.tokenizer.decode_list_batches(list_batch_tokens, list_start_times, return_events) |
|
|