|
""" preprocess_mtrack_slakh.py |
|
|
|
""" |
|
import os |
|
import time |
|
import json |
|
from typing import Dict, List, Tuple |
|
import numpy as np |
|
from utils.audio import get_audio_file_info, load_audio_file |
|
from utils.midi import midi2note |
|
from utils.note2event import note2note_event, mix_notes |
|
import mirdata |
|
from utils.mirdata_dev.datasets import slakh16k |
|
|
|
|
|
def create_audio_stem_from_mtrack(ds: mirdata.core.Dataset, |
|
mtrack_id: str, |
|
delete_source_files: bool = False) -> Dict: |
|
"""Extracts audio stems and metadata from a multitrack.""" |
|
mtrack = ds.multitrack(mtrack_id) |
|
track_ids = mtrack.track_ids |
|
max_length = 0 |
|
program_numbers = [] |
|
is_drum = [] |
|
audio_tracks = [] |
|
|
|
|
|
for track_id in track_ids: |
|
track = ds.track(track_id) |
|
audio_file = track.audio_path |
|
program_numbers.append(track.program_number) |
|
is_drum.append(1) if track.is_drum else is_drum.append(0) |
|
|
|
fs, n_frames, n_channels = get_audio_file_info(audio_file) |
|
assert (fs == 16000 and n_channels == 1) |
|
max_length = n_frames if n_frames > max_length else max_length |
|
audio = load_audio_file(audio_file, dtype=np.int16) |
|
audio = audio / 2**15 |
|
audio = audio.astype(np.float16) |
|
audio_tracks.append(audio) |
|
if delete_source_files: |
|
print(f'๐๏ธ Deleting {audio_file} ...') |
|
os.remove(audio_file) |
|
|
|
|
|
n_tracks = len(track_ids) |
|
audio_array = np.zeros((n_tracks, max_length), dtype=np.float16) |
|
for j, audio in enumerate(audio_tracks): |
|
audio_array[j, :len(audio)] = audio |
|
|
|
stem_content = { |
|
'mtrack_id': mtrack_id, |
|
'program': np.array(program_numbers, dtype=np.int64), |
|
'is_drum': np.array(is_drum, dtype=np.int64), |
|
'n_frames': max_length, |
|
'audio_array': audio_array |
|
} |
|
return stem_content |
|
|
|
|
|
def create_note_event_and_note_from_mtrack_mirdata( |
|
ds: mirdata.core.Dataset, |
|
mtrack_id: str, |
|
fix_bass_octave: bool = True) -> Tuple[Dict, Dict]: |
|
"""Extracts note or note_event and metadata from a multitrack: |
|
Args: |
|
ds (mirdata.core.Dataset): Slakh dataset. |
|
mtrack_id (str): multitrack id. |
|
Returns: |
|
notes (dict): note events and metadata. |
|
note_events (dict): note events and metadata. |
|
""" |
|
mtrack = ds.multitrack(mtrack_id) |
|
track_ids = mtrack.track_ids |
|
program_numbers = [] |
|
is_drum = [] |
|
mixed_notes = [] |
|
duration_sec = 0. |
|
|
|
|
|
for track_id in track_ids: |
|
track = ds.track(track_id) |
|
stem_midi_file = track.midi_path |
|
notes, dur_sec = midi2note( |
|
stem_midi_file, |
|
binary_velocity=True, |
|
ch_9_as_drum=False, |
|
force_all_drum=True if track.is_drum else False, |
|
force_all_program_to=None, |
|
trim_overlap=True, |
|
fix_offset=True, |
|
quantize=True, |
|
verbose=0, |
|
minimum_offset_sec=0.01, |
|
drum_offset_sec=0.01) |
|
|
|
if fix_bass_octave == True and track.program_number in np.arange(32, 40): |
|
if track.plugin_name == 'scarbee_jay_bass_slap_both.nkm': |
|
pass |
|
else: |
|
for note in notes: |
|
note.pitch -= 12 |
|
print("Fixed bass octave for track", track_id) |
|
|
|
mixed_notes = mix_notes((mixed_notes, notes), True, True, True) |
|
program_numbers.append(track.program_number) |
|
is_drum.append(1) if track.is_drum else is_drum.append(0) |
|
duration_sec = max(duration_sec, dur_sec) |
|
|
|
|
|
mixed_note_events = note2note_event(mixed_notes, sort=True, return_activity=True) |
|
return { |
|
'mtrack_id': mtrack_id, |
|
'program': np.array(program_numbers, dtype=np.int64), |
|
'is_drum': np.array(is_drum, dtype=np.int64), |
|
'duration_sec': duration_sec, |
|
'notes': mixed_notes |
|
}, { |
|
'mtrack_id': mtrack_id, |
|
'program': np.array(program_numbers, dtype=np.int64), |
|
'is_drum': np.array(is_drum, dtype=np.int64), |
|
'duration_sec': duration_sec, |
|
'note_events': mixed_note_events |
|
} |
|
|
|
|
|
def preprocess_slakh16k(data_home: str, |
|
run_checksum: bool = False, |
|
delete_source_files: bool = False, |
|
fix_bass_octave: bool = True) -> None: |
|
""" |
|
Processes the Slakh dataset and extracts stems for each multitrack. |
|
|
|
Args: |
|
data_home (str): path to the Slakh data. |
|
run_checksum (bool): if True, validates the dataset using its checksum. Default is False. |
|
delete_source_files (bool): if True, deletes original audio files. Default is False. |
|
fix_bass_octave (bool): if True, fixes the bass to be -1 octave. Slakh bass is annotated as +1 octave. Default is True. |
|
|
|
Writes: |
|
- {dataset_name}_{split}_file_list.json: a dictionary with the following keys: |
|
{ |
|
'mtrack_id': mtrack_id, |
|
'n_frames': n of audio frames |
|
'stem_file': Dict of stem audio file info |
|
'mix_audio_file': mtrack.mix_path, |
|
'notes_file': available only for 'validation' and 'test' |
|
'note_events_file': available only for 'train' and 'validation' |
|
'midi_file': mtrack.midi_path |
|
} |
|
""" |
|
start_time = time.time() |
|
|
|
ds = slakh16k.Dataset(data_home=data_home, version='2100-yourmt3-16k') |
|
if run_checksum: |
|
print('Checksum for slakh dataset...') |
|
ds.validate() |
|
print('Preprocessing slakh dataset...') |
|
|
|
mtrack_split_dict = ds.get_mtrack_splits() |
|
for split in ['train', 'validation', 'test']: |
|
file_list = {} |
|
mtrack_ids = mtrack_split_dict[split] |
|
|
|
for i, mtrack_id in enumerate(mtrack_ids): |
|
print(f'๐๐ปโโ๏ธ: processing {mtrack_id} ({i+1}/{len(mtrack_ids)} in {split})') |
|
mtrack = ds.multitrack(mtrack_id) |
|
output_dir = os.path.dirname(mtrack.mix_path) |
|
"""Audio: get stems (as array) and metadata from the multitrack""" |
|
stem_content = create_audio_stem_from_mtrack(ds, mtrack_id, delete_source_files) |
|
|
|
|
|
stem_file = os.path.join(output_dir, mtrack_id + '_stem.npy') |
|
np.save(stem_file, stem_content) |
|
print(f'๐ฟ Created {stem_file}') |
|
|
|
|
|
"""MIDI: pre-process and get metadata from the multitrack""" |
|
notes, note_events = create_note_event_and_note_from_mtrack_mirdata( |
|
ds, mtrack_id, fix_bass_octave=fix_bass_octave) |
|
|
|
notes_file = os.path.join(output_dir, mtrack_id + '_notes.npy') |
|
np.save(notes_file, notes, allow_pickle=True, \ |
|
fix_imports=False) |
|
print(f'๐น Created {notes_file}') |
|
|
|
note_events_file = os.path.join(output_dir, mtrack_id + '_note_events.npy') |
|
np.save(note_events_file, note_events, allow_pickle=True, \ |
|
fix_imports=False) |
|
print(f'๐น Created {note_events_file}') |
|
|
|
|
|
file_list[i] = { |
|
'mtrack_id': mtrack_id, |
|
'n_frames': stem_content['n_frames'], |
|
'stem_file': stem_file, |
|
'mix_audio_file': mtrack.mix_path, |
|
'notes_file': notes_file, |
|
'note_events_file': note_events_file,\ |
|
'midi_file': mtrack.midi_path |
|
} |
|
|
|
summary_dir = os.path.join(data_home, 'yourmt3_indexes') |
|
os.makedirs(summary_dir, exist_ok=True) |
|
summary_file = os.path.join(summary_dir, f'slakh_{split}_file_list.json') |
|
with open(summary_file, 'w') as f: |
|
json.dump(file_list, f, indent=4) |
|
print(f'๐พ Created {summary_file}') |
|
|
|
elapsed_time = time.time() - start_time |
|
print( |
|
f"โฐ: {int(elapsed_time // 3600):02d}h {int(elapsed_time % 3600 // 60):02d}m {elapsed_time % 60:.2f}s" |
|
) |
|
""" end of preprocess_slakh16k """ |
|
|
|
|
|
def add_program_and_is_drum_info_to_file_list(data_home: str): |
|
|
|
for split in ['train', 'validation', 'test']: |
|
file_list_dir = os.path.join(data_home, 'yourmt3_indexes') |
|
file = os.path.join(file_list_dir, f'slakh_{split}_file_list.json') |
|
with open(file, 'r') as f: |
|
file_list = json.load(f) |
|
|
|
for v in file_list.values(): |
|
stem_file = v['stem_file'] |
|
stem_content = np.load(stem_file, allow_pickle=True).item() |
|
v['program'] = stem_content['program'].tolist() |
|
v['is_drum'] = stem_content['is_drum'].tolist() |
|
|
|
with open(file, 'w') as f: |
|
json.dump(file_list, f, indent=4) |
|
print(f'๐พ Added program and drum info to {file}') |
|
|
|
|
|
if __name__ == '__main__': |
|
from config.config import shared_cfg |
|
data_home = shared_cfg['PATH']['data_home'] |
|
preprocess_slakh16k(data_home=data_home, delete_source_files=False) |