YourMT3 / amt /src /utils /preprocess /preprocess_rwc_pop_full.py
mimbres's picture
.
a03c9b4
raw
history blame
12.8 kB
"""preprocess_rwc_pop.py"""
import os
import glob
import re
import json
import csv
from typing import Dict, List, Any, Tuple
import numpy as np
from utils.audio import get_audio_file_info, load_audio_file
from utils.midi import midi2note, note_event2midi
from utils.note2event import note2note_event, sort_notes, validate_notes, trim_overlapping_notes, extract_program_from_notes
from utils.event2note import event2note_event
from utils.note_event_dataclasses import Note, NoteEvent
from utils.utils import note_event2token2note_event_sanity_check
from mido import MetaMessage, Message, MidiFile, MidiTrack
# UNUSED_IDS = ["010", "071", "099", "023", "034", "036", "038", "049", "060", "062"]
# UNUSED_IDS = ["071", "099", "049", "060", "062"]
UNUSED_IDS = []
DRUM_CHANNEL = 9 # all drums are in channel 9 in geerdes dataset
DRUM_PROGRAM = 128
SINGING_VOICE_PROGRAM = 100
SINGING_VOICE_CHORUS_PROGRAM = 101
TRACK_NAME_TO_PROGRAM_MAP = { # compared by exact match of lowercase
"Singing Voice": SINGING_VOICE_PROGRAM,
"Singing Voice (Chorus)": SINGING_VOICE_CHORUS_PROGRAM,
"Drums": DRUM_PROGRAM,
}
# yapf: disable
TRACK_NAME_FILTERS = {
SINGING_VOICE_PROGRAM: {"include": ["MELO", "VOCAL"], "exclude": ["SUB", "GT"]},
SINGING_VOICE_CHORUS_PROGRAM: {"include": ["CHORUS", "SUB VOCAL", "SUB MELO"],
"exclude": ["/", "GT"]},
DRUM_PROGRAM: {"include": ["DRUMS", "DR", "HIHAT", "BD&SD", "TOM", "KICK"],
"exclude": ["ATOMOS"], "exact": ["DS"]},
0: {"include": ["P.F.", "PF", "PIANO", "A.P", "CLAV", "CEMBAL", "HARPSI"], "exclude": ["E.PIANO", "MARIMBA"]},
2: {"include": ["E.P"], "exclude": []},
8: {"include": ["GLOCKEN", "VIBRA", "VIBE", "MARIMBA", "BELL", "CHIME", "CHAIM", "KALIMB", "CHIMRE", "MALLET"],
"exclude": []},
16: {"include": ["ORG", "HAMO", "HARMONICA", "ACCORD"], "exclude": []},
24: {"include": ["MANDORIN", "AG", "NYLON", "AC.G", "GUITAR", "A.G", "E.G", "GT", "G. SOLO", "CLEAN LEAD", "SITAR", "ATOMOS", "ATMOS",
"CLEAN"],
"exclude": ["DIST", "DIS.", "D.", "E.G SOLO", "E.G.SOLO"]},
30: {"include": ["OD L", "OD R", "DIS.", "DIST GT", "D.G", "DIST", "DIS.SOLO", "E.GUITAR (SOLO)", "E.G SOLO", "LEAD", "E.G.SOLO", "EG", "GT MELO"],
"exclude": ["PAD","SYN.LEAD"]},
33: {"include": ["BASS"], "exclude": []},
48: {"include": ["OR 2", "ST", "STR", "ORCH", "PIZZ", "HIT", "TIMPANI", "VIORA", "VIOLA", "VIOLIN", "VN", "VA", "VC", "HARP", "LO FI", "CHO", "VLN", "CELLO"],
"exclude": ["CHORUS", "HARPSI", "STEEL", "GUITAR", "PAD", "BRASS", "GT", "HORN"],
"exact": ["OR"]},
56: {"include": ["BRAS", "TRUMP", "TP", "TB", "TROM", "HORN", "FLUGEL"], "exclude": []},
64: {"include": ["SAX", "OBOE", "BASS"], "exclude": ["SYNSAX"]},
72: {"include": ["FLUTE", "PICO", "BOTTLE", "GAYA"], "exclude": []},
80: {"include": ["S SOLO", "SYN SOLO", "SOLO SYNTH", "SYNTH SOLO", "SYN.LEAD", "SYNTH(SEQ)", "PORTASYN", "SQ", "SEQ", "VOICE"], "exclude": []},
88: {"include": ["SYNTH", "SYN", "PAD", "FANTASIA", "BRIGHTNESS", "FANTASY"], "exclude": ["SYNBELL", "PORTA", "SOLO", "SEQ", "LEAD", "ORGAN", "BRAS", "BASS", "TROM"]},
None: {"include": ["INTRO SE", "WOW", "PERC", "EXC", "REVERSE", "GONG", "PER.", "RAP", "REV", "S.E", "LASER",
"LESER", "TAMBOURINE", "KANE", "PER", "SHAKER", "RWC-MDB"],
"exclude": [],
"exact": ["SE", "EX", "808", "ICERAIN"]},
"USE RWC PROGRAM MAP": {"include": ["KIRA", "KILA", "ETHNIC&GK"], "exclude": [], "exact": ["FUE", "OU-01A"]},
}
# yapf: enable
RWC_PROGRAM_MAP = {
9: 8,
11: 8,
74: 72,
94: 80,
98: 88,
100: 88,
}
PRG2CH = {
0: (0, "Acoustic Piano"),
2: (1, "Electric Piano"),
8: (2, "Chromatic Percussion"),
16: (3, "Organ"),
24: (4, "Guitar (clean)"),
30: (5, "Guitar (distortion)"),
33: (6, "Bass"),
48: (7, "Strings"),
56: (8, "Brass"),
DRUM_PROGRAM: (9, "Drums"),
64: (10, "Reed"),
72: (11, "Pipe"),
80: (12, "Synth Lead"),
88: (13, "Synth Pad"),
SINGING_VOICE_PROGRAM: (14, "Singing Voice"),
SINGING_VOICE_CHORUS_PROGRAM: (15, "Singing Voice (Chorus)"),
}
def find_matching_filters(input_text, filters):
input_text = input_text.upper()
def text_matches_filter(text, filter_dict):
matchness = False
if "exact" in filter_dict:
for keyword in filter_dict["exact"]:
if keyword == text:
matchness = True
break
for keyword in filter_dict["include"]:
if keyword in text:
matchness = True
break
for keyword in filter_dict["exclude"]:
if keyword in text:
matchness = False
break
return matchness
matching_filters = []
for filter_name, filter_dict in filters.items():
if text_matches_filter(input_text, filter_dict):
matching_filters.append(filter_name)
return matching_filters
def generate_corrected_midi(org_mid_file: os.PathLike,
new_mid_file: os.PathLike,
filters: Dict[Any, Dict[str, List]],
prg2ch: Dict[int, Tuple[int, str]]):
# Load original MIDI file
org_mid = MidiFile(org_mid_file)
# Create a new MIDI file
new_mid = MidiFile(ticks_per_beat=org_mid.ticks_per_beat)
# Extract global messages from the first track (usually the master track)
global_messages = [msg for msg in org_mid.tracks[0] if msg.is_meta]
global_track = MidiTrack(global_messages)
new_mid.tracks.append(global_track)
# Loop over all tracks
for track in org_mid.tracks[1:]:
# Get track name
track_name = None
for msg in track:
if msg.type == 'track_name':
track_name = msg.name
break
if track_name is None:
raise ValueError('track name not found in midi file')
# Get program number from track name
matching_filters = find_matching_filters(track_name, filters)
assert (len(matching_filters) != 0)
if isinstance(matching_filters[0], int):
program = matching_filters[0]
elif matching_filters[0] == "USE RWC PROGRAM MAP":
for msg in track:
if msg.type == 'program_change':
program = RWC_PROGRAM_MAP.get(msg.program, msg.program)
break
elif matching_filters[0] == None:
continue
# Get channel and new track name
ch, new_track_name = prg2ch[program]
# Copy messages to new track with new program, channel, and track_name
new_track = MidiTrack()
new_track.append(MetaMessage('track_name', name=new_track_name,
time=0))
if program == DRUM_PROGRAM:
new_track.append(
Message('program_change', program=0, time=0, channel=9))
else:
new_track.append(
Message('program_change', program=program, time=0, channel=ch))
new_mid.tracks.append(new_track)
for msg in track:
if msg.type in ['track_name', 'instrument_name', 'program_change']:
continue
else:
new_msg = msg.copy()
if hasattr(msg, 'channel'):
new_msg.channel = ch
new_track.append(new_msg)
# Save new MIDI file
new_mid.save(new_mid_file)
print(f'Created {new_mid_file}')
def check_file_existence(file: str) -> bool:
"""Checks if file exists."""
res = True
if not os.path.exists(file):
res = False
elif get_audio_file_info(file)[1] < 10 * 16000:
print(f'File {file} is too short.')
res = False
return res
def create_note_event_and_note_from_midi(
mid_file: str,
id: str,
ch_9_as_drum: bool = False,
track_name_to_program: Dict = None,
ignore_pedal: bool = False) -> Tuple[Dict, Dict]:
"""Create note_events and notes from midi file."""
# Load midi file
notes, dur_sec, program = midi2note(
mid_file,
ch_9_as_drum=ch_9_as_drum,
track_name_to_program=track_name_to_program,
binary_velocity=True,
ignore_pedal=ignore_pedal,
return_programs=True)
program = [x for x in set(program)
if x is not None] # remove None and duplicates
return { # notes
'rwc_pop_id': id,
'program': program,
'is_drum': [1 if p == DRUM_PROGRAM else 0 for p in program],
'duration_sec': dur_sec,
'notes': notes,
}, { # note_events
'rwc_pop_id': id,
'program': program,
'is_drum': [1 if p == DRUM_PROGRAM else 0 for p in program],
'duration_sec': dur_sec,
'note_events': note2note_event(notes),
}
def preprocess_rwc_pop_full16k(data_home='../../data',
dataset_name='rwc_pop') -> None:
# Directory and file paths
base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k')
output_index_dir = os.path.join(data_home, 'yourmt3_indexes')
os.makedirs(output_index_dir, exist_ok=True)
# Load CSV: construct id to midi/wav dictionary
csv_file = os.path.join(base_dir, 'wav_to_midi_filename_mapping.csv')
rwc_all = {}
with open(csv_file, 'r') as f:
reader = csv.reader(f)
headers = next(reader)
for row in reader:
id = row[2]
mix_audio_file = os.path.join(base_dir, headers[0] + row[0],
row[1] + ' ' + headers[1] + '.wav')
assert check_file_existence(mix_audio_file)
mid_file = os.path.join(base_dir, 'MIDI', id + '.mid')
assert os.path.exists(mid_file)
notes_file = mid_file.replace('.mid', '_notes.npy')
note_events_file = mid_file.replace('.mid', '_note_events.npy')
rwc_all[id] = {
'rwc_pop_id': id,
'n_frames': get_audio_file_info(mix_audio_file)[1],
'mix_audio_file': mix_audio_file,
'notes_file': notes_file,
'note_events_file': note_events_file,
'midi_file': mid_file,
'program': None,
'is_drum': None,
}
assert len(rwc_all) == 100
# Generate corrected MIDI files by reassigning program numbers
os.makedirs(os.path.join(base_dir, 'MIDI_full_corrected'), exist_ok=True)
for id, info in rwc_all.items():
org_mid_file = info['midi_file']
new_mid_file = org_mid_file.replace('/MIDI/', '/MIDI_full_corrected/')
generate_corrected_midi(org_mid_file,
new_mid_file,
filters=TRACK_NAME_FILTERS,
prg2ch=PRG2CH)
# Update file path with corrected MIDI file
rwc_all[id]['midi_file'] = new_mid_file
rwc_all[id]['notes_file'] = new_mid_file.replace('.mid', '_notes.npy')
rwc_all[id]['note_events_file'] = new_mid_file.replace(
'.mid', '_note_events.npy')
# Unused ids
for id in UNUSED_IDS:
rwc_all.pop(str(id))
print(f'Number of used IDs: {len(rwc_all)}, Unused ids: {UNUSED_IDS}')
# Create note and note_event files
for id in rwc_all.keys():
midi_file = rwc_all[id]['midi_file']
notes_file = rwc_all[id]['notes_file']
note_events_file = rwc_all[id]['note_events_file']
# Create note and note_event files
notes, note_events = create_note_event_and_note_from_midi(
midi_file,
id,
ch_9_as_drum=False, # we will use track_name_to_program instead
track_name_to_program=TRACK_NAME_TO_PROGRAM_MAP,
ignore_pedal=False)
# Update programs and is_drum
rwc_all[id]['program'] = notes['program']
rwc_all[id]['is_drum'] = notes['is_drum']
# Save note and note_event files
np.save(notes_file, notes, allow_pickle=True, fix_imports=False)
print(f'Created {notes_file}')
np.save(note_events_file,
note_events,
allow_pickle=True,
fix_imports=False)
print(f'Created {note_events_file}')
# Save index file
split = 'full'
output_index_file = os.path.join(output_index_dir,
f'rwc_pop_{split}_file_list.json')
file_list = {}
for i, id in enumerate(rwc_all.keys()):
file_list[i] = rwc_all[id]
with open(output_index_file, 'w') as f:
json.dump(file_list, f, indent=4)
print(f'Created {output_index_file}')