File size: 26,968 Bytes
e775f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
""" Structured MIDI encoding method as using in the Piano Inpainting Application
https://arxiv.org/abs/2107.05944

"""

from typing import List, Tuple, Dict, Optional

import numpy as np
from miditoolkit import Instrument, Note, TempoChange
from miditok import Structured
from miditok.midi_tokenizer_base import MIDITokenizer, Vocabulary, Event
from miditok.constants import *
from itertools import combinations
Cs = np.array([60 + oct for oct in range(-12*4, 12*5, 12)])

def get_chord_map():
    my_chord_map = {#'octave': (0, 12),
                     #'power': (0, 7),
                     #'power_inv_1': (0, 5),
                     'min': (0, 3, 7),
                     'maj': (0, 4, 7),
                     'dim': (0, 3, 6),
                     'aug': (0, 4, 8),
                     'sus2': (0, 2, 7),
                     'sus4': (0, 5, 7),
                     '7dom': (0, 4, 7, 10),
                     '7min': (0, 3, 7, 10),
                     '7maj': (0, 4, 7, 11),
                     '7halfdim': (0, 3, 6, 10),
                     '7dim': (0, 3, 6, 9),
                     '7aug': (0, 4, 8, 11),
                     '9maj': (0, 4, 7, 10, 14),
                     '9min': (0, 4, 7, 10, 13)}

    # 
    for k in list(my_chord_map.keys()).copy():
        n_notes = len(my_chord_map[k])
        if n_notes > 2:
            if k not in ['7dim', 'aug', 'sus2', 'sus4']:
                if '9' in k:
                    nb_invs = 3
                else:
                    nb_invs = n_notes
                for i_inv in range(1, nb_invs):
                    shift = np.array([my_chord_map[k][(i + i_inv) % n_notes] for i in range(n_notes)])
                    shift[-i_inv:] += 12
                    pattern = [0]
                    for i in range(1, len(shift)):
                        pattern.append(shift[i] - shift[0])
                    my_chord_map[k + f'_inv_{i_inv}'] = tuple(pattern)
    known = set()
    for k in my_chord_map.keys():
        assert my_chord_map[k] not in known
    inverted_chord_map = dict()
    for k, v in my_chord_map.items():
        inverted_chord_map[v] = k
    return my_chord_map, inverted_chord_map

def find_sub_pattern(pattern, candidate_patterns):
    for i in np.arange(len(pattern) - 1, 0, -1):
        patt_indexes = [(0,) + c for c in combinations(range(1, len(pattern)), i)]
        for p_ind in patt_indexes:
            sorted_pattern = np.sort(np.array(pattern)[np.array(p_ind)])
            sorted_pattern = tuple(sorted_pattern - sorted_pattern[0])
            if sorted_pattern in candidate_patterns:
                return True, sorted_pattern, np.array(p_ind)
    return False, None, None

# def find_sub_pattern(pattern, candidate_patterns, indexes, n_asserted=1):
#     if len(candidate_patterns) == 0 or len(pattern) < 3:
#         return False, None, None
#     else:
#         sorted_pattern = np.sort(pattern)
#         sorted_pattern = tuple(sorted_pattern - sorted_pattern[0])
#         if sorted_pattern in candidate_patterns:
#             return True, sorted_pattern, indexes
#         else:
#             if n_asserted + 1 == len(pattern):
#                 return False, None, None
#             else:
#                 # hypothesis that pattern is good up to n_asserted + 1
#                 asserted_pattern = pattern[:n_asserted + 1]
#                 len_asserted = len(asserted_pattern)
#                 # find candidate patterns matching that beginning
#                 sorted_asserted_pattern = np.sort(asserted_pattern)
#                 sorted_asserted_pattern = tuple(sorted_asserted_pattern - sorted_asserted_pattern[0])
#                 c_p = [cp for cp in candidate_patterns if cp[:len_asserted] == sorted_asserted_pattern]
#                 found, found_pattern, found_indexes = find_sub_pattern(pattern, c_p, indexes, n_asserted=n_asserted+1)
#             if found:
#                 return True, found_pattern, found_indexes
#             # if the pattern was not found, then we need to remove that note
#             else:
#                 pattern2 = pattern[: n_asserted] + pattern[n_asserted + 1:]
#                 if pattern2 == pattern:
#                     stop = 1
#                 new_indexes = indexes.copy()
#                 new_indexes.pop(n_asserted)
#                 return find_sub_pattern(pattern2, candidate_patterns, new_indexes, n_asserted=n_asserted)


def filter_notes_find_chord_and_root(chord, inverted_chord_map):
    known_chords = list(inverted_chord_map.keys())
    found, chord_pattern, chord_indexes = find_sub_pattern(tuple(chord), known_chords)
    if found:
        chord_id = inverted_chord_map[chord_pattern].split('_')[0]
    else:
        return False, None, None, None

    # find root now :)
    if 'inv' not in inverted_chord_map[chord_pattern]:
        root_id = 0
    else:
        inv_id = int(inverted_chord_map[chord_pattern].split('_')[-1])
        n_notes = len(chord_pattern)
        root_id = n_notes - inv_id

    return True, chord_id, root_id, chord_indexes


class ChordStructured(MIDITokenizer):
    """ Structured MIDI encoding method as using in the Piano Inpainting Application
    https://arxiv.org/abs/2107.05944
    The token types follows the specific pattern:
    Pitch -> Velocity -> Duration -> Time Shift -> back to Pitch ...
    NOTE: this encoding uses only "Time Shifts" events to move in the time, and only
    from one note to another. Hence it is suitable to encode continuous sequences of
    notes without long periods of silence. If your dataset contains music with long
    pauses, you might handle them with an appropriate "time shift" dictionary
    (which values are made from the beat_res dict) or with a different encoding.

    :param pitch_range: range of used MIDI pitches
    :param beat_res: beat resolutions, with the form:
            {(beat_x1, beat_x2): beat_res_1, (beat_x2, beat_x3): beat_res_2, ...}
            The keys of the dict are tuples indicating a range of beats, ex 0 to 3 for the first bar
            The values are the resolution, in samples per beat, of the given range, ex 8
    :param nb_velocities: number of velocity bins
    :param program_tokens: will add entries for MIDI programs in the dictionary, to use
            in the case of multitrack generation for instance
    :param sos_eos_tokens: Adds Start Of Sequence (SOS) and End Of Sequence (EOS) tokens to the vocabulary
    :param params: can be a path to the parameter (json encoded) file or a dictionary
    """
    def __init__(self, pitch_range: range = PITCH_RANGE, beat_res: Dict[Tuple[int, int], int] = BEAT_RES,
                 nb_velocities: int = NB_VELOCITIES, program_tokens: bool = ADDITIONAL_TOKENS['Program'],
                 sos_eos_tokens: bool = False, params=None):
        # No additional tokens
        additional_tokens = {'Chord': False, 'Rest': False, 'Tempo': False, 'TimeSignature': False, 'Program': program_tokens}
        self.pitch2octave_relative = dict()
        self.octave_relative2pitch = dict()
        for p in pitch_range:
            self.pitch2octave_relative[p] = self.get_octave_and_relative(p)
            self.octave_relative2pitch[self.pitch2octave_relative[p]] = p
        self.chord_maps, self.inverted_chord_map = get_chord_map()
        super().__init__(pitch_range, beat_res, nb_velocities, additional_tokens, sos_eos_tokens, params)

    def get_octave_and_relative(self, pitch):
        octave = np.argwhere(pitch - Cs >=0).flatten()[-1]
        relative = pitch - Cs[octave]
        return octave, relative

    def get_note_events(self, note, dur_bins, next_note_start):
        events = []
        if isinstance(note.pitch, str):  # it's a chord
            chord_id = '_'.join(note.pitch.split('_')[:-1])
            pitch = int(note.pitch.split('_')[-1])
        else:  # it's a note
            chord_id = 'note'
            pitch = note.pitch
        # get octave and relative position of the pitch (root pitch for a chord)
        octave, relative = self.pitch2octave_relative[pitch]
        # Add chord/note event. A note is defined as Chord_note
        events.append(Event(type_='Chord', time=note.start, value=chord_id, desc=note.pitch))
        # Add octave of the root
        events.append(Event(type_='OctavePitch', time=note.start, value=octave, desc=note.pitch))
        # Add octave relative pitch of the root
        events.append(Event(type_='RelativePitch', time=note.start, value=relative, desc=note.pitch))
        # Velocity
        events.append(Event(type_='Velocity', time=note.start, value=note.velocity, desc=f'{note.velocity}'))
        # Duration
        duration = note.end - note.start
        index = np.argmin(np.abs(dur_bins - duration))
        events.append(Event(type_='Duration', time=note.start, value='.'.join(map(str, self.durations[index])), desc=f'{duration} ticks'))
        # Time-Shift
        time_shift = next_note_start - note.start
        assert time_shift >= 0  # this asserts that events are sorted
        index = np.argmin(np.abs(dur_bins - time_shift))
        events.append(Event(type_='Time-Shift', time=note.start, desc=f'{time_shift} ticks',
                            value='.'.join(map(str, self.durations[index])) if time_shift != 0 else '0.0.1'))
        return events, time_shift

    def track_to_tokens(self, track: Instrument) -> List[int]:
        """ Converts a track (miditoolkit.Instrument object) into a sequence of tokens

        :param track: MIDI track to convert
        :return: sequence of corresponding tokens
        """
        # Make sure the notes are sorted first by their onset (start) times, second by pitch
        # notes.sort(key=lambda x: (x.start, x.pitch))  # done in midi_to_tokens
        events = []

        dur_bins = self.durations_ticks[self.current_midi_metadata['time_division']]

        # assume first note is the beginning of the song, no time shift at first.

        # Track chords. For each chord, insert a fake note that contains its info so that it can be converted to the proper event
        if self.additional_tokens['Chord'] and not track.is_drum:
            notes_and_chords = self.detect_chords(track.notes, self.current_midi_metadata['time_division'], self._first_beat_res)
        else:
            notes_and_chords = track.notes

        sum_shifts = 0
        # Creates the Pitch, Velocity, Duration and Time Shift events
        for n, note in enumerate(notes_and_chords):
            if n == len(notes_and_chords) - 1:
                next_note_start = note.start  # add zero time shift at the end
            else:
                next_note_start = notes_and_chords[n + 1].start
            new_events, time_shift = self.get_note_events(note, dur_bins, next_note_start=next_note_start)
            events += new_events
            sum_shifts += time_shift
        assert len(events) // 6  == len(notes_and_chords)

        return self.events_to_tokens(events)

    def tokens_to_track(self, tokens: List[int], time_division: Optional[int] = TIME_DIVISION,
                        program: Optional[Tuple[int, bool]] = (0, False)) -> Tuple[Instrument, List[TempoChange]]:
        """ Converts a sequence of tokens into a track object

        :param tokens: sequence of tokens to convert
        :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI to create)
        :param program: the MIDI program of the produced track and if it drum, (default (0, False), piano)
        :return: the miditoolkit instrument object and a "Dummy" tempo change
        """
        events = self.tokens_to_events(tokens)
        instrument = Instrument(program[0], is_drum=False, name=MIDI_INSTRUMENTS[program[0]]['name'])
        current_tick = 0
        count = 0
        # start at first chord event
        while count < len(events) and events[count].type != 'Chord':
            count += 1

        while count < len(events):
            if events[count].type == 'Chord':
                note_chord_events = [events[c] for c in range(count, count + 6)]
                events_types = [c.type for c in note_chord_events]
                if events_types[1:] == ['OctavePitch', 'RelativePitch', 'Velocity', 'Duration', 'Time-Shift']:
                    octave, relative = int(note_chord_events[1].value), int(note_chord_events[2].value)
                    duration = self._token_duration_to_ticks(note_chord_events[4].value, time_division)
                    vel = int(note_chord_events[3].value)
                    root_pitch = self.octave_relative2pitch[(octave, relative)]
                    if note_chord_events[0].value == "note":
                        # pass
                        instrument.notes.append(Note(vel, root_pitch, current_tick, current_tick + duration))
                    else:
                        pitches = self.find_chord_pitches(root_pitch, note_chord_events[0].value)
                        for p in pitches:
                            instrument.notes.append(Note(vel, p, current_tick, current_tick + duration))

                    beat, pos, res = map(int, note_chord_events[5].value.split('.'))
                    current_tick += (beat * res + pos) * time_division // res  # time shift
                    count += 6
                else:
                    count += 1
            else:
                count += 1

        return instrument, [TempoChange(TEMPO, 0)]

    def find_chord_pitches(self, root_pitch, chord_name):
        chord_map = self.chord_maps[chord_name]
        if 'inv' not in chord_map:
            root_position = 0
        else:
            inv_id = int(chord_name.split('_')[-1])
            n_notes = len(chord_map)
            root_position = n_notes - inv_id
        deltas = np.array(chord_map) - chord_map[root_position]
        pitches = [root_pitch + d for d in deltas]
        return pitches

    def _create_vocabulary(self, sos_eos_tokens: bool = False) -> Vocabulary:
        """ Creates the Vocabulary object of the tokenizer.
        See the docstring of the Vocabulary class for more details about how to use it.
        NOTE: token index 0 is often used as a padding index during training

        :param sos_eos_tokens: will include Start Of Sequence (SOS) and End Of Sequence (tokens)
        :return: the vocabulary object
        """
        vocab = Vocabulary({'PAD_None': 0})

        if self.additional_tokens['Chord']:
            vocab.add_event(f'Chord_{chord_quality}' for chord_quality in CHORD_MAPS)

        # PITCH
        vocab.add_event('Chord_note')
        vocab.add_event(f'OctavePitch_{i}' for i in range(8))
        vocab.add_event(f'RelativePitch_{i}' for i in range(12))
        # vocab.add_event(f'Pitch_{i}' for i in self.pitch_range)

        # VELOCITY
        vocab.add_event(f'Velocity_{i}' for i in self.velocities)

        # DURATION
        vocab.add_event(f'Duration_{".".join(map(str, duration))}' for duration in self.durations)

        # TIME SHIFT (same as durations)
        vocab.add_event('Time-Shift_0.0.1')  # for a time shift of 0
        vocab.add_event(f'Time-Shift_{".".join(map(str, duration))}' for duration in self.durations)

        # PROGRAM
        if self.additional_tokens['Program']:
            vocab.add_event(f'Program_{program}' for program in range(-1, 128))

        # SOS & EOS
        if sos_eos_tokens:
            vocab.add_sos_eos_to_vocab()

        return vocab

    def _create_token_types_graph(self) -> Dict[str, List[str]]:
        """ Returns a graph (as a dictionary) of the possible token
        types successions.
        NOTE: Program type is not referenced here, you can add it manually by
        modifying the tokens_types_graph class attribute following your strategy.

        :return: the token types transitions dictionary
        """
        dic = {'Pitch': ['Velocity'], 'Velocity': ['Duration'], 'Duration': ['Time-Shift'], 'Time-Shift': ['Pitch']}
        self._add_pad_type_to_graph(dic)
        return dic

    def token_types_errors(self, tokens: List[int], consider_pad: bool = False) -> float:
        """ Checks if a sequence of tokens is constituted of good token types
        successions and returns the error ratio (lower is better).
        The Pitch values are also analyzed:
            - a pitch token should not be present if the same pitch is already played at the time

        :param tokens: sequence of tokens to check
        :param consider_pad: if True will continue the error detection after the first PAD token (default: False)
        :return: the error ratio (lower is better)
        """
        err = 0
        previous_type = self.vocab.token_type(tokens[0])
        current_pitches = []

        def check(tok: int):
            nonlocal err
            nonlocal previous_type
            nonlocal current_pitches
            token_type, token_value = self.vocab.token_to_event[tok].split('_')

            # Good token type
            if token_type in self.tokens_types_graph[previous_type]:
                if token_type == 'Pitch':
                    if int(token_value) in current_pitches:
                        err += 1  # pitch already played at current position
                    else:
                        current_pitches.append(int(token_value))
                elif token_type == 'Time-Shift':
                    if self._token_duration_to_ticks(token_value, 48) > 0:
                        current_pitches = []  # moving in time, list reset
            # Bad token type
            else:
                err += 1
            previous_type = token_type

        if consider_pad:
            for token in tokens[1:]:
                check(token)
        else:
            for token in tokens[1:]:
                if previous_type == 'PAD':
                    break
                check(token)
        return err / len(tokens)

    def detect_chords(self, list_notes: List[Note], time_division: int, beat_res: int = 4, onset_offset: int = 1,
                      only_known_chord: bool = False, simul_notes_limit: int = 20, verbose=False) -> List[Event]:
        """ Chord detection method.
        NOTE: make sure to sort notes by start time then pitch before: notes.sort(key=lambda x: (x.start, x.pitch))
        NOTE2: on very large tracks with high note density this method can be very slow !
        If you plan to use it with the Maestro or GiantMIDI datasets, it can take up to
        hundreds of seconds per MIDI depending on your cpu.
        One time step at a time, it will analyse the notes played together
        and detect possible chords.

        :param notes: notes to analyse (sorted by starting time, them pitch)
        :param time_division: MIDI time division / resolution, in ticks/beat (of the MIDI being parsed)
        :param beat_res: beat resolution, i.e. nb of samples per beat (default 4)
        :param onset_offset: maximum offset (in samples) ∈ N separating notes starts to consider them
                                starting at the same time / onset (default is 1)
        :param only_known_chord: will select only known chords. If set to False, non recognized chords of
                                n notes will give a chord_n event (default False)
        :param simul_notes_limit: nb of simultaneous notes being processed when looking for a chord
                this parameter allows to speed up the chord detection (default 20)
        :return: the detected chords as Event objects
        """
        assert simul_notes_limit >= 5, 'simul_notes_limit must be higher than 5, chords can be made up to 5 notes'
        tuples = []
        for note in list_notes:
            tuples.append((note.pitch, int(note.start), int(note.end), int(note.velocity)))
        notes = np.asarray(tuples)

        time_div_half = time_division // 2
        onset_offset = time_division * onset_offset / beat_res

        count = 0
        previous_tick = -1
        detected_chords = []
        note_belong_to_chord_id = dict()
        while count < len(notes):
            # Checks we moved in time after last step, otherwise discard this tick
            if notes[count, 1] == previous_tick:
                count += 1
                continue

            # Gathers the notes around the same time step
            # Reduce the scope of the search
            notes_to_consider = notes[count:count + simul_notes_limit].copy()
            old_true_notes_indexes = np.arange(count, count + simul_notes_limit)  # keep track of true note indexes
            # Take notes withing onset_offset samples of the first note
            indexes_valid = np.where(notes_to_consider[:, 1] <= notes_to_consider[0, 1] + onset_offset)
            true_notes_indexes = old_true_notes_indexes[indexes_valid]
            onset_notes = notes_to_consider[indexes_valid]
            # Take notes that end close to the first note's end
            indexes_valid = np.where(np.abs(onset_notes[:, 2] - onset_notes[0, 2]) < time_div_half)
            true_notes_indexes = true_notes_indexes[indexes_valid]
            onset_notes = onset_notes[indexes_valid]

            # if there are at least 3 notes, try to find the chord
            if len(onset_notes) >= 3:
                found, chord_name, root_id, chord_notes_indexes = filter_notes_find_chord_and_root(onset_notes[:, 0], self.inverted_chord_map)
                # if found:
                #     found, chord_name, root_id, chord_notes_indexes = filter_notes_find_chord_and_root(notes_to_consider[:, 0], self.inverted_chord_map)

                if found:
                    detected_chord_id = len(detected_chords)
                    # get the indexes of the notes in the chord wrt the onset_notes array
                    relative_indexes_chord_notes_in_onset_notes = np.array(chord_notes_indexes)
                    # get true indexes of the notes in the chord (indexes of the note stream)
                    true_indexes = true_notes_indexes[relative_indexes_chord_notes_in_onset_notes]
                    # for each note, track the chords it belongs to in note_belong_to_chord_id
                    for i in true_indexes:
                        if i not in note_belong_to_chord_id.keys():
                            note_belong_to_chord_id[i] = [detected_chord_id]
                        else:
                            note_belong_to_chord_id[i].append(detected_chord_id)
                    # save the info of the detected chord
                    root_position_in_sorted_onset = chord_notes_indexes[root_id]
                    root_pitch = onset_notes[root_position_in_sorted_onset, 0]
                    onset = np.min([notes[i, 1] for i in true_indexes])
                    offset = int(np.mean([notes[i, 2] for i in true_indexes]))
                    velocity = self.velocities[int(np.argmin(np.abs(self.velocities - int(np.mean([notes[i, 3]  for i in true_indexes])))))] # quantize velocity
                    detected_chords.append((chord_name, true_indexes, root_pitch, onset, offset, velocity))
                    if verbose: print(f'New chord detected: {chord_name}, root {root_pitch} with notes: {true_indexes}, onset: {onset}, offset: {offset}, velocity: {velocity}')

            count += 1

        # now we need to delete some the redundant detected chords to have just one chord per note
        indexes_chords_to_remove = []

        for note, chord_ids in note_belong_to_chord_id.copy().items():
            # remove chords that were already filtered
            chord_ids = sorted(set(chord_ids) - set(indexes_chords_to_remove))
            if len(chord_ids) == 0:  # if not remaining chords, then the note should be removed
                del note_belong_to_chord_id[note]
            else:
                note_belong_to_chord_id[note] = chord_ids  # update the chord_ids
                if len(chord_ids) > 1:  # if several, we  need to filter by the number of notes in the chords
                    chords = [detected_chords[i] for i in chord_ids]
                    selected_chord = np.argmax([len(c[1]) for c in chords])
                    note_belong_to_chord_id[note] = [chord_ids[selected_chord]]
                    for i_c, c in enumerate(chord_ids):
                        if i_c != selected_chord:
                            indexes_chords_to_remove.append(c)
        for note, chord_ids in note_belong_to_chord_id.copy().items():
            chord_ids = sorted(set(chord_ids) - set(indexes_chords_to_remove))
            if len(chord_ids) == 0:  # if not remaining chords, then the note should be removed
                del note_belong_to_chord_id[note]
            else:
                note_belong_to_chord_id[note] = chord_ids  # update the chord_ids
        selected_chords = [detected_chords[i] for i in range(len(detected_chords)) if i not in indexes_chords_to_remove]
        selected_chords_ids = [i for i in range(len(detected_chords)) if i not in indexes_chords_to_remove]
        # check that all notes are used just once
        all_chord_notes = []
        for c in selected_chords:
            all_chord_notes += list(c[1])
        assert len(all_chord_notes) == len(set(all_chord_notes))

        # format new stream of notes, removing chord notes from them, and inserting "chord" to be able to track timeshifts
        new_list_notes = []
        note_dict_keys = list(note_belong_to_chord_id.keys())
        inserted_chords = []
        count_added = 0
        for i in range(len(list_notes)):
            if i not in note_dict_keys:
                new_list_notes.append(list_notes[i])
            else:
                assert len(note_belong_to_chord_id[i]) == 1
                chord_id = note_belong_to_chord_id[i][0]
                if chord_id not in inserted_chords:
                    inserted_chords.append(chord_id)
                    count_added += 1
                    chord_id, _, root_pitch, onset, offset, velocity = detected_chords[chord_id]
                    new_list_notes.append(Note(velocity=velocity, start=onset, end=offset, pitch=chord_id + '_' + str(root_pitch)))
        # check the new count of notes (all previous notes - the number of notes in the chords + the number of chords)
        assert len(new_list_notes) == (len(list_notes) - len(all_chord_notes) + len(selected_chords))
        return new_list_notes


if __name__ == '__main__':
    from miditoolkit import MidiFile

    pitch_range = range(21, 109)
    beat_res = {(0, 4): 8, (4, 12): 4}
    nb_velocities = 32
    tokenizer_structured = ChordStructured(pitch_range, beat_res, nb_velocities)
    # tokenizer_structured = Structured(pitch_range, beat_res, nb_velocities)

    path = '/home/cedric/Documents/pianocktail/data/music/processed/vkgoeswild_processed/ac_dc_hells_bells_vkgoeswild_piano_cover_processed.mid'
    midi = MidiFile(path)
    tokens = tokenizer_structured.midi_to_tokens(midi)
    midi = tokenizer_structured.tokens_to_midi(tokens)
    midi.dump("/home/cedric/Desktop/tes/transcribed.mid")