File size: 2,327 Bytes
a03c9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
from typing import Dict, Tuple
from copy import deepcopy
import soundfile as sf
import torch
from utils.data_modules import AMTDataModule
from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg
from utils.augment import intra_stem_augment_processor


def get_ds(data_preset_multi: Dict, train_num_samples_per_epoch: int = 90000):
    dm = AMTDataModule(data_preset_multi=data_preset_multi, train_num_samples_per_epoch=train_num_samples_per_epoch)
    dm.setup('fit')
    dl = dm.train_dataloader()
    ds = dl.flattened[0].dataset
    return ds


def debug_func(num_segments: int = 10):
    sampled_data, sampled_ids = ds._get_rand_segments_from_cache(num_segments)
    ux_sampled_data, _ = ds._get_rand_segments_from_cache(ux_count_sum, False, sampled_ids)
    s = deepcopy(sampled_data)
    intra_stem_augment_processor(sampled_data, submix_audio=False)


def gen_audio(index: int = 0):
    # audio_arr: (b, 1, nframe), note_token_arr: (b, l), task_token_arr: (b, task_l)
    audio_arr, note_token_arr, task_token_arr = ds.__getitem__(index)

    # merge all the segments into one audio file
    audio = audio_arr.permute(0, 2, 1).reshape(-1).squeeze().numpy()

    # save the audio file
    sf.write('xaug_demo_audio.wav', audio, 16000, subtype='PCM_16')


data_preset_multi = data_preset_multi_cfg["all_cross_rebal5"]
ds = get_ds(data_preset_multi)
ds.random_amp_range = [0.8, 1.1]
ds.stem_xaug_policy = {
    "max_k": 5,
    "tau": 0.3,
    "alpha": 1.0,
    "max_subunit_stems": 12,
    "no_instr_overlap": True,
    "no_drum_overlap": True,
    "uhat_intra_stem_augment": True,
}
gen_audio(3)

# for k in ds.cache.keys():
#     arr = ds.cache[k]['audio_array']
#     arr = np.sum(arr, axis=1).reshape(-1)
#     # sf.write(f'xxx/{k}.wav', arr, 16000, subtype='PCM_16')
#     if np.min(arr) > -0.5:
#         print(k)

# arr = ds.cache[52]['audio_array']
# for i in range(arr.shape[1]):
#     a = arr[:, i, :].reshape(-1)
#     sf.write(f'xxx52/52_{i}.wav', a, 16000, subtype='PCM_16')