|
""" preprocess_enstdrums.py """ |
|
import os |
|
import re |
|
import glob |
|
import copy |
|
import json |
|
import numpy as np |
|
from typing import Dict |
|
from utils.note_event_dataclasses import Note |
|
from utils.audio import get_audio_file_info, load_audio_file, write_wav_file |
|
from utils.midi import note_event2midi |
|
from utils.note2event import note2note_event, sort_notes, validate_notes, trim_overlapping_notes, mix_notes |
|
from config.vocabulary import ENST_DRUM_NOTES |
|
|
|
DRUM_OFFSET = 0.01 |
|
|
|
|
|
def create_enst_audio_stem(drum_audio_file, accomp_audio_file, enst_id) -> Dict: |
|
program = [128, 129] |
|
is_drum = [1, 0] |
|
|
|
audio_tracks = [] |
|
drum_audio = load_audio_file(drum_audio_file, dtype=np.int16) / 2**15 |
|
audio_tracks.append(drum_audio.astype(np.float16)) |
|
accomp_audio = load_audio_file(accomp_audio_file, dtype=np.int16) / 2**15 |
|
audio_tracks.append(accomp_audio.astype(np.float16)) |
|
max_length = max(len(drum_audio), len(accomp_audio)) |
|
|
|
|
|
n_tracks = 2 |
|
audio_array = np.zeros((n_tracks, max_length), dtype=np.float16) |
|
for j, audio in enumerate(audio_tracks): |
|
audio_array[j, :len(audio)] = audio |
|
|
|
stem_content = { |
|
'enstdrums_id': enst_id, |
|
'program': program, |
|
'is_drum': is_drum, |
|
'n_frames': max_length, |
|
'audio_array': audio_array |
|
} |
|
return stem_content |
|
|
|
|
|
def create_note_note_event_midi_from_enst_annotation(ann_file, enst_id): |
|
""" |
|
Args: |
|
ann_file: 'path/to/annotation.txt' |
|
enst_id: str |
|
Returns: |
|
notes: List[Note] |
|
note_events: List[NoteEvent] |
|
midi: List[List[int]] |
|
""" |
|
|
|
with open(ann_file, 'r') as f: |
|
lines = f.readlines() |
|
anns = [(float(line.split()[0]), re.sub('[^a-zA-Z]', '', line.split()[1])) for line in lines] |
|
|
|
|
|
notes = [] |
|
for time, drum_name in anns: |
|
if drum_name not in ENST_DRUM_NOTES.keys(): |
|
raise ValueError(f"Drum name {drum_name} is not in ENST_DRUM_NOTES") |
|
notes.append( |
|
Note( |
|
is_drum=True, |
|
program=128, |
|
onset=float(time), |
|
offset=float(time) + DRUM_OFFSET, |
|
pitch=ENST_DRUM_NOTES[drum_name][0], |
|
velocity=1)) |
|
|
|
notes = sort_notes(notes) |
|
notes = validate_notes(notes) |
|
notes = trim_overlapping_notes(notes) |
|
note_events = note2note_event(notes) |
|
|
|
|
|
midi_file = ann_file.replace('.txt', '.mid') |
|
note_event2midi(note_events, midi_file) |
|
print(f"Created {midi_file}") |
|
|
|
program = [128] |
|
is_drum = [1] |
|
|
|
return { |
|
'enstdrums_id': enst_id, |
|
'program': program, |
|
'is_drum': is_drum, |
|
'duration_sec': note_events[-1].time, |
|
'notes': notes, |
|
}, { |
|
'enstdrums_id': enst_id, |
|
'program': program, |
|
'is_drum': is_drum, |
|
'duration_sec': note_events[-1].time, |
|
'note_events': note_events, |
|
} |
|
|
|
|
|
def preprocess_enstdrums16k(data_home: os.PathLike, dataset_name='enstdrums') -> None: |
|
""" |
|
Some tracks ('minus-one' in the file name) of ENST-drums contain accompaniments. |
|
'stem file' will contain these accompaniments. 'mix_audio_file' will contain the |
|
mix of the drums and accompaniments. |
|
|
|
Splits: |
|
- drummer_1, drummer_2, drummer_3, all |
|
- drummer3_dtd, drummer3_dtp, drummer3_dtm_r1, drummer3_dtm_r2 (for validation/test) |
|
|
|
DTD means drum track only, DTP means drum track plus percussions, DTM means |
|
drum track plus music. |
|
|
|
DTM r1 and r2 are two different versions of the mixing tracks derived from listening |
|
test in [Gillet 2008, Paulus 2009], and used in [Wu 2018]. |
|
r1 uses 1:3 ratio of accompaniment to drums, and r2 uses 2:3 ratio. |
|
|
|
O. Gillet and G. Richard, “Transcription and separation of drum signals from polyphonic music,” |
|
IEEE Trans. Audio, Speech, Lang. Process., vol. 16, no. 3, pp. 529–540, Mar. 2008. |
|
J. Paulus and A. Klapuri, “Drum sound detection in polyphonic mu- sic with hidden Markov models,” |
|
EURASIP J. Audio, Speech, Music Process., vol. 2009, no. 14, 2009, Art. no. 14. |
|
C. -W. Wu et al., "A Review of Automatic Drum Transcription," |
|
IEEE/ACM TASLP, vol. 26, no. 9, pp. 1457-1483, Sept. 2018, |
|
|
|
Writes: |
|
- {dataset_name}_{split}_file_list.json: a dictionary with the following keys: |
|
{ |
|
index: |
|
{ |
|
'enstdrums_id': {drummer_id}_{3-digit-track-id} |
|
'n_frames': (int), |
|
'stem_file': Dict of stem audio file with metadata, |
|
'mix_audio_file': 'path/to/mix.wav', |
|
'notes_file': 'path/to/notes.npy', |
|
'note_events_file': 'path/to/note_events.npy', |
|
'midi_file': 'path/to/midi.mid', |
|
'program': List[int], # 128 for drums, 129 for unannotated (accompaniment) |
|
'is_drum': List[int], # 0 or 1 |
|
} |
|
} |
|
""" |
|
|
|
base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') |
|
output_index_dir = os.path.join(data_home, 'yourmt3_indexes') |
|
os.makedirs(output_index_dir, exist_ok=True) |
|
|
|
|
|
enst_ids = [] |
|
enst_info = {} |
|
for i in [1, 2, 3]: |
|
drummer_files = sorted( |
|
glob.glob(os.path.join(base_dir, f'drummer_{i}', 'annotation/*.txt'))) |
|
for file in drummer_files: |
|
track_id = os.path.basename(file).split('_')[0] |
|
enst_id = f'{i}_{track_id}' |
|
enst_ids.append(enst_id) |
|
|
|
|
|
ann_file = file |
|
assert os.path.exists(ann_file), f'{ann_file} does not exist' |
|
notes, note_events = create_note_note_event_midi_from_enst_annotation(ann_file, enst_id) |
|
notes_file = ann_file.replace('.txt', '_notes.npy') |
|
note_events_file = ann_file.replace('.txt', '_note_events.npy') |
|
np.save(notes_file, notes, allow_pickle=True, fix_imports=False) |
|
print(f"Created {notes_file}") |
|
np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) |
|
print(f"Created {note_events_file}") |
|
|
|
|
|
drum_audio_file = os.path.join(base_dir, f'drummer_{i}', 'audio', 'wet_mix', |
|
os.path.basename(file).replace('.txt', '.wav')) |
|
assert os.path.exists(drum_audio_file), f'{drum_audio_file} does not exist' |
|
|
|
if 'minus-one' in file: |
|
|
|
accomp_audio_file = os.path.join(base_dir, f'drummer_{i}', 'audio', 'accompaniment', |
|
os.path.basename(file).replace('.txt', '.wav')) |
|
assert os.path.exists(accomp_audio_file), f'{accomp_audio_file} does not exist' |
|
os.makedirs(os.path.join(base_dir, f'drummer_{i}', 'audio', 'stem'), exist_ok=True) |
|
stem_file = os.path.join(base_dir, f'drummer_{i}', 'audio', 'stem', |
|
os.path.basename(file).replace('.txt', '_stem.npy')) |
|
stem_content = create_enst_audio_stem(drum_audio_file, accomp_audio_file, enst_id) |
|
|
|
np.save(stem_file, stem_content, allow_pickle=True, fix_imports=False) |
|
print(f"Created {stem_file}") |
|
|
|
|
|
os.makedirs( |
|
os.path.join(base_dir, f'drummer_{i}', 'audio', 'accompaniment_mix_r1'), |
|
exist_ok=True) |
|
accomp_mix_audio_file_r1 = os.path.join( |
|
base_dir, f'drummer_{i}', 'audio', 'accompaniment_mix_r1', |
|
os.path.basename(file).replace('.txt', '.wav')) |
|
accomp_mix_audio_r1 = stem_content['audio_array'][0] / np.max( |
|
np.abs(stem_content['audio_array'][0])) * 0.75 + stem_content['audio_array'][ |
|
1] / np.max(np.abs(stem_content['audio_array'][1])) * 0.25 |
|
accomp_mix_audio_r1 = accomp_mix_audio_r1 / np.max(np.abs(accomp_mix_audio_r1)) |
|
write_wav_file(accomp_mix_audio_file_r1, accomp_mix_audio_r1, 16000) |
|
print(f"Created {accomp_mix_audio_file_r1}") |
|
|
|
|
|
os.makedirs( |
|
os.path.join(base_dir, f'drummer_{i}', 'audio', 'accompaniment_mix_r2'), |
|
exist_ok=True) |
|
accomp_mix_audio_file_r2 = os.path.join( |
|
base_dir, f'drummer_{i}', 'audio', 'accompaniment_mix_r2', |
|
os.path.basename(file).replace('.txt', '.wav')) |
|
accomp_mix_audio_r2 = stem_content['audio_array'][0] / np.max( |
|
np.abs(stem_content['audio_array'][0])) * 0.6 + stem_content['audio_array'][ |
|
1] / np.max(np.abs(stem_content['audio_array'][1])) * 0.4 |
|
accomp_mix_audio_r2 = accomp_mix_audio_r2 / np.max(np.abs(accomp_mix_audio_r2)) |
|
write_wav_file(accomp_mix_audio_file_r2, accomp_mix_audio_r2, 16000) |
|
print(f"Created {accomp_mix_audio_file_r2}") |
|
n_frames = len(accomp_mix_audio_r2) |
|
|
|
|
|
mix_audio_file = accomp_mix_audio_file_r2 |
|
else: |
|
|
|
stem_file = None |
|
mix_audio_file = drum_audio_file |
|
n_frames = get_audio_file_info(drum_audio_file)[1] |
|
|
|
|
|
enst_info[enst_id] = { |
|
'enstdrums_id': enst_id, |
|
'n_frames': n_frames, |
|
'stem_file': stem_file, |
|
'mix_audio_file': mix_audio_file, |
|
'notes_file': notes_file, |
|
'note_events_file': note_events_file, |
|
'midi_file': ann_file.replace('.txt', '.mid'), |
|
'program': stem_content['program'] if 'minus-one' in file else notes['program'], |
|
'is_drum': stem_content['is_drum'] if 'minus-one' in file else notes['is_drum'], |
|
} |
|
|
|
|
|
for split in [ |
|
'drummer_1_dtm', 'drummer_2_dtm', 'all_dtm', 'drummer_1_dtp', 'drummer_2_dtp', |
|
'all_dtp', 'drummer_3_dtd', 'drummer_3_dtp', 'drummer_3_dtm_r1', 'drummer_3_dtm_r2' |
|
]: |
|
|
|
file_list = {} |
|
i = 0 |
|
if split == 'drummer_1_dtm': |
|
for enst_id in enst_ids: |
|
if enst_id.startswith('1_'): |
|
file_list[str(i)] = enst_info[enst_id] |
|
i += 1 |
|
assert len(file_list) == 97 |
|
elif split == 'drummer_2_dtm': |
|
for enst_id in enst_ids: |
|
if enst_id.startswith('2_'): |
|
file_list[str(i)] = enst_info[enst_id] |
|
i += 1 |
|
assert len(file_list) == 105 |
|
elif split == 'all_dtm': |
|
for enst_id in enst_ids: |
|
file_list[str(i)] = enst_info[enst_id] |
|
i += 1 |
|
assert len(file_list) == 318 |
|
elif split == 'drummer_1_dtp': |
|
for enst_id in enst_ids: |
|
if enst_id.startswith('1_'): |
|
file_list[str(i)] = copy.deepcopy(enst_info[enst_id]) |
|
file_list[str(i)]['stem_file'] = None |
|
file_list[str(i)]['mix_audio_file'] = file_list[str( |
|
i)]['mix_audio_file'].replace('accompaniment_mix_r2', 'wet_mix') |
|
file_list[str(i)]['program'] = [128] |
|
file_list[str(i)]['is_drum'] = [1] |
|
i += 1 |
|
assert len(file_list) == 97 |
|
elif split == 'drummer_2_dtp': |
|
for enst_id in enst_ids: |
|
if enst_id.startswith('2_'): |
|
file_list[str(i)] = copy.deepcopy(enst_info[enst_id]) |
|
file_list[str(i)]['stem_file'] = None |
|
file_list[str(i)]['mix_audio_file'] = file_list[str( |
|
i)]['mix_audio_file'].replace('accompaniment_mix_r2', 'wet_mix') |
|
file_list[str(i)]['program'] = [128] |
|
file_list[str(i)]['is_drum'] = [1] |
|
i += 1 |
|
assert len(file_list) == 105 |
|
elif split == 'all_dtp': |
|
for enst_id in enst_ids: |
|
file_list[str(i)] = copy.deepcopy(enst_info[enst_id]) |
|
file_list[str(i)]['stem_file'] = None |
|
file_list[str(i)]['mix_audio_file'] = file_list[str(i)]['mix_audio_file'].replace( |
|
'accompaniment_mix_r2', 'wet_mix') |
|
file_list[str(i)]['program'] = [128] |
|
file_list[str(i)]['is_drum'] = [1] |
|
i += 1 |
|
assert len(file_list) == 318 |
|
elif split == 'drummer_3_dtd': |
|
for enst_id in enst_ids: |
|
if enst_id.startswith('3_') and len(enst_info[enst_id]['program']) == 1: |
|
assert enst_info[enst_id]['stem_file'] == None |
|
file_list[str(i)] = enst_info[enst_id] |
|
i += 1 |
|
assert len(file_list) == 95 |
|
elif split == 'drummer_3_dtp': |
|
for enst_id in enst_ids: |
|
if enst_id.startswith('3_') and len(enst_info[enst_id]['program']) == 2: |
|
file_list[str(i)] = copy.deepcopy(enst_info[enst_id]) |
|
file_list[str(i)]['stem_file'] = None |
|
|
|
file_list[str(i)]['mix_audio_file'] = file_list[str( |
|
i)]['mix_audio_file'].replace('accompaniment_mix_r2', 'wet_mix') |
|
file_list[str(i)]['program'] = [128] |
|
file_list[str(i)]['is_drum'] = [1] |
|
i += 1 |
|
assert len(file_list) == 21 |
|
elif split == 'drummer_3_dtm_r1': |
|
for enst_id in enst_ids: |
|
if enst_id.startswith('3_') and len(enst_info[enst_id]['program']) == 2: |
|
file_list[str(i)] = copy.deepcopy(enst_info[enst_id]) |
|
file_list[str(i)]['stem_file'] = None |
|
file_list[str(i)]['mix_audio_file'] = file_list[str( |
|
i)]['mix_audio_file'].replace('accompaniment_mix_r2', |
|
'accompaniment_mix_r1') |
|
i += 1 |
|
assert len(file_list) == 21 |
|
elif split == 'drummer_3_dtm_r2': |
|
for enst_id in enst_ids: |
|
if enst_id.startswith('3_') and len(enst_info[enst_id]['program']) == 2: |
|
file_list[str(i)] = copy.deepcopy(enst_info[enst_id]) |
|
file_list[str(i)]['stem_file'] = None |
|
i += 1 |
|
assert len(file_list) == 21 |
|
|
|
|
|
for k, v in file_list.items(): |
|
if v['stem_file'] is not None: |
|
assert os.path.exists(v['stem_file']) |
|
assert os.path.exists(v['mix_audio_file']) |
|
assert os.path.exists(v['notes_file']) |
|
assert os.path.exists(v['note_events_file']) |
|
assert os.path.exists(v['midi_file']) |
|
|
|
|
|
output_index_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') |
|
with open(output_index_file, 'w') as f: |
|
json.dump(file_list, f, indent=4) |
|
print(f"Created {output_index_file}") |
|
|
|
|
|
def create_filelist_dtm_random_enstdrums16k(data_home: os.PathLike, |
|
dataset_name: str = 'enstdrums') -> None: |
|
|
|
base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') |
|
output_index_dir = os.path.join(data_home, 'yourmt3_indexes') |
|
os.makedirs(output_index_dir, exist_ok=True) |
|
|
|
|
|
file_list_all_dtm_path = os.path.join(output_index_dir, |
|
f'{dataset_name}_all_dtm_file_list.json') |
|
file_list_all_dtp_path = os.path.join(output_index_dir, |
|
f'{dataset_name}_all_dtp_file_list.json') |
|
|
|
|
|
with open(file_list_all_dtm_path, 'r') as f: |
|
fl = json.load(f) |
|
fl_dtm = {} |
|
i = 0 |
|
for v in fl.values(): |
|
if 129 in v['program']: |
|
fl_dtm[i] = copy.deepcopy(v) |
|
i += 1 |
|
|
|
fl_dtd = {} |
|
i = 0 |
|
for v in fl.values(): |
|
if 129 not in v['program']: |
|
fl_dtd[i] = copy.deepcopy(v) |
|
i += 1 |
|
|
|
|
|
|
|
idx = {} |
|
idx['train_dtm'] = [ |
|
47, 58, 14, 48, 60, 44, 34, 31, 5, 62, 46, 12, 9, 26, 57, 11, 16, 22, 33, 3, 6, 55, 50, 32, |
|
52, 53, 10, 28, 24, 41, 63, 51, 43, 49, 54, 15, 20, 1, 27, 2, 23, 45, 38, 37 |
|
] |
|
idx['validation_dtm'] = [39, 4, 19, 59, 61, 17, 56, 36, 29, 0] |
|
idx['test_dtm'] = [18, 7, 42, 25, 40, 8, 30, 21, 13, 35] |
|
idx['train_dtp'] = idx['train_dtm'] |
|
idx['validation_dtp'] = idx['validation_dtm'] |
|
idx['test_dtp'] = idx['test_dtm'] |
|
|
|
for split in [ |
|
'train_dtm', |
|
'validation_dtm', |
|
'test_dtm', |
|
'train_dtp', |
|
'validation_dtp', |
|
'test_dtp', |
|
'all_dtd', |
|
]: |
|
file_list = {} |
|
i = 0 |
|
if 'dtm' in split: |
|
for k, v in fl_dtm.items(): |
|
if int(k) in idx[split]: |
|
file_list[i] = copy.deepcopy(v) |
|
i += 1 |
|
if split == 'test_dtm' or split == 'validation_dtm': |
|
|
|
for k, v in fl_dtm.items(): |
|
if int(k) in idx[split]: |
|
_v = copy.deepcopy(v) |
|
_v['mix_audio_file'] = _v['mix_audio_file'].replace( |
|
'accompaniment_mix_r2', 'accompaniment_mix_r1') |
|
file_list[i] = _v |
|
i += 1 |
|
elif 'dtp' in split: |
|
for k, v in fl_dtm.items(): |
|
if int(k) in idx[split]: |
|
_v = copy.deepcopy(v) |
|
_v['stem_file'] = None |
|
_v['mix_audio_file'] = _v['mix_audio_file'].replace( |
|
'accompaniment_mix_r2', 'wet_mix') |
|
_v['program'] = [128] |
|
_v['is_drum'] = [1] |
|
file_list[i] = _v |
|
i += 1 |
|
elif 'dtd' in split: |
|
for k, v in fl_dtd.items(): |
|
file_list[i] = copy.deepcopy(v) |
|
i += 1 |
|
else: |
|
raise ValueError(f'Unknown split: {split}') |
|
|
|
output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') |
|
with open(output_file, 'w') as f: |
|
json.dump(file_list, f, indent=4) |
|
print(f'Created {output_file}') |
|
|