|
"""preprocess_geerdes.py""" |
|
import os |
|
import glob |
|
import re |
|
import json |
|
import csv |
|
import logging |
|
import random |
|
from typing import Dict, List, Tuple |
|
from copy import deepcopy |
|
|
|
import numpy as np |
|
from utils.audio import get_audio_file_info, load_audio_file |
|
from utils.midi import midi2note, note_event2midi |
|
from utils.note2event import (note2note_event, sort_notes, validate_notes, trim_overlapping_notes, |
|
extract_program_from_notes, extract_notes_selected_by_programs) |
|
from utils.event2note import event2note_event |
|
from utils.note_event_dataclasses import Note, NoteEvent |
|
from utils.utils import note_event2token2note_event_sanity_check, create_inverse_vocab |
|
from config.vocabulary import MT3_FULL_PLUS |
|
|
|
GEERDES_DATA_CSV_FILENAME = 'geerdes_data_final.csv' |
|
DRUM_CHANNEL = 9 |
|
DRUM_PROGRAM = 128 |
|
SINGING_VOICE_PROGRAM = 100 |
|
SINGING_VOICE_CHORUS_PROGRAM = 101 |
|
TRACK_NAME_TO_PROGRAM_MAP = { |
|
"vocal": SINGING_VOICE_PROGRAM, |
|
"vocalist": SINGING_VOICE_PROGRAM, |
|
"2nd Vocals/backings/harmony": SINGING_VOICE_CHORUS_PROGRAM, |
|
"backvocals": SINGING_VOICE_CHORUS_PROGRAM, |
|
} |
|
|
|
|
|
def format_number(n, width=5): |
|
""" |
|
Format a number to a fixed width string, padding with leading zeros if needed. |
|
|
|
Parameters: |
|
- n (int): The number to be formatted. |
|
- width (int, optional): The desired fixed width for the resulting string. Default is 5. |
|
|
|
Returns: |
|
- str: The formatted string representation of the number. |
|
|
|
Example: |
|
>>> format_number(123) |
|
'00123' |
|
>>> format_number(7, 3) |
|
'007' |
|
""" |
|
return f"{int(n):0{width}}" |
|
|
|
|
|
def find_index_with_key(lst, key): |
|
|
|
def filter_string(s): |
|
return re.sub(r'[^a-zA-Z0-9]', '', s) |
|
|
|
filtered_key = filter_string(key).lower() |
|
indices = [ |
|
index for index, value in enumerate(lst) if filtered_key in filter_string(value.lower()) |
|
] |
|
|
|
if len(indices) > 1: |
|
raise ValueError(f"'{key}'has more than two matching song titles.") |
|
elif len(indices) == 1: |
|
return indices[0] |
|
else: |
|
return None |
|
|
|
|
|
"""Code below was used to generate the "geerdes_data_final.csv" file for the Geerdes dataset split info.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_note_event_and_note_from_midi(mid_file: str, |
|
id: str, |
|
ch_9_as_drum: bool = True, |
|
track_name_to_program: Dict = None, |
|
ignore_pedal: bool = False) -> Tuple[Dict, Dict]: |
|
"""Create note_events and notes from midi file.""" |
|
|
|
|
|
notes, dur_sec, program = midi2note( |
|
mid_file, |
|
ch_9_as_drum=ch_9_as_drum, |
|
track_name_to_program=track_name_to_program, |
|
binary_velocity=True, |
|
ignore_pedal=ignore_pedal, |
|
return_programs=True) |
|
program = [x for x in set(program) if x is not None] |
|
return { |
|
'geerdes_id': id, |
|
'program': program, |
|
'is_drum': [1 if p == DRUM_PROGRAM else 0 for p in program], |
|
'duration_sec': dur_sec, |
|
'notes': notes, |
|
}, { |
|
'geerdes_id': id, |
|
'program': program, |
|
'is_drum': [1 if p == DRUM_PROGRAM else 0 for p in program], |
|
'duration_sec': dur_sec, |
|
'note_events': note2note_event(notes), |
|
} |
|
|
|
|
|
def preprocess_geerdes16k(data_home=os.PathLike, |
|
dataset_name='geerdes', |
|
sanity_check=False) -> None: |
|
"""Preprocess Geerdes dataset.""" |
|
|
|
base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') |
|
output_index_dir = os.path.join(data_home, 'yourmt3_indexes') |
|
os.makedirs(output_index_dir, exist_ok=True) |
|
|
|
|
|
log_file = os.path.join(base_dir, 'log.txt') |
|
logger = logging.getLogger('my_logger') |
|
logger.setLevel(logging.DEBUG) |
|
file_handler = logging.FileHandler(log_file) |
|
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') |
|
file_handler.setFormatter(formatter) |
|
if not logger.handlers: |
|
logger.addHandler(file_handler) |
|
console_handler = logging.StreamHandler() |
|
console_handler.setLevel(logging.DEBUG) |
|
console_formatter = logging.Formatter('%(levelname)s - %(message)s') |
|
console_handler.setFormatter(console_formatter) |
|
logger.addHandler(console_handler) |
|
|
|
|
|
ymt3_geerdes_csv_file = os.path.join(base_dir, GEERDES_DATA_CSV_FILENAME) |
|
tracks_all = {} |
|
with open(ymt3_geerdes_csv_file, mode='r', encoding='utf-8') as file: |
|
reader = csv.DictReader(file) |
|
for row in reader: |
|
geerdes_id = row['id'] |
|
tracks_all[geerdes_id] = row |
|
|
|
for v in tracks_all.values(): |
|
v['audio_file'] = os.path.join(base_dir, v['audio_file']) |
|
v['midi_file'] = os.path.join(base_dir, v['midi_file']) |
|
logger.info(f'Loaded {len(tracks_all)} tracks from {ymt3_geerdes_csv_file}.') |
|
|
|
|
|
note_processed_dir = os.path.join(base_dir, 'note_processed') |
|
os.makedirs(note_processed_dir, exist_ok=True) |
|
|
|
for geerdes_id, v in tracks_all.items(): |
|
midi_file = v['midi_file'] |
|
|
|
|
|
notes, note_events = create_note_event_and_note_from_midi( |
|
mid_file=midi_file, |
|
id=geerdes_id, |
|
ch_9_as_drum=True, |
|
track_name_to_program=TRACK_NAME_TO_PROGRAM_MAP, |
|
ignore_pedal=False) |
|
|
|
|
|
if sanity_check is True: |
|
err_cnt = note_event2token2note_event_sanity_check(note_events['note_events'], |
|
notes['notes']) |
|
if len(err_cnt) > 0: |
|
logging.warning(f'Found {err_cnt} errors in {geerdes_id}.') |
|
|
|
|
|
notes_file = os.path.join(note_processed_dir, geerdes_id + '_notes.npy') |
|
np.save(notes_file, notes, allow_pickle=True, fix_imports=False) |
|
logger.info(f'Created {notes_file}.') |
|
|
|
note_events_file = os.path.join(note_processed_dir, geerdes_id + '_note_events.npy') |
|
np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) |
|
logger.info(f'Created {note_events_file}.') |
|
|
|
|
|
recon_midi_file = os.path.join(note_processed_dir, geerdes_id + '_recon.mid') |
|
inverse_vocab = create_inverse_vocab(MT3_FULL_PLUS) |
|
note_event2midi( |
|
note_events['note_events'], recon_midi_file, output_inverse_vocab=inverse_vocab) |
|
logger.info(f'Created {recon_midi_file}.') |
|
|
|
|
|
tracks_all[geerdes_id]['notes_file'] = notes_file |
|
tracks_all[geerdes_id]['note_events_file'] = note_events_file |
|
tracks_all[geerdes_id]['recon_midi_file'] = recon_midi_file |
|
tracks_all[geerdes_id]['program'] = notes['program'] |
|
tracks_all[geerdes_id]['is_drum'] = notes['is_drum'] |
|
|
|
|
|
notes_voc = deepcopy(notes) |
|
notes_voc['notes'] = extract_notes_selected_by_programs( |
|
notes['notes'], [SINGING_VOICE_PROGRAM, SINGING_VOICE_CHORUS_PROGRAM]) |
|
notes_voc['program'] = list(extract_program_from_notes(notes_voc['notes'])) |
|
notes_voc['is_drum'] = [1 if p == DRUM_PROGRAM else 0 for p in notes_voc['program']] |
|
notes_voc_file = os.path.join(note_processed_dir, geerdes_id + '_notes_voc.npy') |
|
np.save(notes_voc_file, notes_voc, allow_pickle=True, fix_imports=False) |
|
|
|
note_events_voc = deepcopy(note_events) |
|
note_events_voc['note_events'] = note2note_event(notes_voc['notes']) |
|
note_events_voc['program'] = deepcopy(notes_voc['program']) |
|
note_events_voc['is_drum'] = deepcopy(notes_voc['is_drum']) |
|
note_events_voc_file = os.path.join(note_processed_dir, geerdes_id + '_note_events_voc.npy') |
|
np.save(note_events_voc_file, note_events_voc, allow_pickle=True, fix_imports=False) |
|
|
|
notes_acc = deepcopy(notes) |
|
notes_acc['notes'] = extract_notes_selected_by_programs(notes['notes'], [ |
|
p for p in notes['program'] |
|
if p not in [SINGING_VOICE_PROGRAM, SINGING_VOICE_CHORUS_PROGRAM] |
|
]) |
|
notes_acc['program'] = list(extract_program_from_notes(notes_acc['notes'])) |
|
notes_acc['is_drum'] = [1 if p == DRUM_PROGRAM else 0 for p in notes_acc['program']] |
|
notes_acc_file = os.path.join(note_processed_dir, geerdes_id + '_notes_acc.npy') |
|
np.save(notes_acc_file, notes_acc, allow_pickle=True, fix_imports=False) |
|
|
|
note_events_acc = deepcopy(note_events) |
|
note_events_acc['note_events'] = note2note_event(notes_acc['notes']) |
|
note_events_acc['program'] = deepcopy(notes_acc['program']) |
|
note_events_acc['is_drum'] = deepcopy(notes_acc['is_drum']) |
|
note_events_acc_file = os.path.join(note_processed_dir, geerdes_id + '_note_events_acc.npy') |
|
np.save(note_events_acc_file, note_events_acc, allow_pickle=True, fix_imports=False) |
|
|
|
tracks_all[geerdes_id]['notes_file_voc'] = notes_voc_file |
|
tracks_all[geerdes_id]['note_events_file_voc'] = note_events_voc_file |
|
tracks_all[geerdes_id]['program_voc'] = notes_voc['program'] |
|
tracks_all[geerdes_id]['is_drum_voc'] = notes_voc['is_drum'] |
|
tracks_all[geerdes_id]['notes_file_acc'] = notes_acc_file |
|
tracks_all[geerdes_id]['note_events_file_acc'] = note_events_acc_file |
|
tracks_all[geerdes_id]['program_acc'] = notes_acc['program'] |
|
tracks_all[geerdes_id]['is_drum_acc'] = notes_acc['is_drum'] |
|
|
|
|
|
for geerdes_id, v in tracks_all.items(): |
|
v['mix_audio_file'] = v['audio_file'] |
|
v['mix_audio_file_voc'] = v['audio_file'].replace('.wav', '_vocals.wav') |
|
v['mix_audio_file_acc'] = v['audio_file'].replace('.wav', '_accompaniment.wav') |
|
assert os.path.exists(v['mix_audio_file']) |
|
assert os.path.exists(v['mix_audio_file_voc']) |
|
assert os.path.exists(v['mix_audio_file_acc']) |
|
v['n_frames'] = get_audio_file_info(v['mix_audio_file'])[1] |
|
logger.info(f'Checked audio files. All audio files exist.') |
|
|
|
|
|
splits = ['train', 'validation', 'all'] |
|
task_suffixes = ['', '_sep'] |
|
|
|
for task_suffix in task_suffixes: |
|
for split in splits: |
|
|
|
file_list = {} |
|
cur_idx = 0 |
|
for geerdes_id, v in tracks_all.items(): |
|
if v['split_half'] == split or split == 'all': |
|
if task_suffix == '': |
|
file_list[cur_idx] = { |
|
'geerdes_id': geerdes_id, |
|
'n_frames': v['n_frames'], |
|
'mix_audio_file': v['mix_audio_file'], |
|
'notes_file': v['notes_file'], |
|
'note_events_file': v['note_events_file'], |
|
'midi_file': v['midi_file'], |
|
'program': v['program'], |
|
'is_drum': v['is_drum'], |
|
} |
|
cur_idx += 1 |
|
elif task_suffix == '_sep': |
|
file_list[cur_idx] = { |
|
'geerdes_id': geerdes_id, |
|
'n_frames': v['n_frames'], |
|
'mix_audio_file': v['mix_audio_file_voc'], |
|
'notes_file': v['notes_file_voc'], |
|
'note_events_file': v['note_events_file_voc'], |
|
'midi_file': v['midi_file'], |
|
'program': v['program_voc'], |
|
'is_drum': v['is_drum_voc'], |
|
} |
|
cur_idx += 1 |
|
file_list[cur_idx] = { |
|
'geerdes_id': geerdes_id, |
|
'n_frames': v['n_frames'], |
|
'mix_audio_file': v['mix_audio_file_acc'], |
|
'notes_file': v['notes_file_acc'], |
|
'note_events_file': v['note_events_file_acc'], |
|
'midi_file': v['midi_file'], |
|
'program': v['program_acc'], |
|
'is_drum': v['is_drum_acc'], |
|
} |
|
cur_idx += 1 |
|
|
|
file_list_file = os.path.join(output_index_dir, |
|
f'{dataset_name}_{split}{task_suffix}_file_list.json') |
|
with open(file_list_file, 'w') as f: |
|
json.dump(file_list, f, indent=4) |
|
logger.info(f'Created {file_list_file}.') |