Spaces:
Build error
Build error
# Copyright 2024 The YourMT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Please see the details in the LICENSE file. | |
from typing import Dict, Tuple | |
from copy import deepcopy | |
from collections import Counter | |
import numpy as np | |
import torch | |
from utils.data_modules import AMTDataModule | |
from utils.task_manager import TaskManager | |
from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg | |
from utils.augment import intra_stem_augment_processor | |
def get_ds(data_preset_multi: Dict, task_name: str, train_num_samples_per_epoch: int = 90000): | |
tm = TaskManager(task_name=task_name) | |
tm.max_note_token_length_per_ch = 1024 # only to check the max length | |
dm = AMTDataModule(data_preset_multi=data_preset_multi, | |
task_manager=tm, | |
train_num_samples_per_epoch=train_num_samples_per_epoch) | |
dm.setup('fit') | |
dl = dm.train_dataloader() | |
ds = dl.flattened[0].dataset | |
return ds | |
data_preset_multi = data_preset_multi_cfg["all_cross_v6"] | |
task_name = "mc13" # "mt3_full_plus" | |
ds = get_ds(data_preset_multi, task_name=task_name) | |
ds.random_amp_range = [0.8, 1.1] | |
ds.stem_xaug_policy = { | |
"max_k": 5, | |
"tau": 0.3, | |
"alpha": 1.0, | |
"max_subunit_stems": 12, | |
"no_instr_overlap": True, | |
"no_drum_overlap": True, | |
"uhat_intra_stem_augment": True, | |
} | |
length_all = [] | |
for i in range(40000): | |
if i % 5000 == 0: | |
print(i) | |
audio_arr, note_token_arr, task_totken_arr, pshift_steps = ds.__getitem__(i) | |
lengths = torch.sum(note_token_arr != 0, dim=2).flatten().cpu().tolist() | |
length_all.extend(lengths) | |
length_all = np.asarray(length_all) | |
# stats | |
empty_sequence = np.sum(length_all < 3) / len(length_all) * 100 | |
print("empty_sequences:", f"{empty_sequence:.2f}", "%") | |
mean_except_empty = np.mean(length_all[length_all > 2]) | |
print("mean_except_empty:", mean_except_empty) | |
median_except_empty = np.median(length_all[length_all > 2]) | |
print("median_except_empty:", median_except_empty) | |
ch_less_than_768 = np.sum(length_all < 768) / len(length_all) * 100 | |
print("ch_less_than_768:", f"{ch_less_than_768:.2f}", "%") | |
ch_larger_than_512 = np.sum(length_all > 512) / len(length_all) * 100 | |
print("ch_larger_than_512:", f"{ch_larger_than_512:.6f}", "%") | |
ch_larger_than_256 = np.sum(length_all > 256) / len(length_all) * 100 | |
print("ch_larger_than_256:", f"{ch_larger_than_256:.6f}", "%") | |
ch_larger_than_128 = np.sum(length_all > 128) / len(length_all) * 100 | |
print("ch_larger_than_128:", f"{ch_larger_than_128:.6f}", "%") | |
ch_larger_than_64 = np.sum(length_all > 64) / len(length_all) * 100 | |
print("ch_larger_than_64:", f"{ch_larger_than_64:.6f}", "%") | |
song_length_all = length_all.reshape(-1, 13) | |
song_larger_than_512 = 0 | |
song_larger_than_256 = 0 | |
song_larger_than_128 = 0 | |
song_larger_than_64 = 0 | |
for l in song_length_all: | |
if np.sum(l > 512) > 0: | |
song_larger_than_512 += 1 | |
if np.sum(l > 256) > 0: | |
song_larger_than_256 += 1 | |
if np.sum(l > 128) > 0: | |
song_larger_than_128 += 1 | |
if np.sum(l > 64) > 0: | |
song_larger_than_64 += 1 | |
num_songs = len(song_length_all) | |
print("song_larger_than_512:", f"{song_larger_than_512/num_songs*100:.4f}", "%") | |
print("song_larger_than_256:", f"{song_larger_than_256/num_songs*100:.4f}", "%") | |
print("song_larger_than_128:", f"{song_larger_than_128/num_songs*100:.4f}", "%") | |
print("song_larger_than_64:", f"{song_larger_than_64/num_songs*100:.4f}", "%") | |
instr_dict = { | |
0: "Piano", | |
1: "Chromatic Percussion", | |
2: "Organ", | |
3: "Guitar", | |
4: "Bass", | |
5: "Strings + Ensemble", | |
6: "Brass", | |
7: "Reed", | |
8: "Pipe", | |
9: "Synth Lead", | |
10: "Synth Pad", | |
11: "Singing", | |
12: "Drums", | |
} | |
cnt_larger_than_512 = Counter() | |
for i in np.where(length_all > 512)[0] % 13: | |
cnt_larger_than_512[i] += 1 | |
print("larger_than_512:") | |
for k, v in cnt_larger_than_512.items(): | |
print(f" - {instr_dict[k]}: {v}") | |
cnt_larger_than_256 = Counter() | |
for i in np.where(length_all > 256)[0] % 13: | |
cnt_larger_than_256[i] += 1 | |
print("larger_than_256:") | |
for k, v in cnt_larger_than_256.items(): | |
print(f" - {instr_dict[k]}: {v}") | |
cnt_larger_than_128 = Counter() | |
for i in np.where(length_all > 128)[0] % 13: | |
cnt_larger_than_128[i] += 1 | |
print("larger_than_128:") | |
for k, v in cnt_larger_than_128.items(): | |
print(f" - {instr_dict[k]}: {v}") | |
""" | |
empty_sequences: 91.06 % | |
mean_except_empty: 36.68976799156269 | |
median_except_empty: 31.0 | |
ch_less_than_768: 100.00 % | |
ch_larger_than_512: 0.000158 % | |
ch_larger_than_256: 0.015132 % | |
ch_larger_than_128: 0.192061 % | |
ch_larger_than_64: 0.661260 % | |
song_larger_than_512: 0.0021 % | |
song_larger_than_256: 0.1926 % | |
song_larger_than_128: 2.2280 % | |
song_larger_than_64: 6.1033 % | |
larger_than_512: | |
- Guitar: 7 | |
- Strings + Ensemble: 3 | |
larger_than_256: | |
- Piano: 177 | |
- Guitar: 680 | |
- Strings + Ensemble: 79 | |
- Organ: 2 | |
- Chromatic Percussion: 11 | |
- Bass: 1 | |
- Synth Lead: 2 | |
- Brass: 1 | |
- Reed: 5 | |
larger_than_128: | |
- Guitar: 4711 | |
- Strings + Ensemble: 1280 | |
- Piano: 5548 | |
- Bass: 211 | |
- Synth Pad: 22 | |
- Pipe: 18 | |
- Chromatic Percussion: 55 | |
- Synth Lead: 22 | |
- Organ: 75 | |
- Reed: 161 | |
- Brass: 45 | |
- Drums: 11 | |
""" | |