MusiConGen / app.py
fffiloni's picture
Update app.py
df8718b verified
raw
history blame
6.87 kB
import gradio as gr
import spaces
import huggingface_hub
import numpy as np
import pandas as pd
import os
import shutil
import torch
from audiocraft.data.audio import audio_write
import audiocraft.models
# download models
huggingface_hub.hf_hub_download(
repo_id='Cyan0731/MusiConGen',
filename='compression_state_dict.bin',
local_dir='./ckpt/musicongen'
)
huggingface_hub.hf_hub_download(
repo_id='Cyan0731/MusiConGen',
filename='state_dict.bin',
local_dir='./ckpt/musicongen'
)
def print_directory_contents(path):
for root, dirs, files in os.walk(path):
level = root.replace(path, '').count(os.sep)
indent = ' ' * 4 * (level)
print(f"{indent}{os.path.basename(root)}/")
subindent = ' ' * 4 * (level + 1)
for f in files:
print(f"{subindent}{f}")
def check_outputs_folder(folder_path):
# Check if the folder exists
if os.path.exists(folder_path) and os.path.isdir(folder_path):
# Delete all contents inside the folder
for filename in os.listdir(folder_path):
file_path = os.path.join(folder_path, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove file or link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove directory
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
else:
print(f'The folder {folder_path} does not exist.')
def check_for_wav_in_outputs():
# Define the path to the outputs folder
outputs_folder = './output_samples/example_1'
# Check if the outputs folder exists
if not os.path.exists(outputs_folder):
return None
# Check if there is a .mp4 file in the outputs folder
mp4_files = [f for f in os.listdir(outputs_folder) if f.endswith('.wav')]
# Return the path to the mp4 file if it exists
if mp4_files:
return os.path.join(outputs_folder, mp4_files[0])
else:
return None
@spaces.GPU()
def infer(prompt_in, chords, duration, bpms):
# check if 'outputs' dir exists and empty it if necessary
check_outputs_folder('./output_samples/example_1')
# set hparams
output_dir = 'example_1' ### change this output directory
duration = duration
num_samples = 1
bs = 1
# load your model
musicgen = audiocraft.models.MusicGen.get_pretrained('./ckpt/musicongen') ### change this path
musicgen.set_generation_params(duration=duration, extend_stride=duration//2, top_k = 250)
chords = [chords]
descriptions = [prompt_in] * num_samples
bpms = [bpms] * num_samples
meters = [4] * num_samples
wav = []
for i in range(num_samples//bs):
print(f"starting {i} batch...")
temp = musicgen.generate_with_chords_and_beats(
descriptions[i*bs:(i+1)*bs],
chords[i*bs:(i+1)*bs],
bpms[i*bs:(i+1)*bs],
meters[i*bs:(i+1)*bs]
)
wav.extend(temp.cpu())
# save and display generated audio
for idx, one_wav in enumerate(wav):
sav_path = os.path.join('./output_samples', output_dir, chords[idx] + "|" + descriptions[idx]).replace(" ", "_")
audio_write(sav_path, one_wav.cpu(), musicgen.sample_rate, strategy='loudness', loudness_compressor=True)
# Print the outputs directory contents
print_directory_contents('./output_samples')
wav_file_path = check_for_wav_in_outputs()
print(wav_file_path)
return wav_file_path
css="""
#col-container{
max-width: 800px;
margin: 0 auto;
}
#chords-examples button{
font-size: 20px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# MusiConGen")
gr.Markdown("## Rhythm and Chord Control for Transformer-Based Text-to-Music Generation")
with gr.Column():
with gr.Group():
prompt_in = gr.Textbox(label="Music description", value="A smooth acid jazz track with a laid-back groove, silky electric piano, and a cool bass, providing a modern take on jazz. Instruments: electric piano, bass, drums.")
with gr.Row():
chords = gr.Textbox(label="Chords progression", value='B:min D F#:min E', scale=1.75)
duration = gr.Slider(label="Sample duration", minimum=4, maximum=30, step=1, value=30)
bpms = gr.Slider(label="BPMs", minimum=50, maximum=220, step=1, value=120)
submit_btn = gr.Button("Submit")
wav_out = gr.Audio(label="Wav Result")
with gr.Row():
gr.Examples(
label = "Audio description examples",
examples = [
["A laid-back blues shuffle with a relaxed tempo, warm guitar tones, and a comfortable groove, perfect for a slow dance or a night in. Instruments: electric guitar, bass, drums."],
["A smooth acid jazz track with a laid-back groove, silky electric piano, and a cool bass, providing a modern take on jazz. Instruments: electric piano, bass, drums."],
["A classic rock n' roll tune with catchy guitar riffs, driving drums, and a pulsating bass line, reminiscent of the golden era of rock. Instruments: electric guitar, bass, drums."],
["A high-energy funk tune with slap bass, rhythmic guitar riffs, and a tight horn section, guaranteed to get you grooving. Instruments: bass, guitar, trumpet, saxophone, drums."],
["A heavy metal onslaught with double kick drum madness, aggressive guitar riffs, and an unrelenting bass, embodying the spirit of metal. Instruments: electric guitar, bass guitar, drums."]
],
inputs = [prompt_in]
)
gr.Examples(
label = "Chords progression examples",
elem_id = "chords-examples",
examples = ['C G A:min F',
'A:min F C G',
'C F G F',
'C A:min F G',
'D:min G C A:min',
'D:min7 G:7 C:maj7 C:maj7',
'F G E:min A:min',
'B:min D F#:min E',
'F G E A:min',
'C Bb F C'
],
inputs = [chords]
)
submit_btn.click(
fn = infer,
inputs = [prompt_in, chords, duration, bpms],
outputs = [wav_out]
)
demo.launch(show_api=False, show_error=True)