|
"""preprocess_rwc_pop.py""" |
|
import os |
|
import json |
|
import csv |
|
from typing import Dict, List, Tuple |
|
import numpy as np |
|
from utils.audio import get_audio_file_info, load_audio_file |
|
from utils.midi import midi2note, note_event2midi |
|
from utils.note2event import note2note_event, sort_notes, validate_notes, trim_overlapping_notes, extract_program_from_notes |
|
from utils.event2note import event2note_event |
|
from utils.note_event_dataclasses import Note, NoteEvent |
|
from utils.utils import note_event2token2note_event_sanity_check |
|
from mido import Message, MidiFile |
|
|
|
ID_NO_BASS = ['071', '072', '073', '074', '075', '076', '077', '078', '079', '080'] |
|
|
|
|
|
def check_file_existence(file: str) -> bool: |
|
"""Checks if file exists.""" |
|
res = True |
|
if not os.path.exists(file): |
|
res = False |
|
elif get_audio_file_info(file)[1] < 10 * 16000: |
|
print(f'File {file} is too short.') |
|
res = False |
|
return res |
|
|
|
|
|
def create_note_event_and_note_from_midi(mid_file: str, |
|
id: str, |
|
ignore_pedal: bool = True) -> Tuple[Dict, Dict]: |
|
"""Extracts note or note_event and metadata from midi: |
|
|
|
Returns: |
|
notes (dict): note events and metadata. |
|
note_events (dict): note events and metadata. |
|
""" |
|
notes, dur_sec, programs = midi2note( |
|
mid_file, |
|
binary_velocity=True, |
|
ch_9_as_drum=True, |
|
trim_overlap=True, |
|
fix_offset=True, |
|
quantize=True, |
|
verbose=0, |
|
minimum_offset_sec=0.01, |
|
drum_offset_sec=0.01, |
|
ignore_pedal=ignore_pedal, |
|
return_programs=True) |
|
|
|
|
|
has_drum = False |
|
for note in notes: |
|
if note.is_drum: |
|
has_drum = True |
|
is_drum = [0] * len(programs) |
|
if has_drum: |
|
is_drum[9] = 1 |
|
|
|
return { |
|
'rwc_pop_id': id, |
|
'program': programs, |
|
'is_drum': is_drum, |
|
'duration_sec': dur_sec, |
|
'notes': notes, |
|
}, { |
|
'rwc_pop_id': id, |
|
'program': programs, |
|
'is_drum': is_drum, |
|
'duration_sec': dur_sec, |
|
'note_events': note2note_event(notes), |
|
} |
|
|
|
|
|
def preprocess_rwc_pop16k(data_home=os.PathLike, dataset_name='rwc_pop') -> None: |
|
|
|
base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') |
|
output_index_dir = os.path.join(data_home, 'yourmt3_indexes') |
|
os.makedirs(output_index_dir, exist_ok=True) |
|
|
|
|
|
csv_file = os.path.join(base_dir, 'wav_to_midi_filename_mapping.csv') |
|
rwc_bass = {} |
|
with open(csv_file, 'r') as f: |
|
reader = csv.reader(f) |
|
headers = next(reader) |
|
|
|
for row in reader: |
|
id = row[2] |
|
|
|
|
|
|
|
|
|
|
|
|
|
mix_audio_file = os.path.join(base_dir, headers[0] + row[0], |
|
row[1] + ' ' + headers[1] + '.wav') |
|
assert check_file_existence(mix_audio_file) |
|
|
|
mid_file = os.path.join(base_dir, 'MIDI-Bass-Octave-fixed-v2', id + '_bass.mid') |
|
|
|
if not os.path.exists(mid_file): |
|
print(mid_file, "does not exist") |
|
continue |
|
|
|
notes_file = mid_file.replace('.mid', '_notes.npy') |
|
note_events_file = mid_file.replace('.mid', '_note_events.npy') |
|
|
|
rwc_bass[id] = { |
|
'rwc_pop_id': id, |
|
'n_frames': get_audio_file_info(mix_audio_file)[1], |
|
'mix_audio_file': mix_audio_file, |
|
'notes_file': notes_file, |
|
'note_events_file': note_events_file, |
|
'midi_file': mid_file, |
|
'program': None, |
|
'is_drum': None, |
|
} |
|
assert len(rwc_bass) == 90 |
|
|
|
|
|
for id in rwc_bass.keys(): |
|
midi_file = rwc_bass[id]['midi_file'] |
|
notes_file = rwc_bass[id]['notes_file'] |
|
note_events_file = rwc_bass[id]['note_events_file'] |
|
|
|
|
|
notes, note_events = create_note_event_and_note_from_midi(midi_file, id, ignore_pedal=True) |
|
|
|
|
|
rwc_bass[id]['program'] = notes['program'] |
|
rwc_bass[id]['is_drum'] = notes['is_drum'] |
|
|
|
|
|
np.save(notes_file, notes, allow_pickle=True, fix_imports=False) |
|
print(f'Created {notes_file}') |
|
np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) |
|
print(f'Created {note_events_file}') |
|
|
|
|
|
bpm120_midi_file = midi_file.replace('.mid', '_bpm120.mid') |
|
note_event2midi(note_events['note_events'], bpm120_midi_file) |
|
print(f'Created {bpm120_midi_file}') |
|
|
|
|
|
split = 'bass' |
|
output_index_file = os.path.join(output_index_dir, f'rwc_pop_{split}_file_list.json') |
|
|
|
file_list = {} |
|
for i, id in enumerate(rwc_bass.keys()): |
|
file_list[i] = rwc_bass[id] |
|
|
|
with open(output_index_file, 'w') as f: |
|
json.dump(file_list, f, indent=4) |
|
print(f'Created {output_index_file}') |
|
|