|
import argparse |
|
import glob |
|
import os.path |
|
|
|
import gradio as gr |
|
|
|
import pickle |
|
import tqdm |
|
import json |
|
|
|
import MIDI |
|
from midi_synthesizer import synthesis |
|
|
|
import copy |
|
from collections import Counter |
|
import random |
|
import statistics |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
in_space = os.getenv("SYSTEM") == "spaces" |
|
|
|
|
|
|
|
def match_midi(midi, progress=gr.Progress()): |
|
|
|
print('=' * 70) |
|
print('Loading MIDI file...') |
|
|
|
|
|
|
|
score = MIDI.midi2score(midi) |
|
|
|
events_matrix = [] |
|
|
|
track_count = 0 |
|
|
|
for s in score: |
|
|
|
if track_count > 0: |
|
track = s |
|
track.sort(key=lambda x: x[1]) |
|
events_matrix.extend(track) |
|
else: |
|
midi_ticks = s |
|
|
|
track_count += 1 |
|
|
|
events_matrix.sort(key=lambda x: x[1]) |
|
|
|
mult_pitches_counts = [] |
|
|
|
for i in range(-6, 6): |
|
|
|
events_matrix1 = [] |
|
|
|
for e in events_matrix: |
|
|
|
ev = copy.deepcopy(e) |
|
|
|
if e[0] == 'note': |
|
if e[3] == 9: |
|
ev[4] = ((e[4] % 128) + 128) |
|
else: |
|
ev[4] = ((e[4] % 128) + i) |
|
|
|
events_matrix1.append(ev) |
|
|
|
pitches_counts = [[y[0],y[1]] for y in Counter([y[4] for y in events_matrix1 if y[0] == 'note']).most_common()] |
|
pitches_counts.sort(key=lambda x: x[0], reverse=True) |
|
|
|
mult_pitches_counts.append(pitches_counts) |
|
|
|
patches_list = sorted(list(set([y[3] for y in events_matrix if y[0] == 'patch_change']))) |
|
|
|
|
|
|
|
|
|
ms_score = MIDI.midi2ms_score(midi) |
|
|
|
ms_events_matrix = [] |
|
|
|
itrack1 = 1 |
|
|
|
while itrack1 < len(ms_score): |
|
for event in ms_score[itrack1]: |
|
if event[0] == 'note': |
|
ms_events_matrix.append(event) |
|
itrack1 += 1 |
|
|
|
ms_events_matrix.sort(key=lambda x: x[1]) |
|
|
|
|
|
chords = [] |
|
pe = ms_events_matrix[0] |
|
cho = [] |
|
for e in ms_events_matrix: |
|
if (e[1] - pe[1]) == 0: |
|
if e[3] != 9: |
|
if (e[4] % 12) not in cho: |
|
cho.append(e[4] % 12) |
|
else: |
|
if len(cho) > 0: |
|
chords.append(sorted(cho)) |
|
cho = [] |
|
if e[3] != 9: |
|
if (e[4] % 12) not in cho: |
|
cho.append(e[4] % 12) |
|
|
|
pe = e |
|
|
|
if len(cho) > 0: |
|
chords.append(sorted(cho)) |
|
|
|
ms_chords_counts = sorted([[list(key), val] for key,val in Counter([tuple(c) for c in chords if len(c) > 1]).most_common()], reverse=True, key = lambda x: x[1]) |
|
|
|
times = [] |
|
pt = ms_events_matrix[0][1] |
|
start = True |
|
for e in ms_events_matrix: |
|
if (e[1]-pt) != 0 or start == True: |
|
times.append((e[1]-pt)) |
|
start = False |
|
pt = e[1] |
|
|
|
durs = [e[2] for e in ms_events_matrix] |
|
vels = [e[5] for e in ms_events_matrix] |
|
|
|
avg_time = int(sum(times) / len(times)) |
|
avg_dur = int(sum(durs) / len(durs)) |
|
|
|
mode_time = statistics.mode(times) |
|
mode_dur = statistics.mode(durs) |
|
|
|
median_time = int(statistics.median(times)) |
|
median_dur = int(statistics.median(durs)) |
|
|
|
|
|
|
|
print('=' * 70) |
|
print('Done!') |
|
print('=' * 70) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
maximum_match_ratio_to_search_for = 1 |
|
|
|
|
|
|
|
pitches_counts_cutoff_threshold_ratio = 0 |
|
search_transposed_pitches = False |
|
skip_exact_matches = True |
|
|
|
|
|
|
|
add_pitches_counts_ratios = False |
|
add_timings_ratios = False |
|
add_durations_ratios = False |
|
|
|
print('=' * 70) |
|
print('MIDI Pitches Search') |
|
print('=' * 70) |
|
|
|
final_ratios = [] |
|
|
|
for d in progress.tqdm(meta_data): |
|
|
|
|
|
p_counts = d[1][10][1] |
|
p_counts.sort(reverse = True, key = lambda x: x[1]) |
|
max_p_count = p_counts[0][1] |
|
trimmed_p_counts = [y for y in p_counts if y[1] >= (max_p_count * pitches_counts_cutoff_threshold_ratio)] |
|
total_p_counts = sum([y[1] for y in trimmed_p_counts]) |
|
|
|
if search_transposed_pitches: |
|
search_pitches = mult_pitches_counts |
|
else: |
|
search_pitches = [mult_pitches_counts[6]] |
|
|
|
|
|
|
|
ratios_list = [] |
|
|
|
|
|
|
|
atrat = [0] |
|
|
|
if add_timings_ratios: |
|
|
|
source_times = [avg_time, |
|
median_time, |
|
mode_time] |
|
|
|
match_times = meta_data[0][1][3][1] |
|
|
|
times_ratios = [] |
|
|
|
for i in range(len(source_times)): |
|
maxtratio = max(source_times[i], match_times[i]) |
|
mintratio = min(source_times[i], match_times[i]) |
|
times_ratios.append(mintratio / maxtratio) |
|
|
|
avg_times_ratio = sum(times_ratios) / len(times_ratios) |
|
|
|
atrat[0] = avg_times_ratio |
|
|
|
|
|
|
|
adrat = [0] |
|
|
|
if add_durations_ratios: |
|
|
|
source_durs = [avg_dur, |
|
median_dur, |
|
mode_dur] |
|
|
|
match_durs = meta_data[0][1][4][1] |
|
|
|
durs_ratios = [] |
|
|
|
for i in range(len(source_durs)): |
|
maxtratio = max(source_durs[i], match_durs[i]) |
|
mintratio = min(source_durs[i], match_durs[i]) |
|
durs_ratios.append(mintratio / maxtratio) |
|
|
|
avg_durs_ratio = sum(durs_ratios) / len(durs_ratios) |
|
|
|
adrat[0] = avg_durs_ratio |
|
|
|
|
|
|
|
for m in search_pitches: |
|
|
|
sprat = [] |
|
|
|
m.sort(reverse = True, key = lambda x: x[1]) |
|
max_pitches_count = m[0][1] |
|
trimmed_pitches_counts = [y for y in m if y[1] >= (max_pitches_count * pitches_counts_cutoff_threshold_ratio)] |
|
total_pitches_counts = sum([y[1] for y in trimmed_pitches_counts]) |
|
|
|
same_pitches = set([T[0] for T in trimmed_p_counts]) & set([m[0] for m in trimmed_pitches_counts]) |
|
num_same_pitches = len(same_pitches) |
|
|
|
if num_same_pitches == len(trimmed_pitches_counts): |
|
same_pitches_ratio = (num_same_pitches / len(trimmed_p_counts)) |
|
else: |
|
same_pitches_ratio = (num_same_pitches / max(len(trimmed_p_counts), len(trimmed_pitches_counts))) |
|
|
|
if skip_exact_matches: |
|
if same_pitches_ratio == 1: |
|
same_pitches_ratio = 0 |
|
|
|
sprat.append(same_pitches_ratio) |
|
|
|
|
|
|
|
spcrat = [0] |
|
|
|
if add_pitches_counts_ratios: |
|
|
|
same_trimmed_p_counts = sorted([T for T in trimmed_p_counts if T[0] in same_pitches], reverse = True) |
|
same_trimmed_pitches_counts = sorted([T for T in trimmed_pitches_counts if T[0] in same_pitches], reverse = True) |
|
|
|
same_trimmed_p_counts_ratios = [[s[0], s[1] / total_p_counts] for s in same_trimmed_p_counts] |
|
same_trimmed_pitches_counts_ratios = [[s[0], s[1] / total_pitches_counts] for s in same_trimmed_pitches_counts] |
|
|
|
same_pitches_counts_ratios = [] |
|
|
|
for i in range(len(same_trimmed_p_counts_ratios)): |
|
mincratio = min(same_trimmed_p_counts_ratios[i][1], same_trimmed_pitches_counts_ratios[i][1]) |
|
maxcratio = max(same_trimmed_p_counts_ratios[i][1], same_trimmed_pitches_counts_ratios[i][1]) |
|
same_pitches_counts_ratios.append([same_trimmed_p_counts_ratios[i][0], mincratio / maxcratio]) |
|
|
|
same_counts_ratios = [s[1] for s in same_pitches_counts_ratios] |
|
|
|
if len(same_counts_ratios) > 0: |
|
avg_same_pitches_counts_ratio = sum(same_counts_ratios) / len(same_counts_ratios) |
|
else: |
|
avg_same_pitches_counts_ratio = 0 |
|
|
|
spcrat[0] = avg_same_pitches_counts_ratio |
|
|
|
|
|
|
|
r_list = [sprat[0]] |
|
|
|
if add_pitches_counts_ratios: |
|
r_list.append(spcrat[0]) |
|
|
|
if add_timings_ratios: |
|
r_list.append(atrat[0]) |
|
|
|
if add_durations_ratios: |
|
r_list.append(adrat[0]) |
|
|
|
ratios_list.append(r_list) |
|
|
|
|
|
|
|
avg_ratios_list = [] |
|
|
|
for r in ratios_list: |
|
avg_ratios_list.append(sum(r) / len(r)) |
|
|
|
|
|
|
|
final_ratio = max(avg_ratios_list) |
|
|
|
if final_ratio > maximum_match_ratio_to_search_for: |
|
final_ratio = 0 |
|
|
|
final_ratios.append(final_ratio) |
|
|
|
|
|
|
|
max_ratio = max(final_ratios) |
|
max_ratio_index = final_ratios.index(max_ratio) |
|
|
|
print('FOUND') |
|
print('=' * 70) |
|
print('Match ratio', max_ratio) |
|
print('MIDI file name', meta_data[max_ratio_index][0]) |
|
print('=' * 70) |
|
|
|
fn = meta_data[max_ratio_index][0] |
|
|
|
|
|
|
|
md = meta_data[max_ratio_index] |
|
|
|
mid_seq = md[1][17:-1] |
|
mid_seq_ticks = md[1][16][1] |
|
mdata = md[1][:16] |
|
|
|
txt_mdata = '' |
|
|
|
for m in mdata: |
|
txt_mdata += str(m[0]) + ':' + str(m[1]) |
|
txt_mdata += chr(10) |
|
|
|
x = [] |
|
y = [] |
|
c = [] |
|
|
|
colors = ['red', 'yellow', 'green', 'cyan', |
|
'blue', 'pink', 'orange', 'purple', |
|
'gray', 'white', 'gold', 'silver', |
|
'lightgreen', 'indigo', 'maroon', 'turquoise'] |
|
|
|
for s in [m for m in mid_seq if m[0] == 'note']: |
|
x.append(s[1]) |
|
y.append(s[4]) |
|
c.append(colors[s[3]]) |
|
|
|
plt.close() |
|
plt.figure(figsize=(14,5)) |
|
ax=plt.axes(title='MIDI Search Plot') |
|
ax.set_facecolor('black') |
|
|
|
plt.scatter(x,y, c=c) |
|
plt.xlabel("Time") |
|
plt.ylabel("Pitch") |
|
|
|
with open(f"output.mid", 'wb') as f: |
|
f.write(MIDI.score2midi([mid_seq_ticks, mid_seq])) |
|
audio = synthesis(MIDI.score2opus([mid_seq_ticks, mid_seq]), soundfont_path) |
|
yield txt_mdata, "MIDI-Match-Sample.mid", (44100, audio), plt |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--share", action="store_true", default=False, help="share gradio app") |
|
parser.add_argument("--port", type=int, default=7860, help="gradio server port") |
|
parser.add_argument("--max-gen", type=int, default=1024, help="max") |
|
|
|
opt = parser.parse_args() |
|
|
|
soundfont_path = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" |
|
meta_data_path = "meta-data/LAMD_META_10000.pickle" |
|
|
|
print('Loading meta-data...') |
|
with open(meta_data_path, 'rb') as f: |
|
meta_data = pickle.load(f) |
|
print('Done!') |
|
|
|
app = gr.Blocks() |
|
with app: |
|
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Match</h1>") |
|
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'> # Upload any MIDI file to find its closest match </h1>") |
|
|
|
gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.MIDI-Match&style=flat)\n\n" |
|
"MIDI Match\n\n" |
|
"Demo for [MIDI Match](https://github.com/asigalov61)\n\n" |
|
"[Open In Colab]" |
|
"(https://colab.research.google.com/github/asigalov61/MIDI-Match/blob/main/demo.ipynb)" |
|
" for faster running and longer generation" |
|
) |
|
|
|
input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary") |
|
|
|
gr.Markdown("# Match results") |
|
|
|
output_plot = gr.Plot(label="Output MIDI match sample plot") |
|
output_audio = gr.Audio(label="Output MIDI match sample audio", format="mp3", elem_id="midi_audio") |
|
output_midi = gr.File(label="Output MIDI match sample file", file_types=[".mid"]) |
|
output_midi_seq = gr.Textbox(label="Output MIDI match metadata") |
|
|
|
run_event = input_midi.upload(match_midi, [input_midi], |
|
[output_midi_seq, output_midi, output_audio, output_plot]) |
|
|
|
app.queue(1).launch(server_port=opt.port, share=opt.share, inbrowser=True) |