|
|
|
|
|
|
|
|
|
import os |
|
import time as reqtime |
|
import datetime |
|
from pytz import timezone |
|
|
|
import torch |
|
from imagen_pytorch import Unet, Imagen, ImagenTrainer |
|
from imagen_pytorch.data import Dataset |
|
|
|
import spaces |
|
import gradio as gr |
|
|
|
import numpy as np |
|
|
|
import random |
|
import tqdm |
|
|
|
import TMIDIX |
|
import TPLOTS |
|
|
|
from midi_to_colab_audio import midi_to_colab_audio |
|
|
|
|
|
|
|
@spaces.GPU |
|
def Generate_POP_Medley(input_num_medley_comps, input_melody_patch): |
|
|
|
print('=' * 70) |
|
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) |
|
start_time = reqtime.time() |
|
print('=' * 70) |
|
|
|
print('Req number of medley compositions:', input_num_medley_comps) |
|
print('Req melody MIDI patch number:', input_melody_patch) |
|
print('=' * 70) |
|
|
|
|
|
|
|
|
|
|
|
print('Loading model...') |
|
|
|
DIM = 64 |
|
CHANS = 1 |
|
TSTEPS = 1000 |
|
DEVICE = 'cpu' |
|
|
|
unet = Unet( |
|
dim = DIM, |
|
dim_mults = (1, 2, 4, 8), |
|
num_resnet_blocks = 1, |
|
channels=CHANS, |
|
layer_attns = (False, False, False, True), |
|
layer_cross_attns = False |
|
) |
|
|
|
imagen = Imagen( |
|
condition_on_text = False, |
|
unets = unet, |
|
channels=CHANS, |
|
image_sizes = 128, |
|
timesteps = TSTEPS |
|
) |
|
|
|
trainer = ImagenTrainer( |
|
imagen = imagen, |
|
split_valid_from_train = True |
|
).to(DEVICE) |
|
|
|
print('=' * 70) |
|
print('Loading model checkpoint...') |
|
print('=' * 70) |
|
|
|
trainer.load('Imagen_POP909_64_dim_12638_steps_0.00983_loss.ckpt') |
|
|
|
print('=' * 70) |
|
print('Done!') |
|
print('=' * 70) |
|
print('Generating...') |
|
print('=' * 70) |
|
|
|
images = trainer.sample(batch_size = input_num_medley_comps, return_pil_images = True) |
|
|
|
print('=' * 70) |
|
print('Done!') |
|
print('=' * 70) |
|
|
|
print('Processing...') |
|
|
|
threshold = 128 |
|
|
|
imgs_array = [] |
|
|
|
for i in images: |
|
arr = np.array(i) |
|
farr = np.where(arr < threshold, 0, 1) |
|
imgs_array.append(farr) |
|
|
|
print('Done!') |
|
|
|
|
|
|
|
print('=' * 70) |
|
print('Converting images to scores...') |
|
|
|
medley_compositions_escores = [] |
|
|
|
for i in imgs_array: |
|
|
|
bmatrix = TPLOTS.images_to_binary_matrix([i]) |
|
|
|
score = TMIDIX.binary_matrix_to_original_escore_notes(bmatrix) |
|
|
|
if input_melody_patch > -1: |
|
score = TMIDIX.add_melody_to_enhanced_score_notes(score, melody_patch=input_melody_patch) |
|
|
|
medley_compositions_escores.append(score) |
|
|
|
print('Done!') |
|
print('=' * 70) |
|
print('Creating medley score...') |
|
|
|
medley_labels = ['Imagen POP Medley Composition #' + str(i+1) for i in range(len(medley_compositions_escores))] |
|
|
|
medley_escore = TMIDIX.escore_notes_medley(medley_compositions_escores, medley_labels, pause_time_value=16) |
|
|
|
|
|
print('Rendering results...') |
|
print('=' * 70) |
|
|
|
print('Sample INTs', medley_escore[:15]) |
|
print('=' * 70) |
|
|
|
fn1 = "Imagen-POP-Music-Medley-Diffusion-Transformer-Composition" |
|
|
|
output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(medley_escore) |
|
|
|
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score, |
|
output_signature = 'Imagen POP Music Medley', |
|
output_file_name = fn1, |
|
track_name='Project Los Angeles', |
|
list_of_MIDI_patches=patches, |
|
timings_multiplier=256 |
|
) |
|
|
|
new_fn = fn1+'.mid' |
|
|
|
|
|
audio = midi_to_colab_audio(new_fn, |
|
soundfont_path=soundfont, |
|
sample_rate=16000, |
|
volume_scale=10, |
|
output_for_gradio=True |
|
) |
|
|
|
print('Done!') |
|
print('=' * 70) |
|
|
|
|
|
|
|
output_midi_title = str(fn1) |
|
output_midi_summary = str(output_score[:3]) |
|
output_midi = str(new_fn) |
|
output_audio = (16000, audio) |
|
|
|
output_plot = TMIDIX.plot_ms_SONG(output_score, plot_title=output_midi, return_plt=True, timings_multiplier=256) |
|
|
|
print('Output MIDI file name:', output_midi) |
|
print('Output MIDI title:', output_midi_title) |
|
print('Output MIDI summary:', output_midi_summary) |
|
print('=' * 70) |
|
|
|
|
|
|
|
|
|
print('-' * 70) |
|
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) |
|
print('-' * 70) |
|
print('Req execution time:', (reqtime.time() - start_time), 'sec') |
|
|
|
return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
PDT = timezone('US/Pacific') |
|
|
|
print('=' * 70) |
|
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) |
|
print('=' * 70) |
|
|
|
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" |
|
|
|
app = gr.Blocks() |
|
|
|
with app: |
|
|
|
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Imagen POP Music Medley Diffusion Transformer</h1>") |
|
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Generate unique POP music medleys with Imagen diffusion transformer</h1>") |
|
gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Imagen-POP-Music-Medley-Diffusion-Transformer&style=flat)\n\n" |
|
"This is a demo for MIDI Images dataset\n\n" |
|
"Please see [MIDI Images](https://huggingface.co/datasets/asigalov61/MIDI-Images) Hugging Face repo for more information\n\n" |
|
) |
|
|
|
gr.Markdown("## Choose medley settings") |
|
|
|
input_num_medley_comps = gr.Slider(1, 10, value=5, step=1, label="Number of medley compositions") |
|
input_melody_patch = gr.Slider(-1, 127, value=40, step=1, label="Medley melody MIDI patch number") |
|
|
|
run_btn = gr.Button("Generate POP Medley", variant="primary") |
|
|
|
gr.Markdown("## Generation results") |
|
|
|
output_midi_title = gr.Textbox(label="Output MIDI title") |
|
output_midi_summary = gr.Textbox(label="Output MIDI summary") |
|
output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio") |
|
output_plot = gr.Plot(label="Output MIDI score plot") |
|
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"]) |
|
|
|
run_event = run_btn.click(Generate_POP_Medley, [input_num_medley_comps, input_melody_patch], |
|
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot]) |
|
|
|
app.queue().launch() |