|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""ymt3.py""" |
|
import os |
|
from typing import Union, Optional, Tuple, Dict, List, Any |
|
from collections import Counter |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import CrossEntropyLoss |
|
import torchaudio |
|
import pytorch_lightning as pl |
|
import numpy as np |
|
import wandb |
|
from einops import rearrange |
|
|
|
from transformers import T5Config |
|
from model.t5mod import T5EncoderYMT3, T5DecoderYMT3, MultiChannelT5Decoder |
|
from model.t5mod_helper import task_cond_dec_generate |
|
from model.perceiver_mod import PerceiverTFEncoder |
|
from model.perceiver_helper import PerceiverTFConfig |
|
from model.conformer_mod import ConformerYMT3Encoder |
|
from model.conformer_helper import ConformerYMT3Config |
|
from model.lm_head import LMHead |
|
from model.pitchshift_layer import PitchShiftLayer |
|
from model.spectrogram import get_spectrogram_layer_from_audio_cfg |
|
from model.conv_block import PreEncoderBlockRes3B |
|
from model.conv_block import PreEncoderBlockHFTT, PreEncoderBlockRes3BHFTT |
|
from model.projection_layer import get_projection_layer, get_multi_channel_projection_layer |
|
from model.optimizers import get_optimizer |
|
from model.lr_scheduler import get_lr_scheduler |
|
|
|
from utils.note_event_dataclasses import Note |
|
from utils.note2event import mix_notes |
|
from utils.event2note import merge_zipped_note_events_and_ties_to_notes, DECODING_ERR_TYPES |
|
from utils.metrics import compute_track_metrics |
|
from utils.metrics import AMTMetrics |
|
|
|
from utils.utils import write_model_output_as_midi, create_inverse_vocab, write_err_cnt_as_json |
|
from utils.utils import Timer |
|
from utils.task_manager import TaskManager |
|
|
|
from config.config import audio_cfg as default_audio_cfg |
|
from config.config import model_cfg as default_model_cfg |
|
from config.config import shared_cfg as default_shared_cfg |
|
from config.config import T5_BASE_CFG |
|
|
|
|
|
class YourMT3(pl.LightningModule): |
|
"""YourMT3: |
|
|
|
Lightning wrapper for multi-task music transcription Transformer. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
audio_cfg: Optional[Dict] = None, |
|
model_cfg: Optional[Dict] = None, |
|
shared_cfg: Optional[Dict] = None, |
|
pretrained: bool = False, |
|
optimizer_name: str = 'adamwscale', |
|
scheduler_name: str = 'cosine', |
|
base_lr: float = None, |
|
max_steps: Optional[int] = None, |
|
weight_decay: float = 0.0, |
|
init_factor: Optional[Union[str, float]] = None, |
|
task_manager: TaskManager = TaskManager(), |
|
eval_subtask_key: Optional[str] = "default", |
|
eval_vocab: Optional[Dict] = None, |
|
eval_drum_vocab: Optional[Dict] = None, |
|
write_output_dir: Optional[str] = None, |
|
write_output_vocab: Optional[Dict] = None, |
|
onset_tolerance: float = 0.05, |
|
add_pitch_class_metric: Optional[List[str]] = None, |
|
add_melody_metric_to_singing: bool = True, |
|
test_optimal_octave_shift: bool = False, |
|
test_pitch_shift_layer: Optional[str] = None, |
|
**kwargs: Any) -> None: |
|
super().__init__() |
|
if pretrained is True: |
|
raise NotImplementedError("Pretrained model is not supported in this version.") |
|
self.test_pitch_shift_layer = test_pitch_shift_layer |
|
|
|
|
|
if model_cfg is None: |
|
model_cfg = default_model_cfg |
|
if audio_cfg is None: |
|
audio_cfg = default_audio_cfg |
|
if shared_cfg is None: |
|
shared_cfg = default_shared_cfg |
|
|
|
|
|
self.spectrogram, spec_output_shape = get_spectrogram_layer_from_audio_cfg( |
|
audio_cfg) |
|
model_cfg["feat_length"] = spec_output_shape[0] |
|
|
|
|
|
self.task_manager = task_manager |
|
self.max_total_token_length = self.task_manager.max_total_token_length |
|
|
|
|
|
self.use_task_cond_encoder = bool(model_cfg["use_task_conditional_encoder"]) |
|
self.use_task_cond_decoder = bool(model_cfg["use_task_conditional_decoder"]) |
|
|
|
|
|
assert model_cfg["encoder_type"] in ["t5", "perceiver-tf", "conformer"] |
|
assert model_cfg["decoder_type"] in ["t5", "multi-t5"] |
|
self.encoder_type = model_cfg["encoder_type"] |
|
self.decoder_type = model_cfg["decoder_type"] |
|
encoder_config = model_cfg["encoder"][self.encoder_type] |
|
decoder_config = model_cfg["decoder"][self.decoder_type] |
|
|
|
|
|
if isinstance(model_cfg["num_max_positions"], str) and model_cfg["num_max_positions"] == 'auto': |
|
encoder_config["num_max_positions"] = int(model_cfg["feat_length"] + |
|
self.task_manager.max_task_token_length + 10) |
|
decoder_config["num_max_positions"] = int(self.max_total_token_length + 10) |
|
else: |
|
assert isinstance(model_cfg["num_max_positions"], int) |
|
encoder_config["num_max_positions"] = model_cfg["num_max_positions"] |
|
decoder_config["num_max_positions"] = model_cfg["num_max_positions"] |
|
|
|
|
|
if model_cfg["pre_encoder_type"] == "default": |
|
model_cfg["pre_encoder_type"] = model_cfg["pre_encoder_type_default"].get(model_cfg["encoder_type"], None) |
|
elif model_cfg["pre_encoder_type"] in [None, "none", "None", "0"]: |
|
model_cfg["pre_encoder_type"] = None |
|
if model_cfg["pre_decoder_type"] == "default": |
|
model_cfg["pre_decoder_type"] = model_cfg["pre_decoder_type_default"].get(model_cfg["encoder_type"]).get( |
|
model_cfg["decoder_type"], None) |
|
elif model_cfg["pre_decoder_type"] in [None, "none", "None", "0"]: |
|
model_cfg["pre_decoder_type"] = None |
|
self.pre_encoder_type = model_cfg["pre_encoder_type"] |
|
self.pre_decoder_type = model_cfg["pre_decoder_type"] |
|
|
|
|
|
self.pre_encoder = nn.Sequential() |
|
if self.pre_encoder_type in ["conv", "conv1d_t", "conv1d_f"]: |
|
kernel_size = (3, 3) |
|
avp_kernel_size = (1, 2) |
|
if self.pre_encoder_type == "conv1d_t": |
|
kernel_size = (3, 1) |
|
elif self.pre_encoder_type == "conv1d_f": |
|
kernel_size = (1, 3) |
|
self.pre_encoder.append( |
|
PreEncoderBlockRes3B(1, |
|
model_cfg["conv_out_channels"], |
|
kernel_size=kernel_size, |
|
avp_kernerl_size=avp_kernel_size, |
|
activation="relu")) |
|
pre_enc_output_shape = (spec_output_shape[0], spec_output_shape[1] // 2**3, model_cfg["conv_out_channels"] |
|
) |
|
elif self.pre_encoder_type == "hftt": |
|
self.pre_encoder.append(PreEncoderBlockHFTT()) |
|
pre_enc_output_shape = (spec_output_shape[0], spec_output_shape[1], 128) |
|
elif self.pre_encoder_type == "res3b_hftt": |
|
self.pre_encoder.append(PreEncoderBlockRes3BHFTT()) |
|
pre_enc_output_shape = (spec_output_shape[0], spec_output_shape[1] // 2**3, 128) |
|
else: |
|
pre_enc_output_shape = spec_output_shape |
|
|
|
|
|
if isinstance(model_cfg["d_feat"], str) and model_cfg["d_feat"] == 'auto': |
|
if self.encoder_type == "perceiver-tf" and encoder_config["attention_to_channel"] is True: |
|
model_cfg["d_feat"] = pre_enc_output_shape[-2] |
|
else: |
|
model_cfg["d_feat"] = pre_enc_output_shape[-1] |
|
|
|
if self.encoder_type == "perceiver-tf" and isinstance(encoder_config["d_model"], str): |
|
if encoder_config["d_model"] == 'q': |
|
encoder_config["d_model"] = encoder_config["d_latent"] |
|
elif encoder_config["d_model"] == 'kv': |
|
encoder_config["d_model"] = model_cfg["d_feat"] |
|
else: |
|
raise ValueError(f"Unknown d_model: {encoder_config['d_model']}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(model_cfg["vocab_size"], str) and model_cfg["vocab_size"] == 'auto': |
|
model_cfg["vocab_size"] = task_manager.num_tokens |
|
|
|
if isinstance(model_cfg["num_max_positions"], str) and model_cfg["num_max_positions"] == 'auto': |
|
model_cfg["num_max_positions"] = int( |
|
max(model_cfg["feat_length"], model_cfg["event_length"]) + self.task_manager.max_task_token_length + 10) |
|
|
|
|
|
self.pre_decoder = nn.Sequential() |
|
if self.encoder_type == "perceiver-tf" and self.decoder_type == "t5": |
|
t, f, c = pre_enc_output_shape |
|
encoder_output_shape = (t, encoder_config["num_latents"], encoder_config["d_latent"]) |
|
decoder_input_shape = (t, decoder_config["d_model"]) |
|
proj_layer = get_projection_layer(input_shape=encoder_output_shape, |
|
output_shape=decoder_input_shape, |
|
proj_type=self.pre_decoder_type) |
|
self.pre_encoder_output_shape = pre_enc_output_shape |
|
self.encoder_output_shape = encoder_output_shape |
|
self.decoder_input_shape = decoder_input_shape |
|
self.pre_decoder.append(proj_layer) |
|
elif self.encoder_type in ["t5", "conformer"] and self.decoder_type == "t5": |
|
pass |
|
elif self.encoder_type == "perceiver-tf" and self.decoder_type == "multi-t5": |
|
|
|
assert encoder_config["num_latents"] % decoder_config["num_channels"] == 0 |
|
encoder_output_shape = (encoder_config["num_latents"], encoder_config["d_model"]) |
|
decoder_input_shape = (decoder_config["num_channels"], decoder_config["d_model"]) |
|
proj_layer = get_multi_channel_projection_layer(input_shape=encoder_output_shape, |
|
output_shape=decoder_input_shape, |
|
proj_type=self.pre_decoder_type) |
|
self.pre_decoder.append(proj_layer) |
|
else: |
|
raise NotImplementedError( |
|
f"Encoder type {self.encoder_type} and decoder type {self.decoder_type} is not implemented yet.") |
|
|
|
|
|
if self.encoder_type in ["t5", "conformer"]: |
|
encoder_config["num_max_positions"] = decoder_config["num_max_positions"] = model_cfg["num_max_positions"] |
|
else: |
|
encoder_config["num_max_positions"] = model_cfg["feat_length"] |
|
decoder_config["num_max_positions"] = model_cfg["num_max_positions"] |
|
encoder_config["vocab_size"] = decoder_config["vocab_size"] = model_cfg["vocab_size"] |
|
|
|
|
|
self.audio_cfg = audio_cfg |
|
self.model_cfg = model_cfg |
|
self.shared_cfg = shared_cfg |
|
self.save_hyperparameters() |
|
if self.global_rank == 0: |
|
print(self.hparams) |
|
|
|
|
|
self.encoder = None |
|
self.decoder = None |
|
self.lm_head = LMHead(decoder_config, 1.0, model_cfg["tie_word_embeddings"]) |
|
self.embed_tokens = nn.Embedding(decoder_config["vocab_size"], decoder_config["d_model"]) |
|
self.embed_tokens.weight.data.normal_(mean=0.0, std=1.0) |
|
self.shift_right_fn = None |
|
self.set_encoder_decoder() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
columns = ['Ep', 'Track ID', 'Pred Events', 'Actual Events', 'Pred Notes', 'Actual Notes'] |
|
self.sample_table = wandb.Table(columns=columns) |
|
|
|
|
|
if write_output_dir is not None: |
|
if write_output_vocab is None: |
|
from config.vocabulary import program_vocab_presets |
|
self.midi_output_vocab = program_vocab_presets["gm_ext_plus"] |
|
else: |
|
self.midi_output_vocab = write_output_vocab |
|
self.midi_output_inverse_vocab = create_inverse_vocab(self.midi_output_vocab) |
|
|
|
def set_encoder_decoder(self) -> None: |
|
"""Set encoder, decoder, lm_head and emb_tokens from self.model_cfg""" |
|
|
|
|
|
t5_basename = self.model_cfg["t5_basename"] |
|
if t5_basename in T5_BASE_CFG.keys(): |
|
|
|
t5_config = T5Config(**T5_BASE_CFG[t5_basename]) |
|
else: |
|
|
|
t5_config = T5Config.from_pretrained(t5_basename) |
|
|
|
|
|
if self.encoder_type == "t5": |
|
self.encoder = T5EncoderYMT3(self.model_cfg["encoder"]["t5"], t5_config) |
|
elif self.encoder_type == "perceiver-tf": |
|
perceivertf_config = PerceiverTFConfig() |
|
perceivertf_config.update(self.model_cfg["encoder"]["perceiver-tf"]) |
|
self.encoder = PerceiverTFEncoder(perceivertf_config) |
|
elif self.encoder_type == "conformer": |
|
conformer_config = ConformerYMT3Config() |
|
conformer_config.update(self.model_cfg["encoder"]["conformer"]) |
|
self.encoder = ConformerYMT3Encoder(conformer_config) |
|
|
|
if self.decoder_type == "t5": |
|
self.decoder = T5DecoderYMT3(self.model_cfg["decoder"]["t5"], t5_config) |
|
elif self.decoder_type == "multi-t5": |
|
self.decoder = MultiChannelT5Decoder(self.model_cfg["decoder"]["multi-t5"], t5_config) |
|
|
|
|
|
self.shift_right_fn = self.decoder._shift_right |
|
|
|
def setup(self, stage: str) -> None: |
|
|
|
if self.hparams.eval_vocab is None: |
|
extra_classes_per_dataset = [None] |
|
else: |
|
extra_classes_per_dataset = [ |
|
list(v.keys()) if v is not None else None for v in self.hparams.eval_vocab |
|
] |
|
|
|
|
|
extra_metrics = None |
|
if self.hparams.add_melody_metric_to_singing is True: |
|
extra_metrics = ["melody_rpa_Singing Voice", "melody_rca_Singing Voice", "melody_oa_Singing Voice"] |
|
|
|
|
|
if self.hparams.add_pitch_class_metric is not None: |
|
for sublist in extra_classes_per_dataset: |
|
for name in self.hparams.add_pitch_class_metric: |
|
if sublist is not None and name in sublist: |
|
sublist += [name + "_pc"] |
|
|
|
extra_classes_unique = list( |
|
set(item for sublist in extra_classes_per_dataset if sublist is not None |
|
for item in sublist)) |
|
dm = self.trainer.datamodule |
|
|
|
|
|
if stage == "fit": |
|
self.val_metrics_macro = AMTMetrics(prefix=f'validation/macro_', extra_classes=extra_classes_unique) |
|
self.val_metrics = nn.ModuleList() |
|
for i in range(dm.num_val_dataloaders): |
|
self.val_metrics.append( |
|
AMTMetrics(prefix=f'validation/({dm.get_val_dataset_name(i)})', |
|
extra_classes=extra_classes_per_dataset[i], |
|
error_types=DECODING_ERR_TYPES)) |
|
|
|
|
|
if self.shared_cfg["AUGMENTATION"]["train_pitch_shift_range"] in [None, [0, 0]]: |
|
self.pitchshift = None |
|
else: |
|
|
|
input_shape = (self.shared_cfg["BSZ"]["train_local"], 1, self.audio_cfg["input_frames"]) |
|
self.pitchshift = PitchShiftLayer( |
|
pshift_range=self.shared_cfg["AUGMENTATION"]["train_pitch_shift_range"], |
|
expected_input_shape=input_shape, |
|
device=self.device) |
|
|
|
|
|
elif stage == "test": |
|
|
|
|
|
self.test_metrics = nn.ModuleList() |
|
for i in range(dm.num_test_dataloaders): |
|
self.test_metrics.append( |
|
AMTMetrics(prefix=f'test/({dm.get_test_dataset_name(i)})', |
|
extra_classes=extra_classes_per_dataset[i], |
|
extra_metrics=extra_metrics, |
|
error_types=DECODING_ERR_TYPES)) |
|
|
|
|
|
if self.test_pitch_shift_layer is not None: |
|
self.test_pitch_shift_semitone = int(self.test_pitch_shift_layer) |
|
self.pitchshift = PitchShiftLayer( |
|
pshift_range=[self.test_pitch_shift_semitone, self.test_pitch_shift_semitone]) |
|
|
|
def configure_optimizers(self) -> None: |
|
"""Configure optimizer and scheduler""" |
|
optimizer, base_lr = get_optimizer(models_dict=self.named_parameters(), |
|
optimizer_name=self.hparams.optimizer_name, |
|
base_lr=self.hparams.base_lr, |
|
weight_decay=self.hparams.weight_decay) |
|
|
|
if self.hparams.optimizer_name.lower() == 'adafactor' and self.hparams.base_lr == None: |
|
print("Using AdaFactor with auto learning rate and no scheduler") |
|
return [optimizer] |
|
if self.hparams.optimizer_name.lower() == 'dadaptadam': |
|
print("Using dAdaptAdam with auto learning rate and no scheduler") |
|
return [optimizer] |
|
elif self.hparams.base_lr == None: |
|
print(f"Using default learning rate {base_lr} of {self.hparams.optimizer_name} as base learning rate.") |
|
self.hparams.base_lr = base_lr |
|
|
|
scheduler_cfg = self.shared_cfg["LR_SCHEDULE"] |
|
if self.hparams.max_steps != -1: |
|
|
|
scheduler_cfg["total_steps"] = self.hparams.max_steps |
|
_lr_scheduler = get_lr_scheduler(optimizer, |
|
scheduler_name=self.hparams.scheduler_name, |
|
base_lr=base_lr, |
|
scheduler_cfg=scheduler_cfg) |
|
|
|
lr_scheduler = {'scheduler': _lr_scheduler, 'interval': 'step', 'frequency': 1} |
|
return [optimizer], [lr_scheduler] |
|
|
|
def forward( |
|
self, |
|
x: torch.FloatTensor, |
|
target_tokens: torch.LongTensor, |
|
|
|
**kwargs) -> Dict: |
|
""" Forward pass with teacher-forcing for training and validation. |
|
Args: |
|
x: (B, 1, T) waveform with default T=32767 |
|
target_tokens: (B, C, N) tokenized sequence of length N=event_length |
|
task_tokens: (B, C, task_len) tokenized task |
|
|
|
Returns: |
|
{ |
|
'logits': (B, N + task_len + 1, vocab_size) |
|
'loss': (1, ) |
|
} |
|
|
|
NOTE: all the commented shapes are in the case of original MT3 setup. |
|
""" |
|
x = self.spectrogram(x) |
|
x = self.pre_encoder(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
enc_hs = self.encoder(inputs_embeds=x)["last_hidden_state"] |
|
enc_hs = self.pre_decoder(enc_hs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
labels = target_tokens |
|
if labels.shape[1] == 1: |
|
labels = labels.squeeze(1) |
|
|
|
dec_input_ids = self.shift_right_fn(labels) |
|
dec_inputs_embeds = self.embed_tokens(dec_input_ids) |
|
dec_hs, _ = self.decoder(inputs_embeds=dec_inputs_embeds, encoder_hidden_states=enc_hs, return_dict=False) |
|
|
|
if self.model_cfg["tie_word_embeddings"] is True: |
|
dec_hs = dec_hs * (self.model_cfg["decoder"][self.decoder_type]["d_model"]**-0.5) |
|
|
|
logits = self.lm_head(dec_hs) |
|
|
|
loss = None |
|
labels = labels.masked_fill(labels == 0, value=-100) |
|
loss_fct = CrossEntropyLoss(ignore_index=-100) |
|
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) |
|
return {"logits": logits, "loss": loss} |
|
|
|
def inference(self, |
|
x: torch.FloatTensor, |
|
task_tokens: Optional[torch.LongTensor] = None, |
|
max_token_length: Optional[int] = None, |
|
**kwargs: Any) -> torch.Tensor: |
|
""" Inference from audio batch by cached autoregressive decoding. |
|
Args: |
|
x: (b, 1, t) waveform with t=32767 |
|
task_token: (b, c, task_len) tokenized task. If None, will not append task embeddings (from task_tokens) to input. |
|
max_length: Maximum length of generated sequence. If None, self.max_total_token_length. |
|
**kwargs: https://huggingface.co/docs/transformers/v4.27.2/en/main_classes/text_generation#transformers.GenerationMixin.generate |
|
|
|
Returns: |
|
res_tokens: (b, n) resulting tokenized sequence of variable length < max_length |
|
""" |
|
if self.test_pitch_shift_layer is not None: |
|
x_ps = self.pitchshift(x, self.test_pitch_shift_semitone) |
|
x = x_ps |
|
|
|
|
|
x = self.spectrogram(x) |
|
x = self.pre_encoder(x) |
|
if task_tokens is not None and task_tokens.numel() > 0 and self.use_task_cond_encoder is True: |
|
|
|
task_embed = self.embed_tokens(task_tokens) |
|
x = torch.cat([task_embed, x], dim=1) |
|
enc_hs = self.encoder(inputs_embeds=x)["last_hidden_state"] |
|
enc_hs = self.pre_decoder(enc_hs) |
|
|
|
|
|
if max_token_length is None: |
|
max_token_length = self.max_total_token_length |
|
|
|
pred_ids = task_cond_dec_generate(decoder=self.decoder, |
|
decoder_type=self.decoder_type, |
|
embed_tokens=self.embed_tokens, |
|
lm_head=self.lm_head, |
|
encoder_hidden_states=enc_hs, |
|
shift_right_fn=self.shift_right_fn, |
|
prefix_ids=task_tokens, |
|
max_length=max_token_length) |
|
if pred_ids.dim() == 2: |
|
pred_ids = pred_ids.unsqueeze(1) |
|
|
|
if self.test_pitch_shift_layer is None: |
|
return pred_ids |
|
else: |
|
return pred_ids, x_ps |
|
|
|
def inference_file( |
|
self, |
|
bsz: int, |
|
audio_segments: torch.FloatTensor, |
|
note_token_array: Optional[torch.LongTensor] = None, |
|
task_token_array: Optional[torch.LongTensor] = None, |
|
|
|
) -> Tuple[List[np.ndarray], Optional[torch.Tensor]]: |
|
""" Inference from audio batch by autoregressive decoding: |
|
Args: |
|
bsz: batch size |
|
audio_segments: (n_items, 1, segment_len): segmented audio from a single file |
|
note_token_array: (n_items, max_token_len): Optional. If token_array is None, will not return loss. |
|
subtask_key: (str): If None, not using subtask prefix. By default, using "default" defined in task manager. |
|
""" |
|
|
|
|
|
|
|
|
|
n_items = audio_segments.shape[0] |
|
loss = 0. |
|
pred_token_array_file = [] |
|
x_ps_concat = [] |
|
|
|
for i in range(0, n_items, bsz): |
|
if i + bsz > n_items: |
|
x = audio_segments[i:n_items].to(self.device) |
|
|
|
|
|
|
|
if note_token_array is not None: |
|
target_tokens = note_token_array[i:n_items].to(self.device) |
|
if task_token_array is not None and task_token_array.numel() > 0: |
|
task_tokens = task_token_array[i:n_items].to(self.device) |
|
else: |
|
task_tokens = None |
|
else: |
|
x = audio_segments[i:i + bsz].to(self.device) |
|
|
|
|
|
if note_token_array is not None: |
|
target_tokens = note_token_array[i:i + bsz].to(self.device) |
|
if task_token_array is not None and task_token_array.numel() > 0: |
|
task_tokens = task_token_array[i:i + bsz].to(self.device) |
|
else: |
|
task_tokens = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.test_pitch_shift_layer is not None: |
|
preds, x_ps = self.inference(x, task_tokens) |
|
preds = preds.detach().cpu().numpy() |
|
x_ps_concat.append(x_ps.detach().cpu()) |
|
else: |
|
preds = self.inference(x, task_tokens).detach().cpu().numpy() |
|
if len(preds) != len(x): |
|
raise ValueError(f'preds: {len(preds)}, x: {len(x)}') |
|
pred_token_array_file.append(preds) |
|
|
|
|
|
if note_token_array is not None: |
|
loss_weight = x.shape[0] / n_items |
|
loss += self(x, target_tokens)['loss'] * loss_weight |
|
|
|
else: |
|
loss = None |
|
|
|
if self.test_pitch_shift_layer is not None: |
|
if self.hparams.write_output_dir is not None: |
|
x_ps_concat = torch.cat(x_ps_concat, dim=0) |
|
return pred_token_array_file, loss, x_ps_concat.flatten().unsqueeze(0) |
|
else: |
|
return pred_token_array_file, loss |
|
|
|
def training_step(self, batch, batch_idx) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
audio_segments, note_tokens, pshift_steps = [torch.cat(t, dim=0) for t in zip(*batch.values())] |
|
|
|
if self.pitchshift is not None: |
|
|
|
n_groups = len(batch) |
|
audio_segments = torch.chunk(audio_segments, n_groups, dim=0) |
|
pshift_steps = torch.chunk(pshift_steps, n_groups, dim=0) |
|
for p in pshift_steps: |
|
assert p.eq(p[0]).all().item() |
|
|
|
audio_segments = torch.cat([self.pitchshift(a, p[0].item()) for a, p in zip(audio_segments, pshift_steps)], |
|
dim=0) |
|
|
|
loss = self(audio_segments, note_tokens)['loss'] |
|
self.log('train_loss', |
|
loss, |
|
on_step=True, |
|
on_epoch=True, |
|
prog_bar=True, |
|
batch_size=note_tokens.shape[0], |
|
sync_dist=True) |
|
|
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx, dataloader_idx=0) -> Dict: |
|
|
|
if self.task_manager.num_decoding_channels == 1: |
|
bsz = self.shared_cfg["BSZ"]["validation"] |
|
else: |
|
bsz = self.shared_cfg["BSZ"]["validation"] // self.task_manager.num_decoding_channels * 3 |
|
|
|
audio_segments, notes_dict, note_token_array = batch |
|
task_token_array = None |
|
|
|
|
|
n_items = audio_segments.shape[0] |
|
start_secs_file = [32767 * i / 16000 for i in range(n_items)] |
|
with Timer() as t: |
|
pred_token_array_file, loss = self.inference_file(bsz, audio_segments, note_token_array, task_token_array) |
|
""" |
|
notes_dict: # Ground truth notes |
|
{ |
|
'mtrack_id': int, |
|
'program': List[int], |
|
'is_drum': bool, |
|
'duration_sec': float, |
|
'notes': List[Note], |
|
} |
|
""" |
|
|
|
num_channels = self.task_manager.num_decoding_channels |
|
pred_notes_in_file = [] |
|
n_err_cnt = Counter() |
|
for ch in range(num_channels): |
|
pred_token_array_ch = [arr[:, ch, :] for arr in pred_token_array_file] |
|
zipped_note_events_and_tie, list_events, ne_err_cnt = self.task_manager.detokenize_list_batches( |
|
pred_token_array_ch, start_secs_file, return_events=True) |
|
pred_notes_ch, n_err_cnt_ch = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie) |
|
pred_notes_in_file.append(pred_notes_ch) |
|
n_err_cnt += n_err_cnt_ch |
|
pred_notes = mix_notes(pred_notes_in_file) |
|
|
|
if self.hparams.write_output_dir is not None: |
|
track_info = [notes_dict[k] for k in notes_dict.keys() if k.endswith("_id")][0] |
|
dataset_info = [k for k in notes_dict.keys() if k.endswith('_id')][0][:-3] |
|
|
|
|
|
write_model_output_as_midi(pred_notes, |
|
self.hparams.write_output_dir, |
|
track_info, |
|
self.midi_output_inverse_vocab, |
|
output_dir_suffix=str(dataset_info) + '_' + |
|
str(self.hparams.eval_subtask_key)) |
|
|
|
|
|
|
|
|
|
|
|
drum_metric, non_drum_metric, instr_metric = compute_track_metrics( |
|
pred_notes, |
|
notes_dict['notes'], |
|
eval_vocab=self.hparams.eval_vocab[dataloader_idx], |
|
eval_drum_vocab=self.hparams.eval_drum_vocab, |
|
onset_tolerance=self.hparams.onset_tolerance, |
|
add_pitch_class_metric=self.hparams.add_pitch_class_metric) |
|
self.val_metrics[dataloader_idx].bulk_update(drum_metric) |
|
self.val_metrics[dataloader_idx].bulk_update(non_drum_metric) |
|
self.val_metrics[dataloader_idx].bulk_update(instr_metric) |
|
self.val_metrics_macro.bulk_update(drum_metric) |
|
self.val_metrics_macro.bulk_update(non_drum_metric) |
|
self.val_metrics_macro.bulk_update(instr_metric) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
decoding_time_sec = t.elapsed_time() |
|
self.log('val_loss', loss, prog_bar=True, batch_size=n_items, sync_dist=True) |
|
|
|
|
|
def on_validation_epoch_end(self) -> None: |
|
for val_metrics in self.val_metrics: |
|
self.log_dict(val_metrics.bulk_compute(), sync_dist=True) |
|
val_metrics.bulk_reset() |
|
self.log_dict(self.val_metrics_macro.bulk_compute(), sync_dist=True) |
|
self.val_metrics_macro.bulk_reset() |
|
|
|
def test_step(self, batch, batch_idx, dataloader_idx=0) -> Dict: |
|
|
|
if self.task_manager.num_decoding_channels == 1: |
|
bsz = self.shared_cfg["BSZ"]["validation"] |
|
else: |
|
bsz = self.shared_cfg["BSZ"]["validation"] // self.task_manager.num_decoding_channels * 3 |
|
|
|
audio_segments, notes_dict, note_token_array = batch |
|
task_token_array = None |
|
|
|
|
|
if self.test_pitch_shift_layer is not None and self.test_pitch_shift_semitone != 0: |
|
for n in notes_dict['notes']: |
|
if n.is_drum == False: |
|
n.pitch = n.pitch + self.test_pitch_shift_semitone |
|
|
|
|
|
n_items = audio_segments.shape[0] |
|
start_secs_file = [32767 * i / 16000 for i in range(n_items)] |
|
|
|
if self.test_pitch_shift_layer is not None and self.hparams.write_output_dir is not None: |
|
pred_token_array_file, loss, x_ps = self.inference_file(bsz, audio_segments, None, None) |
|
else: |
|
pred_token_array_file, loss = self.inference_file(bsz, audio_segments, None, None) |
|
if len(pred_token_array_file) > 0: |
|
|
|
|
|
num_channels = self.task_manager.num_decoding_channels |
|
pred_notes_in_file = [] |
|
n_err_cnt = Counter() |
|
for ch in range(num_channels): |
|
pred_token_array_ch = [arr[:, ch, :] for arr in pred_token_array_file] |
|
zipped_note_events_and_tie, list_events, ne_err_cnt = self.task_manager.detokenize_list_batches( |
|
pred_token_array_ch, start_secs_file, return_events=True) |
|
pred_notes_ch, n_err_cnt_ch = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie) |
|
pred_notes_in_file.append(pred_notes_ch) |
|
n_err_cnt += n_err_cnt_ch |
|
pred_notes = mix_notes(pred_notes_in_file) |
|
|
|
if self.test_pitch_shift_layer is not None and self.hparams.write_output_dir is not None: |
|
|
|
wav_output_dir = os.path.join(self.hparams.write_output_dir, f"model_output_{dataset_info}") |
|
os.makedirs(wav_output_dir, exist_ok=True) |
|
wav_output_file = os.path.join(wav_output_dir, f"{track_info}_ps_{self.test_pitch_shift_semitone}.wav") |
|
torchaudio.save(wav_output_file, x_ps.squeeze(1), 16000, bits_per_sample=16) |
|
|
|
drum_metric, non_drum_metric, instr_metric = compute_track_metrics( |
|
pred_notes, |
|
notes_dict['notes'], |
|
eval_vocab=self.hparams.eval_vocab[dataloader_idx], |
|
eval_drum_vocab=self.hparams.eval_drum_vocab, |
|
onset_tolerance=self.hparams.onset_tolerance, |
|
add_pitch_class_metric=self.hparams.add_pitch_class_metric, |
|
add_melody_metric=['Singing Voice'] if self.hparams.add_melody_metric_to_singing else None, |
|
add_frame_metric=True, |
|
add_micro_metric=True, |
|
add_multi_f_metric=True) |
|
|
|
if self.hparams.write_output_dir is not None and self.global_rank == 0: |
|
|
|
track_info = [notes_dict[k] for k in notes_dict.keys() if k.endswith("_id")][0] |
|
dataset_info = [k for k in notes_dict.keys() if k.endswith('_id')][0][:-3] |
|
f_score = f"OnF{non_drum_metric['onset_f']:.2f}_MulF{instr_metric['multi_f']:.2f}" |
|
write_model_output_as_midi(pred_notes, |
|
self.hparams.write_output_dir, |
|
track_info, |
|
self.midi_output_inverse_vocab, |
|
output_dir_suffix=str(dataset_info) + '_' + |
|
str(self.hparams.eval_subtask_key) + '_' + f_score) |
|
write_err_cnt_as_json(track_info, self.hparams.write_output_dir, |
|
str(dataset_info) + '_' + str(self.hparams.eval_subtask_key) + '_' + f_score, |
|
n_err_cnt, ne_err_cnt) |
|
|
|
|
|
if self.hparams.test_optimal_octave_shift: |
|
track_info = [notes_dict[k] for k in notes_dict.keys() if k.endswith("_id")][0] |
|
dataset_info = [k for k in notes_dict.keys() if k.endswith('_id')][0][:-3] |
|
score = [instr_metric['onset_f_Bass']] |
|
ref_notes_plus = [] |
|
ref_notes_minus = [] |
|
for note in notes_dict['notes']: |
|
if note.is_drum == True: |
|
ref_notes_plus.append(note) |
|
ref_notes_minus.append(note) |
|
else: |
|
ref_notes_plus.append( |
|
Note(is_drum=note.is_drum, |
|
program=note.program, |
|
onset=note.onset, |
|
offset=note.offset, |
|
pitch=note.pitch + 12, |
|
velocity=note.velocity)) |
|
ref_notes_minus.append( |
|
Note(is_drum=note.is_drum, |
|
program=note.program, |
|
onset=note.onset, |
|
offset=note.offset, |
|
pitch=note.pitch - 12, |
|
velocity=note.velocity)) |
|
|
|
drum_metric_plus, non_drum_metric_plus, instr_metric_plus = compute_track_metrics( |
|
pred_notes, |
|
ref_notes_plus, |
|
eval_vocab=self.hparams.eval_vocab[dataloader_idx], |
|
eval_drum_vocab=self.hparams.eval_drum_vocab, |
|
onset_tolerance=self.hparams.onset_tolerance, |
|
add_pitch_class_metric=self.hparams.add_pitch_class_metric) |
|
drum_metric_minus, non_drum_metric_minus, instr_metric_minus = compute_track_metrics( |
|
ref_notes_minus, |
|
notes_dict['notes'], |
|
eval_vocab=self.hparams.eval_vocab[dataloader_idx], |
|
eval_drum_vocab=self.hparams.eval_drum_vocab, |
|
onset_tolerance=self.hparams.onset_tolerance, |
|
add_pitch_class_metric=self.hparams.add_pitch_class_metric) |
|
|
|
score.append(instr_metric_plus['onset_f_Bass']) |
|
score.append(instr_metric_minus['onset_f_Bass']) |
|
max_index = score.index(max(score)) |
|
if max_index == 0: |
|
print(f"ZERO: {track_info}, z/p/m: {score[0]:.2f}/{score[1]:.2f}/{score[2]:.2f}") |
|
elif max_index == 1: |
|
|
|
instr_metric['onset_f_Bass'] = instr_metric_plus['onset_f_Bass'] |
|
print(f"PLUS: {track_info}, z/p/m: {score[0]:.2f}/{score[1]:.2f}/{score[2]:.2f}") |
|
write_model_output_as_midi(ref_notes_plus, |
|
self.hparams.write_output_dir, |
|
track_info + '_ref_octave_plus', |
|
self.midi_output_inverse_vocab, |
|
output_dir_suffix=str(dataset_info) + '_' + |
|
str(self.hparams.eval_subtask_key)) |
|
else: |
|
|
|
instr_metric['onset_f_Bass'] = instr_metric_minus['onset_f_Bass'] |
|
print(f"MINUS: {track_info}, z/p/m: {score[0]:.2f}/{score[1]:.2f}/{score[2]:.2f}") |
|
write_model_output_as_midi(ref_notes_minus, |
|
self.hparams.write_output_dir, |
|
track_info + '_ref_octave_minus', |
|
self.midi_output_, |
|
output_dir_suffix=str(dataset_info) + '_' + |
|
str(self.hparams.eval_subtask_key)) |
|
|
|
self.test_metrics[dataloader_idx].bulk_update(drum_metric) |
|
self.test_metrics[dataloader_idx].bulk_update(non_drum_metric) |
|
self.test_metrics[dataloader_idx].bulk_update(instr_metric) |
|
|
|
|
|
|
|
|
|
def on_test_epoch_end(self) -> None: |
|
|
|
for test_metrics in self.test_metrics: |
|
self.log_dict(test_metrics.bulk_compute(), sync_dist=True) |
|
test_metrics.bulk_reset() |
|
|
|
|
|
|
|
|
|
def test_case_forward_mt3(): |
|
import torch |
|
from config.config import audio_cfg, model_cfg, shared_cfg |
|
from model.ymt3 import YourMT3 |
|
model = YourMT3() |
|
model.eval() |
|
x = torch.randn(2, 1, 32767) |
|
labels = torch.randint(0, 596, (2, 1, 1024), requires_grad=False) |
|
task_tokens = torch.LongTensor([]) |
|
output = model.forward(x, labels, task_tokens) |
|
logits, loss = output['logits'], output['loss'] |
|
assert logits.shape == (2, 1024, 596) |
|
|
|
|
|
def test_case_inference_mt3(): |
|
import torch |
|
from config.config import audio_cfg, model_cfg, shared_cfg |
|
from model.ymt3 import YourMT3 |
|
model_cfg["num_max_positions"] = 1024 + 3 + 1 |
|
model = YourMT3(model_cfg=model_cfg) |
|
model.eval() |
|
x = torch.randn(2, 1, 32767) |
|
task_tokens = torch.randint(0, 596, (2, 3), requires_grad=False) |
|
pred_ids = model.inference(x, task_tokens, max_token_length=10) |
|
|
|
|
|
|
|
def test_case_forward_enc_perceiver_tf_dec_t5(): |
|
import torch |
|
from model.ymt3 import YourMT3 |
|
from config.config import audio_cfg, model_cfg, shared_cfg |
|
model_cfg["encoder_type"] = "perceiver-tf" |
|
audio_cfg["codec"] = "spec" |
|
audio_cfg["hop_length"] = 300 |
|
|
|
model = YourMT3(audio_cfg=audio_cfg, model_cfg=model_cfg) |
|
model.eval() |
|
|
|
x = torch.randn(2, 1, 32767) |
|
labels = torch.randint(0, 596, (2, 1, 1024), requires_grad=False) |
|
|
|
|
|
output = model.forward(x, labels) |
|
logits, loss = output['logits'], output['loss'] |
|
|
|
|
|
pred_ids = model.inference(x, None, max_token_length=3) |
|
|
|
|
|
def test_case_forward_enc_conformer_dec_t5(): |
|
import torch |
|
from model.ymt3 import YourMT3 |
|
from config.config import audio_cfg, model_cfg, shared_cfg |
|
model_cfg["encoder_type"] = "conformer" |
|
audio_cfg["codec"] = "melspec" |
|
audio_cfg["hop_length"] = 128 |
|
model = YourMT3(audio_cfg=audio_cfg, model_cfg=model_cfg) |
|
model.eval() |
|
|
|
x = torch.randn(2, 1, 32767) |
|
labels = torch.randint(0, 596, (2, 1024), requires_grad=False) |
|
|
|
|
|
output = model.forward(x, labels) |
|
logits, loss = output['logits'], output['loss'] |
|
|
|
|
|
pred_ids = model.inference(x, None, 20) |
|
|
|
|
|
def test_case_enc_perceiver_tf_dec_multi_t5(): |
|
import torch |
|
from model.ymt3 import YourMT3 |
|
from config.config import audio_cfg, model_cfg, shared_cfg |
|
model_cfg["encoder_type"] = "perceiver-tf" |
|
model_cfg["decoder_type"] = "multi-t5" |
|
model_cfg["encoder"]["perceiver-tf"]["attention_to_channel"] = True |
|
model_cfg["encoder"]["perceiver-tf"]["num_latents"] = 26 |
|
audio_cfg["codec"] = "spec" |
|
audio_cfg["hop_length"] = 300 |
|
model = YourMT3(audio_cfg=audio_cfg, model_cfg=model_cfg) |
|
model.eval() |
|
|
|
x = torch.randn(2, 1, 32767) |
|
labels = torch.randint(0, 596, (2, 13, 200), requires_grad=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = torch.randn(2, 1, 32767) |
|
labels = torch.randint(0, 596, (2, 13, 200), requires_grad=False) |
|
output = model.forward(x, labels) |
|
logits, loss = output['logits'], output['loss'] |
|
|
|
|
|
model.max_total_token_length = 123 |
|
pred_ids = model.inference(x, None) |
|
|