File size: 14,989 Bytes
a03c9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
""" test.py """
import os
import pprint
import argparse
import torch

from utils.data_modules import AMTDataModule
from utils.task_manager import TaskManager
from model.init_train import initialize_trainer, update_config
from model.ymt3 import YourMT3
from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg
from config.vocabulary import drum_vocab_presets
from utils.utils import str2bool

# yapf: disable
parser = argparse.ArgumentParser(description="YourMT3")
# General
parser.add_argument('exp_id', type=str, help='A unique identifier for the experiment is used to resume training. The "@" symbol can be used to load a specific checkpoint.')
parser.add_argument('-p', '--project', type=str, default='ymt3', help='project name')
parser.add_argument('-d', '--data-preset', type=str, default='musicnet_thickstun_ext_em', help='dataset preset (default=musicnet_thickstun_ext_em). See config/data.py for more options.')
# Audio configurations
parser.add_argument('-ac', '--audio-codec', type=str, default=None, help='audio codec (default=None). {"spec", "melspec"}. If None, default value defined in config.py will be used.')
parser.add_argument('-hop', '--hop-length', type=int, default=None, help='hop length in frames (default=None). {128, 300} 128 for MT3, 300 for PerceiverTFIf None, default value defined in config.py will be used.')
parser.add_argument('-nmel', '--n-mels', type=int, default=None, help='number of mel bins (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-if', '--input-frames', type=int, default=None, help='number of audio frames for input segment (default=None). If None, default value defined in config.py will be used.')
# Model configurations
parser.add_argument('-sqr', '--sca-use-query-residual', type=str2bool, default=None, help='sca use query residual flag. Default follows config.py')
parser.add_argument('-enc', '--encoder-type', type=str, default=None, help="Encoder type. 't5' or 'perceiver-tf' or 'conformer'. Default is 't5', following config.py.")
parser.add_argument('-dec', '--decoder-type', type=str, default=None, help="Decoder type. 't5' or 'multi-t5'. Default is 't5', following config.py.")
parser.add_argument('-preenc', '--pre-encoder-type', type=str, default='default', help="Pre-encoder type. None or 'conv' or 'default'. By default, t5_enc:None, perceiver_tf_enc:conv, conformer:None")
parser.add_argument('-predec', '--pre-decoder-type', type=str, default='default', help="Pre-decoder type. {None, 'linear', 'conv1', 'mlp', 'group_linear'} or 'default'. Default is {'t5': None, 'perceiver-tf': 'linear', 'conformer': None}.")
parser.add_argument('-cout', '--conv-out-channels', type=int, default=None, help='Number of filters for pre-encoder conv layer. Default follows "model_cfg" of config.py.')
parser.add_argument('-tenc', '--task-cond-encoder', type=str2bool, default=True, help='task conditional encoder (default=True). True or False')
parser.add_argument('-tdec', '--task-cond-decoder', type=str2bool, default=True, help='task conditional decoder (default=True). True or False')
parser.add_argument('-df', '--d-feat', type=int, default=None, help='Audio feature will be projected to this dimension for Q,K,V of T5 or K,V of Perceiver (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-pt', '--pretrained', type=str2bool, default=False, help='pretrained T5(default=False). True or False')
parser.add_argument('-b', '--base-name', type=str, default="google/t5-v1_1-small", help='base model name (default="google/t5-v1_1-small")')
parser.add_argument('-epe', '--encoder-position-encoding-type', type=str, default='default', help="Positional encoding type of encoder. By default, pre-defined PE for T5 or Perceiver-TF encoder in config.py. For T5: {'sinusoidal', 'trainable'}, conformer: {'rotary', 'trainable'}, Perceiver-TF: {'trainable', 'rope', 'alibi', 'alibit', 'None', '0', 'none', 'tkd', 'td', 'tk', 'kdt'}.")
parser.add_argument('-dpe', '--decoder-position-encoding-type', type=str, default='default', help="Positional encoding type of decoder. By default, pre-defined PE for T5 in config.py. {'sinusoidal', 'trainable'}.")
parser.add_argument('-twe', '--tie-word-embedding', type=str2bool, default=None, help='tie word embedding (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-el', '--event-length', type=int, default=None, help='event length (default=None). If None, default value defined in model cfg of config.py will be used.')
# Perceiver-TF configurations
parser.add_argument('-dl', '--d-latent', type=int, default=None, help='Latent dimension of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-nl', '--num-latents', type=int, default=None, help='Number of latents of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-dpm', '--perceiver-tf-d-model', type=int, default=None, help='Perceiver-TF d_model (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-npb', '--num-perceiver-tf-blocks', type=int, default=None, help='Number of blocks of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py.')
parser.add_argument('-npl', '--num-perceiver-tf-local-transformers-per-block', type=int, default=None, help='Number of local layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-npt', '--num-perceiver-tf-temporal-transformers-per-block', type=int, default=None, help='Number of temporal layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-atc', '--attention-to-channel', type=str2bool, default=None, help='Attention to channel flag of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-ln', '--layer-norm-type', type=str, default=None, help='Layer normalization type (default=None). {"layer_norm", "rms_norm"}. If None, default value defined in config.py will be used.')
parser.add_argument('-ff', '--ff-layer-type', type=str, default=None, help='Feed forward layer type (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.')
parser.add_argument('-wf', '--ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-nmoe', '--moe-num-experts', type=int, default=None, help='Number of experts for MoE (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-kmoe', '--moe-topk', type=int, default=None, help='Top-k for MoE (default=None). If None, default value defined in config.py will be used.')
parser.add_argument('-act', '--hidden-act', type=str, default=None, help='Hidden activation function (default=None). {"gelu", "silu", "relu", "tanh"}. If None, default value defined in config.py will be used.')
parser.add_argument('-rt', '--rotary-type', type=str, default=None, help='Rotary embedding type expressed in three letters. e.g. ppl: "pixel" for SCA and latents, "lang" for temporal transformer. If None, use config.')
parser.add_argument('-rk', '--rope-apply-to-keys', type=str2bool, default=None, help='Apply rope to keys (default=None). If None, use config.')
parser.add_argument('-rp', '--rope-partial-pe', type=str2bool, default=None, help='Whether to apply RoPE to partial positions (default=None). If None, use config.')
# Decoder configurations
parser.add_argument('-dff', '--decoder-ff-layer-type', type=str, default=None, help='Feed forward layer type of decoder (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.')
parser.add_argument('-dwf', '--decoder-ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for decoder MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.')
# Task and Evaluation configurations
parser.add_argument('-tk', '--task', type=str, default='mt3_full_plus', help='tokenizer type (default=mt3_full_plus). See config/task.py for more options.')
parser.add_argument('-epv', '--eval-program-vocab', type=str, default=None, help='evaluation vocabulary (default=None). If None, default vocabulary of the data preset will be used.')
parser.add_argument('-edv', '--eval-drum-vocab', type=str, default=None, help='evaluation vocabulary for drum (default=None). If None, default vocabulary of the data preset will be used.')
parser.add_argument('-etk', '--eval-subtask-key', type=str, default='default', help='evaluation subtask key (default=default). See config/task.py for more options.')
parser.add_argument('-t', '--onset-tolerance', type=float, default=0.05, help='onset tolerance (default=0.05).')
parser.add_argument('-os', '--test-octave-shift', type=str2bool, default=False, help='test optimal octave shift (default=False). True or False')
parser.add_argument('-w', '--write-model-output', type=str2bool, default=False, help='write model test output to file (default=False). True or False')
# Trainer configurations
parser.add_argument('-pr','--precision', type=str, default="bf16-mixed", help='precision (default="bf16-mixed") {32, 16, bf16, bf16-mixed}')
parser.add_argument('-st', '--strategy', type=str, default='auto', help='strategy (default=auto). auto or deepspeed or ddp')
parser.add_argument('-n', '--num-nodes', type=int, default=1, help='number of nodes (default=1)')
parser.add_argument('-g', '--num-gpus', type=str, default='auto', help='number of gpus (default="auto")')
parser.add_argument('-wb', '--wandb-mode', type=str, default=None, help='wandb mode for logging (default=None). "disabled" or "online" or "offline". If None, default value defined in config.py will be used.')
# Debug
parser.add_argument('-debug', '--debug-mode', type=str2bool, default=False, help='debug mode (default=False). True or False')
parser.add_argument('-tps', '--test-pitch-shift', type=int, default=None, help='use pitch shift when testing. debug-purpose only. (default=None). semitone in int.')
args = parser.parse_args()
# yapf: enable
if torch.__version__ >= "1.13":
    torch.set_float32_matmul_precision("high")
args.epochs = None

# Initialize trainer
trainer, wandb_logger, dir_info, shared_cfg = initialize_trainer(args, stage='test')

# Update config with args, including augmentation settings
shared_cfg, audio_cfg, model_cfg = update_config(args, shared_cfg, stage='test')


def main():
    # Data preset
    if args.data_preset in data_preset_single_cfg:
        # convert single preset into multi preset format
        data_preset = {
            "presets": [args.data_preset],
            "eval_vocab": data_preset_single_cfg[args.data_preset]["eval_vocab"],
        }
        for k in data_preset_single_cfg[args.data_preset].keys():
            if k in ["eval_drum_vocab", "add_pitch_class_metric"]:
                data_preset[k] = data_preset_single_cfg[args.data_preset][k]
    elif args.data_preset in data_preset_multi_cfg:
        data_preset = data_preset_multi_cfg[args.data_preset]
    else:
        raise ValueError("Invalid data preset")
    eval_drum_vocab = data_preset.get("eval_drum_vocab", None)

    if args.eval_drum_vocab != None:  # override eval_drum_vocab
        eval_drum_vocab = drum_vocab_presets[args.eval_drum_vocab]

    # Task manager
    tm = TaskManager(task_name=args.task,
                     max_shift_steps=int(shared_cfg["TOKENIZER"]["max_shift_steps"]),
                     debug_mode=args.debug_mode)
    print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}")

    results = []
    for i, preset in enumerate(data_preset["presets"]):
        # sdp: unpacking multi preset as a list of single presets
        sdp = {
            "presets": [preset],
            "eval_vocab": [data_preset["eval_vocab"][i]],
            "eval_drum_vocab": eval_drum_vocab,
        }
        for k in data_preset.keys():
            if k not in ["presets", "eval_vocab"]:
                sdp[k] = data_preset[k]

        dm = AMTDataModule(data_preset_multi=sdp, task_manager=tm, audio_cfg=audio_cfg)

        model = YourMT3(
            audio_cfg=audio_cfg,
            model_cfg=model_cfg,
            shared_cfg=shared_cfg,
            optimizer=None,
            task_manager=tm,  # tokenizer is a member of task_manager
            eval_subtask_key=args.eval_subtask_key,
            eval_vocab=args.eval_program_vocab if args.eval_program_vocab != None else sdp["eval_vocab"],
            eval_drum_vocab=sdp["eval_drum_vocab"],
            write_output_dir=dir_info["lightning_dir"] if args.write_model_output or args.test_octave_shift else None,
            onset_tolerance=float(args.onset_tolerance),
            add_pitch_class_metric=sdp.get("add_pitch_class_metric", None),
            test_optimal_octave_shift=args.test_octave_shift,
            test_pitch_shift_layer=args.test_pitch_shift)

        # load checkpoint & drop pitchshift from state_dict
        checkpoint = torch.load(dir_info["last_ckpt_path"])
        state_dict = checkpoint['state_dict']
        new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
        model.load_state_dict(new_state_dict, strict=False)
        # if args.test_pitch_shift is None:
        #     new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
        #     model.load_state_dict(new_state_dict, strict=False)
        # else:
        #     model.load_state_dict(state_dict, strict=False)

        results.append("-----------------------------------------------------------------")
        results.append(sdp)
        results.append(trainer.test(model, datamodule=dm))
        # TODO: directly load checkpoint including hyperparmeters https://lightning.ai/docs/pytorch/1.6.2/common/hyperparameters.html

    # save result
    pp = pprint.PrettyPrinter(indent=4)
    results_str = pp.pformat(results)
    result_file = os.path.join(dir_info["lightning_dir"],
                               f"result_{args.task}_{args.eval_subtask_key}_{args.data_preset}.json")
    with open(result_file, 'w') as f:
        f.write(results_str)
    print(f"Result is saved to {result_file}")


if __name__ == "__main__":
    main()