|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, Tuple |
|
from copy import deepcopy |
|
from collections import Counter |
|
import numpy as np |
|
import torch |
|
from utils.data_modules import AMTDataModule |
|
from utils.task_manager import TaskManager |
|
from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg |
|
from utils.augment import intra_stem_augment_processor |
|
|
|
|
|
def get_ds(data_preset_multi: Dict, task_name: str, train_num_samples_per_epoch: int = 90000): |
|
tm = TaskManager(task_name=task_name) |
|
tm.max_note_token_length_per_ch = 1024 |
|
dm = AMTDataModule(data_preset_multi=data_preset_multi, |
|
task_manager=tm, |
|
train_num_samples_per_epoch=train_num_samples_per_epoch) |
|
dm.setup('fit') |
|
dl = dm.train_dataloader() |
|
ds = dl.flattened[0].dataset |
|
return ds |
|
|
|
|
|
data_preset_multi = data_preset_multi_cfg["all_cross_v6"] |
|
task_name = "mc13" |
|
ds = get_ds(data_preset_multi, task_name=task_name) |
|
ds.random_amp_range = [0.8, 1.1] |
|
ds.stem_xaug_policy = { |
|
"max_k": 5, |
|
"tau": 0.3, |
|
"alpha": 1.0, |
|
"max_subunit_stems": 12, |
|
"no_instr_overlap": True, |
|
"no_drum_overlap": True, |
|
"uhat_intra_stem_augment": True, |
|
} |
|
|
|
length_all = [] |
|
for i in range(40000): |
|
if i % 5000 == 0: |
|
print(i) |
|
audio_arr, note_token_arr, task_totken_arr, pshift_steps = ds.__getitem__(i) |
|
lengths = torch.sum(note_token_arr != 0, dim=2).flatten().cpu().tolist() |
|
length_all.extend(lengths) |
|
|
|
length_all = np.asarray(length_all) |
|
|
|
|
|
empty_sequence = np.sum(length_all < 3) / len(length_all) * 100 |
|
print("empty_sequences:", f"{empty_sequence:.2f}", "%") |
|
|
|
mean_except_empty = np.mean(length_all[length_all > 2]) |
|
print("mean_except_empty:", mean_except_empty) |
|
|
|
median_except_empty = np.median(length_all[length_all > 2]) |
|
print("median_except_empty:", median_except_empty) |
|
|
|
ch_less_than_768 = np.sum(length_all < 768) / len(length_all) * 100 |
|
print("ch_less_than_768:", f"{ch_less_than_768:.2f}", "%") |
|
|
|
ch_larger_than_512 = np.sum(length_all > 512) / len(length_all) * 100 |
|
print("ch_larger_than_512:", f"{ch_larger_than_512:.6f}", "%") |
|
|
|
ch_larger_than_256 = np.sum(length_all > 256) / len(length_all) * 100 |
|
print("ch_larger_than_256:", f"{ch_larger_than_256:.6f}", "%") |
|
|
|
ch_larger_than_128 = np.sum(length_all > 128) / len(length_all) * 100 |
|
print("ch_larger_than_128:", f"{ch_larger_than_128:.6f}", "%") |
|
|
|
ch_larger_than_64 = np.sum(length_all > 64) / len(length_all) * 100 |
|
print("ch_larger_than_64:", f"{ch_larger_than_64:.6f}", "%") |
|
|
|
song_length_all = length_all.reshape(-1, 13) |
|
song_larger_than_512 = 0 |
|
song_larger_than_256 = 0 |
|
song_larger_than_128 = 0 |
|
song_larger_than_64 = 0 |
|
for l in song_length_all: |
|
if np.sum(l > 512) > 0: |
|
song_larger_than_512 += 1 |
|
if np.sum(l > 256) > 0: |
|
song_larger_than_256 += 1 |
|
if np.sum(l > 128) > 0: |
|
song_larger_than_128 += 1 |
|
if np.sum(l > 64) > 0: |
|
song_larger_than_64 += 1 |
|
num_songs = len(song_length_all) |
|
print("song_larger_than_512:", f"{song_larger_than_512/num_songs*100:.4f}", "%") |
|
print("song_larger_than_256:", f"{song_larger_than_256/num_songs*100:.4f}", "%") |
|
print("song_larger_than_128:", f"{song_larger_than_128/num_songs*100:.4f}", "%") |
|
print("song_larger_than_64:", f"{song_larger_than_64/num_songs*100:.4f}", "%") |
|
|
|
instr_dict = { |
|
0: "Piano", |
|
1: "Chromatic Percussion", |
|
2: "Organ", |
|
3: "Guitar", |
|
4: "Bass", |
|
5: "Strings + Ensemble", |
|
6: "Brass", |
|
7: "Reed", |
|
8: "Pipe", |
|
9: "Synth Lead", |
|
10: "Synth Pad", |
|
11: "Singing", |
|
12: "Drums", |
|
} |
|
cnt_larger_than_512 = Counter() |
|
for i in np.where(length_all > 512)[0] % 13: |
|
cnt_larger_than_512[i] += 1 |
|
print("larger_than_512:") |
|
for k, v in cnt_larger_than_512.items(): |
|
print(f" - {instr_dict[k]}: {v}") |
|
|
|
cnt_larger_than_256 = Counter() |
|
for i in np.where(length_all > 256)[0] % 13: |
|
cnt_larger_than_256[i] += 1 |
|
print("larger_than_256:") |
|
for k, v in cnt_larger_than_256.items(): |
|
print(f" - {instr_dict[k]}: {v}") |
|
|
|
cnt_larger_than_128 = Counter() |
|
for i in np.where(length_all > 128)[0] % 13: |
|
cnt_larger_than_128[i] += 1 |
|
print("larger_than_128:") |
|
for k, v in cnt_larger_than_128.items(): |
|
print(f" - {instr_dict[k]}: {v}") |
|
""" |
|
empty_sequences: 91.06 % |
|
mean_except_empty: 36.68976799156269 |
|
median_except_empty: 31.0 |
|
ch_less_than_768: 100.00 % |
|
ch_larger_than_512: 0.000158 % |
|
ch_larger_than_256: 0.015132 % |
|
ch_larger_than_128: 0.192061 % |
|
ch_larger_than_64: 0.661260 % |
|
song_larger_than_512: 0.0021 % |
|
song_larger_than_256: 0.1926 % |
|
song_larger_than_128: 2.2280 % |
|
song_larger_than_64: 6.1033 % |
|
|
|
larger_than_512: |
|
- Guitar: 7 |
|
- Strings + Ensemble: 3 |
|
larger_than_256: |
|
- Piano: 177 |
|
- Guitar: 680 |
|
- Strings + Ensemble: 79 |
|
- Organ: 2 |
|
- Chromatic Percussion: 11 |
|
- Bass: 1 |
|
- Synth Lead: 2 |
|
- Brass: 1 |
|
- Reed: 5 |
|
larger_than_128: |
|
- Guitar: 4711 |
|
- Strings + Ensemble: 1280 |
|
- Piano: 5548 |
|
- Bass: 211 |
|
- Synth Pad: 22 |
|
- Pipe: 18 |
|
- Chromatic Percussion: 55 |
|
- Synth Lead: 22 |
|
- Organ: 75 |
|
- Reed: 161 |
|
- Brass: 45 |
|
- Drums: 11 |
|
""" |
|
|