asigalov61's picture
Update app.py
a8bfec5 verified
raw
history blame
17.1 kB
# =================================================================================================
# https://huggingface.co/spaces/asigalov61/Melody-Harmonizer-Transformer
# =================================================================================================
import os
import time as reqtime
import datetime
from pytz import timezone
import gradio as gr
import spaces
import os
from tqdm import tqdm
import numpy as np
import torch
from x_transformer_1_23_2 import *
import random
import TMIDIX
from midi_to_colab_audio import midi_to_colab_audio
# =================================================================================================
# @spaces.GPU
def Harmonize_Melody(input_src_midi,
source_melody_transpose_value,
model_top_k_sampling_value,
texture_harmonized_chords,
melody_MIDI_patch_number,
harmonized_accompaniment_MIDI_patch_number,
base_MIDI_patch_number
):
print('=' * 70)
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('=' * 70)
start_time = reqtime.time()
sfn = os.path.basename(input_src_midi.name)
sfn1 = sfn.split('.')[0]
print('Input src MIDI name:', sfn)
print('=' * 70)
print('Requested settings:')
print('Source melody transpose value:', source_melody_transpose_value)
print('Model top_k sampling value:', model_top_k_sampling_value)
print('Texture harmonized chords:', texture_harmonized_chords)
print('Melody MIDI patch number:', melody_MIDI_patch_number)
print('Harmonized accompaniment MIDI patch number:', harmonized_accompaniment_MIDI_patch_number)
print('Base MIDI patch number:', base_MIDI_patch_number)
print('=' * 70)
#==================================================================
print('=' * 70)
print('Loading seed melody...')
#===============================================================================
# Raw single-track ms score
raw_score = TMIDIX.midi2single_track_ms_score(input_src_midi.name)
#===============================================================================
# Enhanced score notes
escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
#===============================================================================
# Augmented enhanced score notes
escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=16)
cscore = [c[0] for c in TMIDIX.chordify_score([1000, escore_notes])]
mel_score = TMIDIX.fix_monophonic_score_durations(TMIDIX.recalculate_score_timings(cscore))
mel_score = TMIDIX.transpose_escore_notes(mel_score, source_melody_transpose_value)
print('=' * 70)
print('Done!')
print('=' * 70)
mel_pitches = [p[4] % 12 for p in mel_score]
print('Melody has', len(mel_pitches), 'notes')
print('=' * 70)
#===============================================================================
print('=' * 70)
print('Melody Harmonizer Transformer')
print('=' * 70)
print('Loading Melody Harmonizer Transformer Model...')
SEQ_LEN = 75
PAD_IDX = 144
# instantiate the model
model = TransformerWrapper(
num_tokens = PAD_IDX+1,
max_seq_len = SEQ_LEN,
attn_layers = Decoder(dim = 1024, depth = 12, heads = 16, attn_flash = True)
)
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)
model_path = 'Melody_Harmonizer_Transformer_Trained_Model_14961_steps_0.4155_loss_0.8664_acc.pth'
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.cpu()
dtype = torch.bfloat16
ctx = torch.amp.autocast(device_type='cpu', dtype=dtype)
model.eval()
print('Done!')
print('=' * 70)
print('Harmonizing...')
print('=' * 70)
#===============================================================================
mel_remainder_value = (((len(mel_pitches) // 24)+1) * 24) - len(mel_pitches)
mel_pitches_ext = mel_pitches + mel_pitches[:mel_remainder_value]
song = []
for i in range(0, len(mel_pitches_ext)-12, 12):
mel_chunk = mel_pitches_ext[i:i+24]
data = [141] + mel_chunk + [142]
for j in range(24):
data.append(mel_chunk[j])
x = torch.tensor([data], dtype=torch.long, device='cpu')
with ctx:
out = model.generate(x,
1,
filter_logits_fn=top_k,
filter_kwargs={'k': model_top_k_sampling_value},
temperature=1.0,
return_prime=False,
verbose=False)
outy = out.tolist()[0]
data.append(outy[0])
if i != len(mel_pitches_ext)-24:
song.extend(data[26:50])
else:
song.extend(data[26:])
song = song[:len(mel_pitches) * 2]
#===============================================================================
print('Harmonized', len(song) // 2, 'out of', len(mel_pitches), 'notes')
print('Done!')
print('=' * 70)
#===============================================================================
def find_best_match(matches_indexes, previous_match_index):
msigs = []
for midx in matches_indexes:
mat = all_chords_ptcs_chunks[midx]
msig = []
for m in mat:
msig.extend([sum(m) / len(m), len(m)])
msigs.append(msig)
pmat = all_chords_ptcs_chunks[previous_match_index]
psig = []
for p in pmat:
psig.extend([sum(p) / len(p), len(p)])
dists = []
for m in msigs:
dists.append(TMIDIX.minkowski_distance(psig, m))
min_dist = min(dists)
min_dist_idx = dists.index(min_dist)
return matches_indexes[min_dist_idx]
#===============================================================================
if texture_harmonized_chords:
print('=' * 70)
print('Texturing harmonized chords...')
print('=' * 70)
chunk_length = 2
harm_chords = [TMIDIX.ALL_CHORDS_FILTERED[s-12] for s in song if 11 < s < 141]
harm_toks = [TMIDIX.ALL_CHORDS_FILTERED.index(c) for c in harm_chords] + [TMIDIX.ALL_CHORDS_FILTERED.index(harm_chords[-1])] * (chunk_length - (len(harm_chords) % chunk_length))
final_song = []
trg_chunk = np.array(harm_toks[:chunk_length])
sidxs = np.where((src_chunks == trg_chunk).all(axis=1))[0].tolist()
sidx = random.choice(sidxs)
pidx = sidx
final_song.extend(all_chords_ptcs_chunks[sidx])
for i in tqdm(range(chunk_length, len(harm_toks), chunk_length)):
trg_chunk = np.array(harm_toks[i:i+chunk_length])
sidxs = np.where((src_chunks == trg_chunk).all(axis=1))[0].tolist()
if len(sidxs) > 0:
sidx = find_best_match(sidxs, pidx)
pidx = sidx
final_song.extend(all_chords_ptcs_chunks[sidx])
else:
print('Dead end!')
break
final_song = final_song[:len(harm_chords)]
print('=' * 70)
print(len(final_song))
print('=' * 70)
print('Done!')
print('=' * 70)
print('Rendering textured results...')
print('=' * 70)
output_score = []
time = 0
patches = [0] * 16
patches[0] = harmonized_accompaniment_MIDI_patch_number
if base_MIDI_patch_number > -1:
patches[2] = base_MIDI_patch_number
patches[3] = melody_MIDI_patch_number
i = 0
for s in final_song:
time = mel_score[i][1] * 16
dur = mel_score[i][2] * 16
output_score.append(['note', time, dur, 3, mel_score[i][4], 115+(mel_score[i][4] % 12), 40])
for c in s:
pitch = c
output_score.append(['note', time, dur, 0, pitch, max(40, pitch), harmonized_accompaniment_MIDI_patch_number])
if base_MIDI_patch_number > -1:
output_score.append(['note', time, dur, 2, (s[-1] % 12) + 24, 120-(s[-1] % 12), base_MIDI_patch_number])
i += 1
else:
print('Rendering results...')
print('=' * 70)
output_score = []
time = 0
patches = [0] * 16
patches[0] = harmonized_accompaniment_MIDI_patch_number
if base_MIDI_patch_number > -1:
patches[2] = base_MIDI_patch_number
patches[3] = melody_MIDI_patch_number
i = 0
for s in song:
if 11 < s < 141:
time = mel_score[i][1] * 16
dur = mel_score[i][2] * 16
output_score.append(['note', time, dur, 3, mel_score[i][4], 115+(mel_score[i][4] % 12), 40])
chord = TMIDIX.ALL_CHORDS_FILTERED[s-12]
for c in chord:
pitch = 48+c
output_score.append(['note', time, dur, 0, pitch, max(40, pitch), harmonized_accompaniment_MIDI_patch_number])
if base_MIDI_patch_number > -1:
output_score.append(['note', time, dur, 2, chord[-1]+24, 120-chord[-1], base_MIDI_patch_number])
i += 1
fn1 = "Melody-Harmonizer-Transformer-Composition"
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score,
output_signature = 'Melody Harmonizer Transformer',
output_file_name = fn1,
track_name='Project Los Angeles',
list_of_MIDI_patches=patches
)
new_fn = fn1+'.mid'
audio = midi_to_colab_audio(new_fn,
soundfont_path=soundfont,
sample_rate=16000,
volume_scale=10,
output_for_gradio=True
)
#========================================================
output_midi_title = str(fn1)
output_midi = str(new_fn)
output_audio = (16000, audio)
output_plot = TMIDIX.plot_ms_SONG(output_score, plot_title=output_midi, return_plt=True)
print('Done!')
#========================================================
harmonization_summary_string = '=' * 70
harmonization_summary_string += '\n'
harmonization_summary_string += 'Source melody has ' + str(len(mel_pitches)) + ' monophonic pitches' + '\n'
harmonization_summary_string += '=' * 70
harmonization_summary_string += '\n'
harmonization_summary_string += 'Harmonized ' + str(len(song) // 2) + ' out of ' + str(len(mel_pitches)) + ' source melody pitches' + '\n'
harmonization_summary_string += '=' * 70
harmonization_summary_string += '\n'
#========================================================
print('-' * 70)
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('-' * 70)
print('Req execution time:', (reqtime.time() - start_time), 'sec')
return output_audio, output_plot, output_midi, harmonization_summary_string
# =================================================================================================
if __name__ == "__main__":
PDT = timezone('US/Pacific')
print('=' * 70)
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('=' * 70)
#===============================================================================
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
print('Loading Melody Harmonizer Transformer Pitches Chords Pairs Data...')
print('=' * 70)
all_chords_toks_chunks, all_chords_ptcs_chunks = TMIDIX.Tegridy_Any_Pickle_File_Reader('Melody_Harmonizer_Transformer_Pitches_Chords_Pairs_Data')
print('=' * 70)
print('Total number of pitches chords pairs:', len(all_chords_toks_chunks))
print('=' * 70)
print('Loading pitches chords pairs...')
src_chunks = np.array(all_chords_toks_chunks)
print('Done!')
print('=' * 70)
#===============================================================================
app = gr.Blocks()
with app:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Melody Harmonizer Transformer</h1>")
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Harmonize any MIDI melody with transformers</h1>")
gr.Markdown(
"![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Melody-Harmonizer-Transformer&style=flat)\n\n"
"This is a demo for Monster MIDI Dataset\n\n"
"Check out [Monster MIDI Dataset](https://github.com/asigalov61/Monster-MIDI-Dataset) on GitHub!\n\n"
)
gr.Markdown("## Upload your MIDI or select a sample example below")
gr.Markdown("### For best results upload only monophonic melody MIDIs")
input_src_midi = gr.File(label="Source MIDI", file_types=[".midi", ".mid", ".kar"])
gr.Markdown("## Select harmonization options")
source_melody_transpose_value = gr.Slider(-6, 6, value=0, step=1, label="Source melody transpose value", info="You can transpose source melody by specified number of semitones if the original melody key does not harmonize well")
model_top_k_sampling_value = gr.Slider(1, 50, value=25, step=1, label="Model sampling top_k value", info="Decreasing this value may produce better harmonization results in some cases")
texture_harmonized_chords = gr.Checkbox(label="Texture harmonized chords", value=True, info="Texture harmonized chords for more pleasant listening")
melody_MIDI_patch_number = gr.Slider(0, 127, value=40, step=1, label="Source melody MIDI patch number")
harmonized_accompaniment_MIDI_patch_number = gr.Slider(0, 127, value=0, step=1, label="Harmonized accompaniment MIDI patch number")
base_MIDI_patch_number = gr.Slider(-1, 127, value=35, step=1, label="Base MIDI patch number")
run_btn = gr.Button("Harmonize Melody", variant="primary")
gr.Markdown("## Harmonization results")
output_summary = gr.Textbox(label="Melody harmonization summary")
output_audio = gr.Audio(label="Output MIDI audio", format="mp3", elem_id="midi_audio")
output_plot = gr.Plot(label="Output MIDI score plot")
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
run_event = run_btn.click(Harmonize_Melody,
[input_src_midi,
source_melody_transpose_value,
model_top_k_sampling_value,
texture_harmonized_chords,
melody_MIDI_patch_number,
harmonized_accompaniment_MIDI_patch_number,
base_MIDI_patch_number],
[output_audio, output_plot, output_midi, output_summary]
)
gr.Examples(
[
["USSR Anthem Seed Melody.mid", 0, 25, True, 40, 0, 35],
],
[input_src_midi,
source_melody_transpose_value,
model_top_k_sampling_value,
texture_harmonized_chords,
melody_MIDI_patch_number,
harmonized_accompaniment_MIDI_patch_number,
base_MIDI_patch_number],
[output_audio, output_plot, output_midi, output_summary],
Harmonize_Melody,
cache_examples=True,
)
app.queue().launch()