Spaces:
Build error
Build error
""" preprocess_guitarset.py """ | |
import os | |
import glob | |
import copy | |
import json | |
from typing import Dict, List, Tuple, Optional | |
import numpy as np | |
import jams | |
from utils.note_event_dataclasses import Note, NoteEvent | |
from utils.audio import get_audio_file_info, pitch_shift_audio | |
from utils.midi import note_event2midi, pitch_shift_midi | |
from utils.note2event import note2note_event, sort_notes, validate_notes, trim_overlapping_notes | |
def create_note_event_and_note_from_jam(jam_file: str, id: str) -> Tuple[Dict, Dict]: | |
jam = jams.load(jam_file) | |
notes = [] | |
for ann in jam.annotations: | |
for obs in ann.data: | |
if isinstance(obs.value, float): | |
if obs.confidence == None: | |
note = Note(is_drum=False, | |
program=24, | |
onset=obs.time, | |
offset=obs.time + obs.duration, | |
pitch=round(obs.value), | |
velocity=1) | |
notes.append(note) | |
# Sort, validate, and trim notes | |
notes = sort_notes(notes) | |
notes = validate_notes(notes) | |
notes = trim_overlapping_notes(notes) | |
return { # notes | |
'guitarset_id': id, | |
'program': [24], | |
'is_drum': [0], | |
'duration_sec': jam.file_metadata.duration, | |
'notes': notes, | |
}, { # note_events | |
'guitarset_id': id, | |
'program': [24], | |
'is_drum': [0], | |
'duration_sec': jam.file_metadata.duration, | |
'note_events': note2note_event(notes), | |
} | |
def generate_pitch_shifted_wav_and_midi(file_list: Dict, min_pitch_shift: int = -5, max_pitch_shift: int = 6): | |
for key in file_list.keys(): | |
midi_file = file_list[key]['midi_file'] | |
audio_file = file_list[key]['mix_audio_file'] | |
# Write midi, notes, and note_events with pitch shift | |
pitch_shift_midi(src_midi_file=midi_file, | |
min_pitch_shift=min_pitch_shift, | |
max_pitch_shift=max_pitch_shift, | |
write_midi_file=True, | |
write_notes_file=True, | |
write_note_events_file=True) | |
# Write wav with pitch shift | |
pitch_shift_audio(src_audio_file=audio_file, | |
min_pitch_shift=min_pitch_shift, | |
max_pitch_shift=max_pitch_shift, | |
random_microshift_range=(-10, 11)) | |
def preprocess_guitarset16k(data_home: os.PathLike, | |
dataset_name: str = 'guitarset', | |
pitch_shift_range: Optional[Tuple[int, int]] = (-5, 6)) -> None: | |
""" | |
Splits: | |
- progression_1, progression_2, progression_3 | |
- train, validation, test (by random selection [4,1,1] player) | |
Writes: | |
- {dataset_name}_{split}_file_list.json: a dictionary with the following keys: | |
{ | |
index: | |
{ | |
'guitarset_id': guitarset_id, | |
'n_frames': (int), | |
'mix_audio_file': 'path/to/mix.wav', | |
'notes_file': 'path/to/notes.npy', | |
'note_events_file': 'path/to/note_events.npy', | |
'midi_file': 'path/to/midi.mid', | |
'program': List[int], | |
'is_drum': List[int], # 0 or 1 | |
} | |
} | |
""" | |
# Directory and file paths | |
base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') | |
output_index_dir = os.path.join(data_home, 'yourmt3_indexes') | |
os.makedirs(output_index_dir, exist_ok=True) | |
# Process annotations | |
all_ann_files = glob.glob(os.path.join(base_dir, 'annotation/*.jams'), recursive=True) | |
assert len(all_ann_files) == 360 | |
notes_files = {} | |
note_events_files = {} | |
midi_files = {} | |
for ann_file in all_ann_files: | |
# Convert all annotations to notes and note events | |
guitarset_id = os.path.basename(ann_file).split('.')[0] | |
notes, note_events = create_note_event_and_note_from_jam(ann_file, guitarset_id) | |
notes_file = ann_file.replace('.jams', '_notes.npy') | |
np.save(notes_file, notes, allow_pickle=True, fix_imports=False) | |
print(f'Created {notes_file}') | |
note_events_file = ann_file.replace('.jams', '_note_events.npy') | |
np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) | |
print(f'Created {note_events_file}') | |
# Create a midi file from the note_events | |
midi_file = ann_file.replace('.jams', '.mid') | |
note_event2midi(note_events=note_events['note_events'], output_file=midi_file) | |
print(f'Created {midi_file}') | |
notes_files[guitarset_id] = notes_file | |
note_events_files[guitarset_id] = note_events_file | |
midi_files[guitarset_id] = midi_file | |
# Process audio files | |
pass | |
# Create file_list.json | |
guitarset_ids_by_split = { | |
'progression_1': [], | |
'progression_2': [], | |
'progression_3': [], | |
'player_0': [], | |
'player_1': [], | |
'player_2': [], | |
'player_3': [], | |
'player_4': [], | |
'player_5': [], | |
'train': [], # random selection of 4 players for each style | |
'validation': [], # random selection of 1 player for each style | |
'test': [], # random selection of 1 player for each style | |
'all': [], | |
} | |
# by progressions, players and all | |
for ann_file in all_ann_files: | |
guitarset_id = os.path.basename(ann_file).split('.')[0] | |
progression = int(guitarset_id.split('_')[1].split('-')[0][-1]) | |
player = int(guitarset_id.split('_')[0]) | |
# all | |
guitarset_ids_by_split['all'].append(guitarset_id) | |
# progression | |
if progression == 1: | |
guitarset_ids_by_split['progression_1'].append(guitarset_id) | |
elif progression == 2: | |
guitarset_ids_by_split['progression_2'].append(guitarset_id) | |
elif progression == 3: | |
guitarset_ids_by_split['progression_3'].append(guitarset_id) | |
else: | |
raise ValueError(f'Invalid progression: {guitarset_id}') | |
# player | |
if player == 0: | |
guitarset_ids_by_split['player_0'].append(guitarset_id) | |
elif player == 1: | |
guitarset_ids_by_split['player_1'].append(guitarset_id) | |
elif player == 2: | |
guitarset_ids_by_split['player_2'].append(guitarset_id) | |
elif player == 3: | |
guitarset_ids_by_split['player_3'].append(guitarset_id) | |
elif player == 4: | |
guitarset_ids_by_split['player_4'].append(guitarset_id) | |
elif player == 5: | |
guitarset_ids_by_split['player_5'].append(guitarset_id) | |
else: | |
raise ValueError(f'Invalid player: {guitarset_id}') | |
# sort | |
for key in guitarset_ids_by_split.keys(): | |
guitarset_ids_by_split[key] = sorted(guitarset_ids_by_split[key]) | |
for i in range(6): | |
assert len(guitarset_ids_by_split[f'player_{i}']) == 60 | |
# train/valid/test by random player | |
for i in range(60): | |
rand_sel = np.random.choice(6, size=6, replace=False) | |
player_train = rand_sel[:4] | |
player_valid = rand_sel[4] | |
player_test = rand_sel[5] | |
for player in player_train: | |
guitarset_ids_by_split['train'].append(guitarset_ids_by_split[f'player_{player}'][i]) | |
guitarset_ids_by_split['validation'].append(guitarset_ids_by_split[f'player_{player_valid}'][i]) | |
guitarset_ids_by_split['test'].append(guitarset_ids_by_split[f'player_{player_test}'][i]) | |
assert len(guitarset_ids_by_split['train']) == 240 | |
assert len(guitarset_ids_by_split['validation']) == 60 | |
assert len(guitarset_ids_by_split['test']) == 60 | |
# Create file_list.json | |
for split in ['progression_1', 'progression_2', 'progression_3', 'train', 'validation', 'test', 'all']: | |
file_list = {} | |
for i, gid in enumerate(guitarset_ids_by_split[split]): | |
# Check if wav files exist for the 4 versions | |
wav_file = {} | |
wav_file['hex'] = os.path.join(base_dir, 'audio_hex-pickup_original', gid + '_' + 'hex' + '.wav') | |
wav_file['hex_cln'] = os.path.join(base_dir, 'audio_hex-pickup_debleeded', gid + '_' + 'hex_cln' + '.wav') | |
wav_file['mic'] = os.path.join(base_dir, 'audio_mono-mic', gid + '_' + 'mic' + '.wav') | |
wav_file['mix'] = os.path.join(base_dir, 'audio_mono-pickup_mix', gid + '_' + 'mix' + '.wav') | |
for ver in wav_file: | |
assert os.path.exists(wav_file[ver]) | |
for ver in ['mic', 'mix']: #'hex', 'hex_cln', | |
file_list[i, ver] = { | |
'guitarset_id': gid + '_' + ver, | |
'n_frames': get_audio_file_info(wav_file[ver])[1], | |
'mix_audio_file': wav_file[ver], | |
'notes_file': notes_files[gid], | |
'note_events_file': note_events_files[gid], | |
'midi_file': midi_files[gid], | |
'program': [24], | |
'is_drum': [0], | |
} | |
# Reindexing file_list | |
_file_list = {} | |
for i, v in enumerate(file_list.values()): | |
_file_list[i] = v | |
file_list = _file_list | |
# Write json | |
output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') | |
with open(output_file, 'w') as f: | |
json.dump(file_list, f, indent=4) | |
print(f'Created {output_file}') | |
if pitch_shift_range == None: | |
return | |
else: | |
min_pitch_shift, max_pitch_shift = pitch_shift_range | |
# Generate pitch shifted wav and MIDI | |
file_list_all_path = os.path.join(output_index_dir, f'{dataset_name}_all_file_list.json') | |
with open(file_list_all_path, 'r') as f: | |
fl = json.load(f) | |
file_list_all = {int(key): value for key, value in fl.items()} | |
generate_pitch_shifted_wav_and_midi(file_list_all, min_pitch_shift=min_pitch_shift, max_pitch_shift=max_pitch_shift) | |
# Create file_list.json for pitch shifted data | |
for split in ['progression_1', 'progression_2', 'progression_3', 'train', 'all']: | |
src_file_list_path = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') | |
with open(src_file_list_path, 'r') as f: | |
fl = json.load(f) | |
src_file_list = {int(key): value for key, value in fl.items()} | |
file_list = {} | |
for k, v in src_file_list.items(): | |
for pitch_shift in range(min_pitch_shift, max_pitch_shift): | |
if pitch_shift == 0: | |
file_list[k, 0] = copy.deepcopy(v) | |
else: | |
file_list[k, pitch_shift] = copy.deepcopy(v) | |
shifted_audio_file = v['mix_audio_file'].replace('.wav', f'_pshift{pitch_shift}.wav') | |
assert os.path.isfile(shifted_audio_file) == True | |
file_list[k, pitch_shift]['mix_audio_file'] = shifted_audio_file | |
file_list[k, pitch_shift]['n_frames'] = get_audio_file_info(shifted_audio_file)[1] | |
file_list[k, pitch_shift]['pitch_shift'] = pitch_shift | |
shifted_midi_file = v['midi_file'].replace('.mid', f'_pshift{pitch_shift}.mid') | |
shifted_notes_file = v['notes_file'].replace('_notes', f'_pshift{pitch_shift}_notes') | |
shifted_note_events_file = v['note_events_file'].replace('_note_events', | |
f'_pshift{pitch_shift}_note_events') | |
assert os.path.isfile(shifted_midi_file) == True | |
assert os.path.isfile(shifted_notes_file) == True | |
assert os.path.isfile(shifted_note_events_file) == True | |
file_list[k, pitch_shift]['midi_file'] = shifted_midi_file | |
file_list[k, pitch_shift]['notes_file'] = shifted_notes_file | |
file_list[k, pitch_shift]['note_events_file'] = shifted_note_events_file | |
assert len(file_list) == len(src_file_list) * (max_pitch_shift - min_pitch_shift) | |
# Reindexing file_list | |
_file_list = {} | |
for i, v in enumerate(file_list.values()): | |
_file_list[i] = v | |
file_list = _file_list | |
# Write json | |
output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_pshift_file_list.json') | |
with open(output_file, 'w') as f: | |
json.dump(file_list, f, indent=4) | |
print(f'Created {output_file}') | |
def create_filelist_by_style_guitarset16k(data_home: os.PathLike, dataset_name: str = 'guitarset') -> None: | |
# Directory and file paths | |
base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') | |
output_index_dir = os.path.join(data_home, 'yourmt3_indexes') | |
os.makedirs(output_index_dir, exist_ok=True) | |
# Load filelist, pshift_all for train | |
file_list_pshift_all_path = os.path.join(output_index_dir, f'{dataset_name}_all_pshift_file_list.json') | |
with open(file_list_pshift_all_path, 'r') as f: | |
fl_pshift = json.load(f) | |
assert len(fl_pshift) == 7920 | |
# Load filelist, all for test | |
file_list_all_path = os.path.join(output_index_dir, f'{dataset_name}_all_file_list.json') | |
with open(file_list_all_path, 'r') as f: | |
fl = json.load(f) | |
assert len(fl) == 720 | |
# Create file_list.json for training each style using pitch shifted data | |
styles = ['BN', 'Funk', 'SS', 'Jazz', 'Rock'] | |
for style in styles: | |
# Create and write pshift file list | |
train_file_list = {} | |
i = 0 | |
for v in fl_pshift.values(): | |
if style in v['guitarset_id']: | |
train_file_list[i] = copy.deepcopy(v) | |
i += 1 | |
output_file = os.path.join(output_index_dir, f'{dataset_name}_{style}_pshift_file_list.json') | |
with open(output_file, 'w') as f: | |
json.dump(train_file_list, f, indent=4) | |
print(f'Created {output_file}') | |
test_file_list = {} | |
i = 0 | |
for v in fl.values(): | |
if style in v['guitarset_id']: | |
test_file_list[i] = copy.deepcopy(v) | |
i += 1 | |
output_file = os.path.join(output_index_dir, f'{dataset_name}_{style}_file_list.json') | |
with open(output_file, 'w') as f: | |
json.dump(test_file_list, f, indent=4) | |
print(f'Created {output_file}') | |
# BASIC_PITCH_VALIDATION_IDS = [ | |
# "05_Funk2-108-Eb_comp", "04_BN2-131-B_comp", "04_Jazz3-150-C_solo", "05_Rock2-85-F_solo", | |
# "05_Funk3-98-A_comp", "05_BN3-119-G_comp", "02_SS2-107-Ab_solo", "01_BN2-131-B_solo", | |
# "00_BN2-166-Ab_comp", "04_SS1-100-C#_solo", "01_BN2-166-Ab_solo", "01_Rock1-130-A_solo", | |
# "04_Funk2-119-G_solo", "01_SS2-107-Ab_comp", "05_Funk3-98-A_solo", "05_Funk1-114-Ab_comp", | |
# "05_Jazz2-187-F#_solo", "05_SS1-100-C#_comp", "00_Rock3-148-C_solo", "02_Rock3-117-Bb_comp", | |
# "01_BN1-147-Gb_solo", "01_Rock1-90-C#_solo", "01_SS2-107-Ab_solo", "02_Jazz3-150-C_solo", | |
# "00_Funk1-97-C_solo", "05_SS3-98-C_solo", "03_Rock3-148-C_comp", "03_Rock3-117-Bb_solo", | |
# "04_Jazz2-187-F#_solo", "05_Jazz2-187-F#_comp", "02_SS1-68-E_solo", "04_SS2-88-F_solo", | |
# "04_BN2-131-B_solo", "04_Jazz3-137-Eb_comp", "00_SS2-107-Ab_comp", "01_Rock1-130-A_comp", | |
# "00_Jazz1-130-D_comp", "04_Funk2-108-Eb_comp", "05_BN2-166-Ab_comp" | |
# ] | |
# BASIC_PITCH_TEST_IDS = [ | |
# "04_SS3-84-Bb_solo", "02_Funk1-114-Ab_solo", "05_Funk1-114-Ab_solo", "05_Funk1-97-C_solo", | |
# "00_Rock3-148-C_comp", "00_Jazz3-137-Eb_comp", "00_Jazz1-200-B_comp", "03_SS3-98-C_solo", | |
# "05_Jazz1-130-D_comp", "00_Jazz2-110-Bb_comp", "02_Funk3-98-A_comp", "04_Rock1-130-A_comp", | |
# "03_BN1-129-Eb_comp", "03_Funk2-119-G_comp", "05_BN1-147-Gb_comp", "02_Rock1-90-C#_comp", | |
# "00_Funk3-98-A_solo", "01_SS1-100-C#_comp", "00_Funk3-98-A_comp", "02_BN3-154-E_comp", | |
# "01_Jazz3-137-Eb_comp", "00_BN2-131-B_comp", "04_SS1-68-E_solo", "05_Funk1-97-C_comp", | |
# "04_Jazz3-137-Eb_solo", "05_Rock2-142-D_solo", "02_BN3-119-G_solo", "02_Rock2-142-D_solo", | |
# "01_BN1-129-Eb_solo", "00_Rock2-85-F_comp", "00_Rock1-130-A_solo" | |
# ] | |
# def create_filelist_for_basic_pitch_benchmark_guitarset16k(data_home: os.PathLike, | |
# dataset_name: str = 'guitarset') -> None: | |
# # Directory and file paths | |
# base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') | |
# output_index_dir = os.path.join(data_home, 'yourmt3_indexes') | |
# os.makedirs(output_index_dir, exist_ok=True) | |
# # Load filelist, pshift_all for train | |
# file_list_pshift_all_path = os.path.join(output_index_dir, | |
# f'{dataset_name}_all_pshift_file_list.json') | |
# with open(file_list_pshift_all_path, 'r') as f: | |
# fl_pshift = json.load(f) | |
# assert len(fl_pshift) == 7920 | |
# # Load filelist, all without pshift | |
# file_list_all_path = os.path.join(output_index_dir, f'{dataset_name}_all_file_list.json') | |
# with open(file_list_all_path, 'r') as f: | |
# fl = json.load(f) | |
# assert len(fl) == 720 | |
# # This is abandoned, because the split is not official one. | |