|
import spaces |
|
from pathlib import Path |
|
import yaml |
|
import time |
|
import uuid |
|
|
|
import numpy as np |
|
import audiotools as at |
|
import argbind |
|
import shutil |
|
import torch |
|
from datetime import datetime |
|
|
|
import gradio as gr |
|
from vampnet.interface import Interface, signal_concat |
|
from vampnet import mask as pmask |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
interface = Interface.default() |
|
init_model_choice = open("DEFAULT_MODEL").read().strip() |
|
|
|
interface.load_finetuned(init_model_choice) |
|
|
|
def to_output(sig): |
|
return sig.sample_rate, sig.cpu().detach().numpy()[0][0] |
|
|
|
MAX_DURATION_S = 10 |
|
def load_audio(file): |
|
print(file) |
|
if isinstance(file, str): |
|
filepath = file |
|
elif isinstance(file, tuple): |
|
|
|
sr, samples = file |
|
samples = samples / np.iinfo(samples.dtype).max |
|
return sr, samples |
|
else: |
|
filepath = file.name |
|
sig = at.AudioSignal.salient_excerpt( |
|
filepath, duration=MAX_DURATION_S |
|
) |
|
sig = at.AudioSignal(filepath) |
|
return to_output(sig) |
|
|
|
|
|
def load_example_audio(): |
|
return load_audio("./assets/example.wav") |
|
|
|
from torch_pitch_shift import pitch_shift, get_fast_shifts |
|
def shift_pitch(signal, interval: int): |
|
signal.samples = pitch_shift( |
|
signal.samples, |
|
shift=interval, |
|
sample_rate=signal.sample_rate |
|
) |
|
return signal |
|
|
|
|
|
@spaces.GPU |
|
def _vamp( |
|
seed, input_audio, model_choice, |
|
pitch_shift_amt, periodic_p, |
|
n_mask_codebooks, periodic_w, onset_mask_width, |
|
dropout, sampletemp, typical_filtering, |
|
typical_mass, typical_min_tokens, top_p, |
|
sample_cutoff, stretch_factor, api=False |
|
): |
|
|
|
t0 = time.time() |
|
interface.to("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"using device {interface.device}") |
|
_seed = seed if seed > 0 else None |
|
if _seed is None: |
|
_seed = int(torch.randint(0, 2**32, (1,)).item()) |
|
at.util.seed(_seed) |
|
|
|
sr, input_audio = input_audio |
|
input_audio = input_audio / np.iinfo(input_audio.dtype).max |
|
|
|
sig = at.AudioSignal(input_audio, sr) |
|
|
|
|
|
interface.load_finetuned(model_choice) |
|
|
|
if pitch_shift_amt != 0: |
|
sig = shift_pitch(sig, pitch_shift_amt) |
|
|
|
codes = interface.encode(sig) |
|
|
|
mask = interface.build_mask( |
|
codes, sig, |
|
rand_mask_intensity=1.0, |
|
prefix_s=0.0, |
|
suffix_s=0.0, |
|
periodic_prompt=int(periodic_p), |
|
periodic_prompt_width=periodic_w, |
|
onset_mask_width=onset_mask_width, |
|
_dropout=dropout, |
|
upper_codebook_mask=int(n_mask_codebooks), |
|
) |
|
|
|
|
|
|
|
interface.set_chunk_size(10.0) |
|
codes, mask = interface.vamp( |
|
codes, mask, |
|
batch_size=1 if api else 1, |
|
feedback_steps=1, |
|
_sampling_steps=12 if sig.duration <6.0 else 24, |
|
time_stretch_factor=stretch_factor, |
|
return_mask=True, |
|
temperature=sampletemp, |
|
typical_filtering=typical_filtering, |
|
typical_mass=typical_mass, |
|
typical_min_tokens=typical_min_tokens, |
|
top_p=None, |
|
seed=_seed, |
|
sample_cutoff=1.0, |
|
) |
|
print(f"vamp took {time.time() - t0} seconds") |
|
|
|
sig = interface.decode(codes) |
|
|
|
return to_output(sig) |
|
|
|
def vamp(data): |
|
return _vamp( |
|
seed=data[seed], |
|
input_audio=data[input_audio], |
|
model_choice=data[model_choice], |
|
pitch_shift_amt=data[pitch_shift_amt], |
|
periodic_p=data[periodic_p], |
|
n_mask_codebooks=data[n_mask_codebooks], |
|
periodic_w=data[periodic_w], |
|
onset_mask_width=data[onset_mask_width], |
|
dropout=data[dropout], |
|
sampletemp=data[sampletemp], |
|
typical_filtering=data[typical_filtering], |
|
typical_mass=data[typical_mass], |
|
typical_min_tokens=data[typical_min_tokens], |
|
top_p=data[top_p], |
|
sample_cutoff=data[sample_cutoff], |
|
stretch_factor=data[stretch_factor], |
|
api=False, |
|
) |
|
|
|
def api_vamp(data): |
|
return _vamp( |
|
seed=data[seed], |
|
input_audio=data[input_audio], |
|
model_choice=data[model_choice], |
|
pitch_shift_amt=data[pitch_shift_amt], |
|
periodic_p=data[periodic_p], |
|
n_mask_codebooks=data[n_mask_codebooks], |
|
periodic_w=data[periodic_w], |
|
onset_mask_width=data[onset_mask_width], |
|
dropout=data[dropout], |
|
sampletemp=data[sampletemp], |
|
typical_filtering=data[typical_filtering], |
|
typical_mass=data[typical_mass], |
|
typical_min_tokens=data[typical_min_tokens], |
|
top_p=data[top_p], |
|
sample_cutoff=data[sample_cutoff], |
|
stretch_factor=data[stretch_factor], |
|
api=True, |
|
) |
|
|
|
OUT_DIR = Path("gradio-outputs") |
|
OUT_DIR.mkdir(exist_ok=True) |
|
def harp_vamp(input_audio_file, periodic_p, n_mask_codebooks): |
|
sig = at.AudioSignal(input_audio_file) |
|
sr, samples = sig.sample_rate, sig.samples[0][0].detach().cpu().numpy() |
|
|
|
samples = (samples * np.iinfo(np.int32).max).astype(np.int32) |
|
sr, samples = _vamp( |
|
seed=0, |
|
input_audio=(sr, samples), |
|
model_choice=init_model_choice, |
|
pitch_shift_amt=0, |
|
periodic_p=periodic_p, |
|
n_mask_codebooks=n_mask_codebooks, |
|
periodic_w=1, |
|
onset_mask_width=0, |
|
dropout=0.0, |
|
sampletemp=1.0, |
|
typical_filtering=True, |
|
typical_mass=0.15, |
|
typical_min_tokens=64, |
|
top_p=0.0, |
|
sample_cutoff=1.0, |
|
stretch_factor=1, |
|
) |
|
|
|
sig = at.AudioSignal(samples, sr) |
|
|
|
|
|
for p in OUT_DIR.glob("*"): |
|
p.unlink() |
|
OUT_DIR.mkdir(exist_ok=True) |
|
outpath = OUT_DIR / f"{uuid.uuid4()}.wav" |
|
sig.write(outpath) |
|
from pyharp import AudioLabel, LabelList |
|
output_labels = LabelList() |
|
output_labels.append(AudioLabel(label='~', t=0.0, amplitude=0.5, description='generated audio')) |
|
return outpath, output_labels |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
manual_audio_upload = gr.File( |
|
label=f"upload some audio (will be randomly trimmed to max of 100s)", |
|
file_types=["audio"] |
|
) |
|
load_example_audio_button = gr.Button("or load example audio") |
|
|
|
input_audio = gr.Audio( |
|
label="input audio", |
|
interactive=False, |
|
type="numpy", |
|
) |
|
|
|
audio_mask = gr.Audio( |
|
label="audio mask (listen to this to hear the mask hints)", |
|
interactive=False, |
|
type="numpy", |
|
) |
|
|
|
|
|
load_example_audio_button.click( |
|
fn=load_example_audio, |
|
inputs=[], |
|
outputs=[ input_audio] |
|
) |
|
|
|
manual_audio_upload.change( |
|
fn=load_audio, |
|
inputs=[manual_audio_upload], |
|
outputs=[ input_audio] |
|
) |
|
|
|
|
|
|
|
with gr.Column(): |
|
with gr.Accordion("manual controls", open=True): |
|
periodic_p = gr.Slider( |
|
label="periodic prompt", |
|
minimum=0, |
|
maximum=13, |
|
step=1, |
|
value=7, |
|
) |
|
|
|
onset_mask_width = gr.Slider( |
|
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ", |
|
minimum=0, |
|
maximum=100, |
|
step=1, |
|
value=0, visible=False |
|
) |
|
|
|
n_mask_codebooks = gr.Slider( |
|
label="compression prompt ", |
|
value=3, |
|
minimum=1, |
|
maximum=14, |
|
step=1, |
|
) |
|
|
|
maskimg = gr.Image( |
|
label="mask image", |
|
interactive=False, |
|
type="filepath" |
|
) |
|
|
|
with gr.Accordion("extras ", open=False): |
|
pitch_shift_amt = gr.Slider( |
|
label="pitch shift amount (semitones)", |
|
minimum=-12, |
|
maximum=12, |
|
step=1, |
|
value=0, |
|
) |
|
|
|
stretch_factor = gr.Slider( |
|
label="time stretch factor", |
|
minimum=0, |
|
maximum=8, |
|
step=1, |
|
value=1, |
|
) |
|
|
|
periodic_w = gr.Slider( |
|
label="periodic prompt width (steps, 1 step ~= 10milliseconds)", |
|
minimum=1, |
|
maximum=20, |
|
step=1, |
|
value=1, |
|
) |
|
|
|
|
|
with gr.Accordion("sampling settings", open=False): |
|
sampletemp = gr.Slider( |
|
label="sample temperature", |
|
minimum=0.1, |
|
maximum=10.0, |
|
value=1.0, |
|
step=0.001 |
|
) |
|
|
|
top_p = gr.Slider( |
|
label="top p (0.0 = off)", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.0 |
|
) |
|
typical_filtering = gr.Checkbox( |
|
label="typical filtering ", |
|
value=True |
|
) |
|
typical_mass = gr.Slider( |
|
label="typical mass (should probably stay between 0.1 and 0.5)", |
|
minimum=0.01, |
|
maximum=0.99, |
|
value=0.15 |
|
) |
|
typical_min_tokens = gr.Slider( |
|
label="typical min tokens (should probably stay between 1 and 256)", |
|
minimum=1, |
|
maximum=256, |
|
step=1, |
|
value=64 |
|
) |
|
sample_cutoff = gr.Slider( |
|
label="sample cutoff", |
|
minimum=0.0, |
|
maximum=0.9, |
|
value=1.0, |
|
step=0.01 |
|
) |
|
|
|
|
|
dropout = gr.Slider( |
|
label="mask dropout", |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.01, |
|
value=0.0 |
|
) |
|
|
|
|
|
seed = gr.Number( |
|
label="seed (0 for random)", |
|
value=0, |
|
precision=0, |
|
) |
|
|
|
|
|
|
|
with gr.Column(): |
|
|
|
model_choice = gr.Dropdown( |
|
label="model choice", |
|
choices=list(interface.available_models()), |
|
value=init_model_choice, |
|
visible=True |
|
) |
|
|
|
|
|
vamp_button = gr.Button("generate (vamp)!!!") |
|
|
|
|
|
audio_outs = [] |
|
use_as_input_btns = [] |
|
for i in range(1): |
|
with gr.Column(): |
|
audio_outs.append(gr.Audio( |
|
label=f"output audio {i+1}", |
|
interactive=False, |
|
type="numpy" |
|
)) |
|
use_as_input_btns.append( |
|
gr.Button(f"use as input (feedback)") |
|
) |
|
|
|
thank_you = gr.Markdown("") |
|
|
|
|
|
|
|
|
|
|
|
_inputs = { |
|
input_audio, |
|
sampletemp, |
|
top_p, |
|
periodic_p, periodic_w, |
|
dropout, |
|
stretch_factor, |
|
onset_mask_width, |
|
typical_filtering, |
|
typical_mass, |
|
typical_min_tokens, |
|
seed, |
|
model_choice, |
|
n_mask_codebooks, |
|
pitch_shift_amt, |
|
sample_cutoff, |
|
} |
|
|
|
|
|
vamp_button.click( |
|
fn=vamp, |
|
inputs=_inputs, |
|
outputs=[audio_outs[0]], |
|
) |
|
|
|
api_vamp_button = gr.Button("api vamp", visible=True) |
|
api_vamp_button.click( |
|
fn=api_vamp, |
|
inputs=_inputs, |
|
outputs=[audio_outs[0]], |
|
api_name="vamp" |
|
) |
|
|
|
from pyharp import ModelCard, build_endpoint |
|
card = ModelCard( |
|
name="vampnet", |
|
description="vampnet! is a model for generating audio from audio", |
|
author="hugo flores garcía", |
|
tags=["music generation"], |
|
midi_in=False, |
|
midi_out=False |
|
) |
|
|
|
|
|
app = build_endpoint(model_card=card, |
|
components=[ |
|
periodic_p, |
|
n_mask_codebooks, |
|
], |
|
process_fn=harp_vamp) |
|
|
|
|
|
|
|
try: |
|
demo.queue() |
|
demo.launch(share=True) |
|
except KeyboardInterrupt: |
|
shutil.rmtree("gradio-outputs", ignore_errors=True) |
|
raise |