import pretty_midi from copy import deepcopy import numpy as np from miditok import CPWord, Structured from miditoolkit import MidiFile from src.music.config import MAX_EMBEDDING, CHUNK_SIZE from src.music.utilities.chord_structured import ChordStructured # code from https://github.com/jason9693/midi-neural-processor RANGE_NOTE_ON = 128 RANGE_NOTE_OFF = 128 RANGE_VEL = 32 RANGE_TIME_SHIFT = 100 MAX_EMBEDDING = RANGE_VEL + RANGE_NOTE_OFF + RANGE_TIME_SHIFT + RANGE_NOTE_ON START_IDX = { 'note_on': 0, 'note_off': RANGE_NOTE_ON, 'time_shift': RANGE_NOTE_ON + RANGE_NOTE_OFF, 'velocity': RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT } # Our parameters pitch_range = range(21, 109) beat_res = {(0, 4): 8, (4, 12): 4} nb_velocities = 32 additional_tokens = {'Chord': True, 'Rest': True, 'Tempo': True, 'TimeSignature': False, 'Program': False, 'rest_range': (2, 8), # (half, 8 beats) 'nb_tempos': 32, # nb of tempo bins 'tempo_range': (40, 250)} # (min, max) # Creates the tokenizer_cp and loads a MIDI # tokenizer_cp = CPWord(pitch_range, beat_res, nb_velocities, additional_tokens) tokenizer_structured = ChordStructured(pitch_range, beat_res, nb_velocities) class SustainAdapter: def __init__(self, time, type): self.start = time self.type = type class SustainDownManager: def __init__(self, start, end): self.start = start self.end = end self.managed_notes = [] self._note_dict = {} # key: pitch, value: note.start def add_managed_note(self, note: pretty_midi.Note): self.managed_notes.append(note) def transposition_notes(self): for note in reversed(self.managed_notes): try: note.end = self._note_dict[note.pitch] except KeyError: note.end = max(self.end, note.end) self._note_dict[note.pitch] = note.start # Divided note by note_on, note_off class SplitNote: def __init__(self, type, time, value, velocity): ## type: note_on, note_off self.type = type self.time = time self.velocity = velocity self.value = value def __repr__(self): return '<[SNote] time: {} type: {}, value: {}, velocity: {}>'\ .format(self.time, self.type, self.value, self.velocity) class Event: def __init__(self, event_type, value): self.type = event_type self.value = value def __repr__(self): return ''.format(self.type, self.value) def to_int(self): return START_IDX[self.type] + self.value @staticmethod def from_int(int_value): info = Event._type_check(int_value) return Event(info['type'], info['value']) @staticmethod def _type_check(int_value): range_note_on = range(0, RANGE_NOTE_ON) range_note_off = range(RANGE_NOTE_ON, RANGE_NOTE_ON+RANGE_NOTE_OFF) range_time_shift = range(RANGE_NOTE_ON+RANGE_NOTE_OFF,RANGE_NOTE_ON+RANGE_NOTE_OFF+RANGE_TIME_SHIFT) valid_value = int_value if int_value in range_note_on: return {'type': 'note_on', 'value': valid_value} elif int_value in range_note_off: valid_value -= RANGE_NOTE_ON return {'type': 'note_off', 'value': valid_value} elif int_value in range_time_shift: valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF) return {'type': 'time_shift', 'value': valid_value} else: valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT) return {'type': 'velocity', 'value': valid_value} def _divide_note(notes): result_array = [] notes.sort(key=lambda x: x.start) for note in notes: on = SplitNote('note_on', note.start, note.pitch, note.velocity) off = SplitNote('note_off', note.end, note.pitch, None) result_array += [on, off] return result_array def _merge_note(snote_sequence): note_on_dict = {} result_array = [] for snote in snote_sequence: # print(note_on_dict) if snote.type == 'note_on': note_on_dict[snote.value] = snote elif snote.type == 'note_off': try: on = note_on_dict[snote.value] off = snote if off.time - on.time == 0: continue result = pretty_midi.Note(on.velocity, snote.value, on.time, off.time) result_array.append(result) except: print('info removed pitch: {}'.format(snote.value)) return result_array def _snote2events(snote: SplitNote, prev_vel: int): result = [] if snote.velocity is not None: modified_velocity = snote.velocity // 4 if prev_vel != modified_velocity: result.append(Event(event_type='velocity', value=modified_velocity)) result.append(Event(event_type=snote.type, value=snote.value)) return result def _event_seq2snote_seq(event_sequence): timeline = 0 velocity = 0 snote_seq = [] for event in event_sequence: if event.type == 'time_shift': timeline += ((event.value+1) / 100) if event.type == 'velocity': velocity = event.value * 4 else: snote = SplitNote(event.type, timeline, event.value, velocity) snote_seq.append(snote) return snote_seq def _make_time_sift_events(prev_time, post_time): time_interval = int(round((post_time - prev_time) * 100)) results = [] while time_interval >= RANGE_TIME_SHIFT: results.append(Event(event_type='time_shift', value=RANGE_TIME_SHIFT-1)) time_interval -= RANGE_TIME_SHIFT if time_interval == 0: return results else: return results + [Event(event_type='time_shift', value=time_interval-1)] def _control_preprocess(ctrl_changes): sustains = [] manager = None for ctrl in ctrl_changes: if ctrl.value >= 64 and manager is None: # sustain down manager = SustainDownManager(start=ctrl.time, end=None) elif ctrl.value < 64 and manager is not None: # sustain up manager.end = ctrl.time sustains.append(manager) manager = None elif ctrl.value < 64 and len(sustains) > 0: sustains[-1].end = ctrl.time return sustains def _note_preprocess(susteins, notes): note_stream = [] count_note_processed = 0 if susteins: # if the midi file has sustain controls for sustain in susteins: if len(notes) > 0: for note_idx, note in enumerate(notes): if note.start < sustain.start: note_stream.append(note) last_counted = True elif note.start > sustain.end: # notes = notes[note_idx:] # sustain.transposition_notes() last_counted = False break else: sustain.add_managed_note(note) last_counted = True count_note_processed += 1 sustain.transposition_notes() # transpose what in the sustain note_stream += sustain.managed_notes # add to stream # remove notes that were already added to the stream last_idx = note_idx if not last_counted else note_idx + 1 if last_idx < len(notes): notes = notes[last_idx:] # save next notes, previous notes were stored in note stream else: notes = [] note_stream += notes count_note_processed += len(notes) else: # else, just push everything into note stream for note_idx, note in enumerate(notes): note_stream.append(note) note_stream.sort(key= lambda x: x.start) return note_stream def midi_valid(midi) -> bool: # if any(ts.numerator != 4 or ts.denominator != 4 for ts in midi.time_signature_changes): # return False # time signature different from 4/4 # if midi.max_tick < 10 * midi.ticks_per_beat: # return False # this MIDI is too short return True def encode_midi_structured(file_path, nb_aug, nb_noise): notes = [] mid = MidiFile(file_path) assert midi_valid(mid) # Converts MIDI to tokens, and back to a MIDI for inst in mid.instruments: inst_notes = inst.notes # ctrl.number is the number of sustain control. If you want to know abour the number type of control, # see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) notes += _note_preprocess(ctrls, inst_notes) assert len(notes) == len(mid.instruments[0].notes) # sort notes arg_rank = np.argsort([n.start for n in notes]) notes = list(np.array(notes)[arg_rank]) original_notes = deepcopy(notes) # convert notes to ints encoded_main = tokenizer_structured.midi_to_tokens(mid)[0] min_pitch = np.min([n.pitch for n in notes]) encoded_augmentations = [] noise_shift = 6 aug_shift = 3 embedding_noise = None for i_aug in range(nb_aug): a_notes = alter_notes_exact_tick(original_notes, aug_shift, min_pitch) mid.instruments[0].notes = a_notes assert midi_valid(mid) embedding_aug = tokenizer_structured.midi_to_tokens(mid)[0] # encode notes encoded_augmentations.append(embedding_aug) if nb_noise > 0: a_notes = alter_notes_exact_tick(original_notes, noise_shift, min_pitch) mid.instruments[0].notes = a_notes assert midi_valid(mid) embedding_noise = tokenizer_structured.midi_to_tokens(mid)[0] # encode notes return encoded_main, encoded_augmentations, embedding_noise def encode_midi_cp(file_path, nb_aug, nb_noise): notes = [] mid = MidiFile(file_path) assert midi_valid(mid) # Converts MIDI to tokens, and back to a MIDI for inst in mid.instruments: inst_notes = inst.notes # ctrl.number is the number of sustain control. If you want to know abour the number type of control, # see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) notes += _note_preprocess(ctrls, inst_notes) assert len(notes) == len(mid.instruments[0].notes) # sort notes arg_rank = np.argsort([n.start for n in notes]) notes = list(np.array(notes)[arg_rank]) original_notes = deepcopy(notes) # convert notes to ints encoded_main = tokenizer_cp.midi_to_tokens(mid)[0] min_pitch = np.min([n.pitch for n in notes]) encoded_augmentations = [] noise_shift = 6 aug_shift = 3 embedding_noise = None for i_aug in range(nb_aug): a_notes = alter_notes_exact_tick(original_notes, aug_shift, min_pitch) mid.instruments[0].notes = a_notes assert midi_valid(mid) embedding_aug = tokenizer_cp.midi_to_tokens(mid)[0] # encode notes encoded_augmentations.append(embedding_aug) if nb_noise > 0: a_notes = alter_notes_exact_tick(original_notes, noise_shift, min_pitch) mid.instruments[0].notes = a_notes assert midi_valid(mid) embedding_noise = tokenizer_cp.midi_to_tokens(mid)[0] # encode notes return encoded_main, encoded_augmentations, embedding_noise def alter_notes_exact_tick(notes, shift, min_pitch): # copy original notes a_notes = deepcopy(notes) # sample smart augmentation pitch_shift, time_scaling = 0, 0 while pitch_shift == 0 and time_scaling == 0: pitch_shift = np.random.choice(np.arange(max(-shift, -min_pitch), shift+1)) time_scaling = np.random.choice([-5, -2.5, 0, 2.5, 5]) assert pitch_shift <= shift and pitch_shift >= -shift # modify notes for e in a_notes: e.start = int(e.start * (1. + time_scaling / 100)) e.end = int(e.end * (1. + time_scaling / 100)) new_pitch = max(e.pitch + pitch_shift, 0) e.pitch = new_pitch return a_notes def alter_notes(notes, shift, min_pitch): # copy original notes a_notes = deepcopy(notes) # sample smart augmentation pitch_shift, time_scaling = 0, 0 while pitch_shift == 0 and time_scaling == 0: pitch_shift = np.random.choice(np.arange(max(-shift, -min_pitch), shift+1)) time_scaling = np.random.choice([-5, -2.5, 0, 2.5, 5]) assert pitch_shift <= shift and pitch_shift >= -shift # modify notes for e in a_notes: e.start = e.start * (1. + time_scaling / 100) e.end = e.end * (1. + time_scaling / 100) new_pitch = max(e.pitch + pitch_shift, 0) e.pitch = new_pitch return a_notes def encode_midi(file_path, nb_aug, nb_noise): notes = [] mid = pretty_midi.PrettyMIDI(midi_file=file_path) for inst in mid.instruments: inst_notes = inst.notes # ctrl.number is the number of sustain control. If you want to know abour the number type of control, # see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) notes += _note_preprocess(ctrls, inst_notes) assert len(notes) == len(mid.instruments[0].notes) # sort notes arg_rank = np.argsort([n.start for n in notes]) notes = list(np.array(notes)[arg_rank]) # convert notes to ints encoded_main = convert_notes(notes) min_pitch = np.min([n.pitch for n in notes]) encoded_augmentations = [] noise_shift = 6 aug_shift = 3 embedding_noise = None for i_aug in range(nb_aug): a_notes = alter_notes(notes, aug_shift, min_pitch) embedding_group = convert_notes(a_notes) # encode notes encoded_augmentations.append(embedding_group) if nb_noise > 0: a_notes = alter_notes(notes, noise_shift, min_pitch) embedding_noise = convert_notes(a_notes) # encode notes return encoded_main, encoded_augmentations, embedding_noise def chunk_notes(n_notes_per_chunk, notes): index = 0 chunks = [] for n in n_notes_per_chunk: chunks.append(notes[index:index+n]) index += n return chunks def chunk_first_embedding(chunk_size, embedding): chunks = [] index = 0 if len(embedding) < chunk_size: return [embedding] else: for i in range(chunk_size, len(embedding) + chunk_size, chunk_size): if (len(embedding) - index) > (chunk_size / 2): chunks.append(embedding[index:i]) index = i return chunks def encode_midi_in_chunks(file_path, n_aug, n_noise): n_noise = 0 notes = [] mid = pretty_midi.PrettyMIDI(midi_file=file_path) # preprocess midi for inst in mid.instruments: inst_notes = inst.notes # ctrl.number is the number of sustain control. If you want to know abour the number type of control, # see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) notes += _note_preprocess(ctrls, inst_notes) assert len(notes) == len(mid.instruments[0].notes) arg_rank = np.argsort([n.start for n in notes]) notes = list(np.array(notes)[arg_rank]) # convert notes to ints main_embedding = convert_notes(notes) # split the sequence of events in chunks if np.max(main_embedding) < MAX_EMBEDDING and np.min(main_embedding) >= 0: encoded_chunks = chunk_first_embedding(CHUNK_SIZE, main_embedding) else: assert False n_notes_per_chunk = [np.argwhere(np.array(ec) < 128).flatten().size for ec in encoded_chunks] chunked_notes = chunk_notes(n_notes_per_chunk, notes) # reencode chunks by shifting notes encoded_chunks = [] for note_group in chunked_notes: note_group = shift_notes(note_group) embedding_main = convert_notes(note_group)[:CHUNK_SIZE] encoded_chunks.append(embedding_main) min_pitches = [np.min([n.pitch for n in cn]) for cn in chunked_notes] encoded_augmentations = [] aug_shift = 3 for i_aug in range(n_aug): chunked_embedding_aug = [] for note_group, min_pitch in zip(chunked_notes, min_pitches): a_notes = alter_notes(note_group, aug_shift, min_pitch) a_notes = shift_notes(a_notes) assert len(a_notes) == len(note_group) embedding_group = convert_notes(a_notes)[:CHUNK_SIZE] # encode notes chunked_embedding_aug.append(embedding_group) encoded_augmentations += chunked_embedding_aug assert len(encoded_augmentations) == n_aug * len(encoded_chunks) return encoded_chunks, encoded_augmentations, [] def encode_miditok_in_chunks(file_path, n_aug, n_noise): n_noise = 0 notes = [] mid = MidiFile(file_path) assert midi_valid(mid) # Converts MIDI to tokens, and back to a MIDI for inst in mid.instruments: inst_notes = inst.notes # ctrl.number is the number of sustain control. If you want to know abour the number type of control, # see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) notes += _note_preprocess(ctrls, inst_notes) assert len(notes) == len(mid.instruments[0].notes) # sort notes arg_rank = np.argsort([n.start for n in notes]) notes = list(np.array(notes)[arg_rank]) # convert notes to ints encoded_main = tokenizer_cp.midi_to_tokens(mid)[0] encoded_chunks = chunk_first_embedding(CHUNK_SIZE, encoded_main) n_notes_per_chunk = [len([tokenizer_cp.vocab.token_to_event[e[0]] for e in enc_chunk if tokenizer_cp.vocab.token_to_event[e[0]] == 'Family_Note']) for enc_chunk in encoded_chunks] chunked_notes = chunk_notes(n_notes_per_chunk, notes) # reencode chunks by shifting notes encoded_chunks = [] for note_group in chunked_notes: mid.instruments[0].notes = note_group mid = shift_mid(mid) # shift midi assert midi_valid(mid) embedding_main = tokenizer_cp.midi_to_tokens(mid)[0][:CHUNK_SIZE] # tokenize midi encoded_chunks.append(embedding_main) min_pitch = np.min([n.pitch for n in notes]) encoded_augmentations = [] aug_shift = 3 for i_aug in range(n_aug): chunked_embedding_aug = [] for note_group in chunked_notes: a_notes = alter_notes_exact_tick(note_group, aug_shift, min_pitch) assert len(a_notes) == len(note_group) mid.instruments[0].notes = a_notes # shift midi mid = shift_mid(mid) assert midi_valid(mid) # tokenize midi embedding_aug = tokenizer_cp.midi_to_tokens(mid)[0][:CHUNK_SIZE] # encode notes chunked_embedding_aug.append(embedding_aug) encoded_augmentations += chunked_embedding_aug assert len(encoded_augmentations) == n_aug * len(encoded_chunks) return encoded_chunks, encoded_augmentations, [] def encode_midi_chunks_structured(file_path, n_aug, n_noise): n_noise = 0 notes = [] mid = MidiFile(file_path) assert midi_valid(mid) # Converts MIDI to tokens, and back to a MIDI for inst in mid.instruments: inst_notes = inst.notes # ctrl.number is the number of sustain control. If you want to know abour the number type of control, # see https://www.midi.org/specifications-old/item/table-3-control-change-messages-data-bytes-2 ctrls = _control_preprocess([ctrl for ctrl in inst.control_changes if ctrl.number == 64]) notes += _note_preprocess(ctrls, inst_notes) assert len(notes) == len(mid.instruments[0].notes) nb_notes = CHUNK_SIZE // 4 notes = notes[:50 * nb_notes] # limit to 50 chunks to speed up # sort notes arg_rank = np.argsort([n.start for n in notes]) notes = list(np.array(notes)[arg_rank]) assert (len(notes) // nb_notes) > 1 # assert at least 3 chunks n_notes_per_chunk = [nb_notes for _ in range(len(notes) // nb_notes)] if len(notes) % nb_notes > nb_notes / 2: n_notes_per_chunk.append(len(notes) % nb_notes) chunked_notes = chunk_notes(n_notes_per_chunk, notes) # reencode chunks by shifting notes encoded_chunks = [] for note_group in chunked_notes: mid.instruments[0].notes = note_group mid = shift_mid(mid) # shift midi assert midi_valid(mid) embedding_main = tokenizer_structured.midi_to_tokens(mid)[0] # tokenize midi encoded_chunks.append(embedding_main) min_pitch = np.min([n.pitch for n in notes]) encoded_augmentations = [] aug_shift = 3 for i_aug in range(n_aug): chunked_embedding_aug = [] for note_group in chunked_notes: a_notes = alter_notes_exact_tick(note_group, aug_shift, min_pitch) assert len(a_notes) == len(note_group) mid.instruments[0].notes = a_notes # shift midi mid = shift_mid(mid) assert midi_valid(mid) # tokenize midi embedding_aug = tokenizer_structured.midi_to_tokens(mid)[0] # encode notes chunked_embedding_aug.append(embedding_aug) encoded_augmentations += chunked_embedding_aug assert len(encoded_augmentations) == n_aug * len(encoded_chunks) return encoded_chunks, encoded_augmentations, [] def shift_mid(mid): # mid = deepcopy(mid) to_remove = mid.instruments[0].notes[0].start if to_remove > 0: for n in mid.instruments[0].notes: n.start -= to_remove n.end -= to_remove # for e in mid.tempo_changes: # e.time = max(0, e.time - to_remove) # # for e in mid.time_signature_changes: # e.time = max(0, e.time - to_remove) # # for e in mid.key_signature_changes: # e.time = max(0, e.time - to_remove) return mid def shift_notes(notes): to_remove = notes[0].start for n in notes: n.start -= to_remove n.end -= to_remove return notes def convert_notes(notes): events = [] dnotes = _divide_note(notes) # split events in on / off # print(dnotes) dnotes.sort(key=lambda x: x.time) # print('sorted:') # print(dnotes) cur_time = 0 cur_vel = 0 for snote in dnotes: events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time) events += _snote2events(snote=snote, prev_vel=cur_vel) # events += _make_time_sift_events(prev_time=cur_time, post_time=snote.time) cur_time = snote.time cur_vel = snote.velocity event_list = [e.to_int() for e in events] if not (np.max(event_list) < MAX_EMBEDDING and np.min(event_list) >= 0): print('weird') assert False return event_list def decode_midi_structured(encoding, file_path=None): mid = tokenizer_structured.tokens_to_midi([encoding]) if file_path: mid.dump(file_path) return mid def decode_midi_cp(encoding, file_path=None): mid = tokenizer_cp.tokens_to_midi([encoding]) if file_path: mid.dump(file_path) return mid def decode_midi(idx_array, file_path=None): event_sequence = [Event.from_int(idx) for idx in idx_array] # print(event_sequence) snote_seq = _event_seq2snote_seq(event_sequence) note_seq = _merge_note(snote_seq) note_seq.sort(key=lambda x:x.start) mid = pretty_midi.PrettyMIDI() # if want to change instument, see https://www.midi.org/specifications/item/gm-level-1-sound-set instument = pretty_midi.Instrument(1, False, "Developed By Yang-Kichang") instument.notes = note_seq mid.instruments.append(instument) if file_path is not None: mid.write(file_path) return mid if __name__ == '__main__': encoded = encode_midi('bin/ADIG04.mid') print(encoded) decided = decode_midi(encoded,file_path='bin/test.mid') ins = pretty_midi.PrettyMIDI('bin/ADIG04.mid') print(ins) print(ins.instruments[0]) for i in ins.instruments: print(i.control_changes) print(i.notes)