# Will be fixed soon, but meanwhile: import os if os.getenv('SPACES_ZERO_GPU') == "true": os.environ['SPACES_ZERO_GPU'] = "1" import gradio as gr import random import torch import os from torch import inference_mode from typing import Optional, List import numpy as np from models import load_model import utils import spaces import huggingface_hub from inversion_utils import inversion_forward_process, inversion_reverse_process LDM2 = "cvssp/audioldm2" MUSIC = "cvssp/audioldm2-music" LDM2_LARGE = "cvssp/audioldm2-large" STABLEAUD = "stabilityai/stable-audio-open-1.0" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ldm2 = load_model(model_id=LDM2, device=device) ldm2_large = load_model(model_id=LDM2_LARGE, device=device) ldm2_music = load_model(model_id=MUSIC, device=device) ldm_stableaud = load_model(model_id=STABLEAUD, device=device, token=os.getenv('PRIV_TOKEN')) def randomize_seed_fn(seed, randomize_seed): if randomize_seed: seed = random.randint(0, np.iinfo(np.int32).max) torch.manual_seed(seed) return seed def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src, duration, save_compute): # ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device) with inference_mode(): w0 = ldm_stable.vae_encode(x0) # find Zs and wts - forward process _, zs, wts, extra_info = inversion_forward_process(ldm_stable, w0, etas=1, prompts=[prompt_src], cfg_scales=[cfg_scale_src], num_inference_steps=num_diffusion_steps, numerical_fix=True, duration=duration, save_compute=save_compute) return zs, wts, extra_info def sample(ldm_stable, zs, wts, extra_info, prompt_tar, tstart, cfg_scale_tar, duration, save_compute): # reverse process (via Zs and wT) tstart = torch.tensor(tstart, dtype=torch.int) w0, _ = inversion_reverse_process(ldm_stable, xT=wts, tstart=tstart, etas=1., prompts=[prompt_tar], neg_prompts=[""], cfg_scales=[cfg_scale_tar], zs=zs[:int(tstart)], duration=duration, extra_info=extra_info, save_compute=save_compute) # vae decode image with inference_mode(): x0_dec = ldm_stable.vae_decode(w0) if 'stable-audio' not in ldm_stable.model_id: if x0_dec.dim() < 4: x0_dec = x0_dec[None, :, :, :] with torch.no_grad(): audio = ldm_stable.decode_to_mel(x0_dec) else: audio = x0_dec.squeeze(0).T return (ldm_stable.get_sr(), audio.squeeze().cpu().numpy()) def get_duration(input_audio, model_id: str, do_inversion: bool, wts: Optional[torch.Tensor], zs: Optional[torch.Tensor], extra_info: Optional[List], saved_inv_model: str, source_prompt: str = "", target_prompt: str = "", steps: int = 200, cfg_scale_src: float = 3.5, cfg_scale_tar: float = 12, t_start: int = 45, randomize_seed: bool = True, save_compute: bool = True, oauth_token: Optional[gr.OAuthToken] = None): if model_id == LDM2: factor = 1 elif model_id == LDM2_LARGE: factor = 2.5 elif model_id == STABLEAUD: factor = 3.2 else: # MUSIC factor = 1 forwards = 0 if do_inversion or randomize_seed: forwards = steps if source_prompt == "" else steps * 2 # x2 when there is a prompt text forwards += int(t_start / 100 * steps) * 2 duration = min(utils.get_duration(input_audio), utils.MAX_DURATION) time_for_maxlength = factor * forwards * 0.15 # 0.25 is the time per forward pass print('expected time:', time_for_maxlength / utils.MAX_DURATION * duration) spare_time = 5 return max(10, time_for_maxlength / utils.MAX_DURATION * duration + spare_time) def verify_model_params(model_id: str, input_audio, src_prompt: str, tar_prompt: str, cfg_scale_src: float, oauth_token: gr.OAuthToken | None): if input_audio is None: raise gr.Error('Input audio missing!') if tar_prompt == "": raise gr.Error("Please provide a target prompt to edit the audio.") if src_prompt != "": if model_id == STABLEAUD and cfg_scale_src != 1: gr.Info("Consider using Source Guidance Scale=1 for Stable Audio Open 1.0.") elif model_id != STABLEAUD and cfg_scale_src != 3: gr.Info(f"Consider using Source Guidance Scale=3 for {model_id}.") if model_id == STABLEAUD: if oauth_token is None: raise gr.Error("You must be logged in to use Stable Audio Open 1.0. Please log in and try again.") try: huggingface_hub.get_hf_file_metadata(huggingface_hub.hf_hub_url(STABLEAUD, 'transformer/config.json'), token=oauth_token.token) print('Has Access') # except huggingface_hub.utils._errors.GatedRepoError: except huggingface_hub.errors.GatedRepoError: raise gr.Error("You need to accept the license agreement to use Stable Audio Open 1.0. " "Visit the " "model page to get access.") @spaces.GPU(duration=get_duration) def edit(input_audio, model_id: str, do_inversion: bool, wts: Optional[torch.Tensor], zs: Optional[torch.Tensor], extra_info: Optional[List], saved_inv_model: str, source_prompt: str = "", target_prompt: str = "", steps: int = 200, cfg_scale_src: float = 3.5, cfg_scale_tar: float = 12, t_start: int = 45, randomize_seed: bool = True, save_compute: bool = True, oauth_token: Optional[gr.OAuthToken] = None): print(model_id) if model_id == LDM2: ldm_stable = ldm2 elif model_id == LDM2_LARGE: ldm_stable = ldm2_large elif model_id == STABLEAUD: ldm_stable = ldm_stableaud else: # MUSIC ldm_stable = ldm2_music ldm_stable.model.scheduler.set_timesteps(steps, device=device) # If the inversion was done for a different model, we need to re-run the inversion if not do_inversion and (saved_inv_model is None or saved_inv_model != model_id): do_inversion = True if input_audio is None: raise gr.Error('Input audio missing!') x0, _, duration = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device, stft=('stable-audio' not in ldm_stable.model_id), model_sr=ldm_stable.get_sr()) if wts is None or zs is None: do_inversion = True if do_inversion or randomize_seed: # always re-run inversion zs_tensor, wts_tensor, extra_info_list = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src, duration=duration, save_compute=save_compute) wts = wts_tensor zs = zs_tensor extra_info = extra_info_list saved_inv_model = model_id do_inversion = False else: wts_tensor = wts.to(device) zs_tensor = zs.to(device) extra_info_list = [e.to(device) for e in extra_info if e is not None] output = sample(ldm_stable, zs_tensor, wts_tensor, extra_info_list, prompt_tar=target_prompt, tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar, duration=duration, save_compute=save_compute) return output, wts.cpu(), zs.cpu(), [e.cpu() for e in extra_info if e is not None], saved_inv_model, do_inversion # return output, wtszs_file, saved_inv_model, do_inversion def get_example(): case = [ ['Examples/Beethoven.mp3', '', 'A recording of an arcade game soundtrack.', 45, 'cvssp/audioldm2-music', '27s', 'Examples/Beethoven_arcade.mp3', ], ['Examples/Beethoven.mp3', 'A high quality recording of wind instruments and strings playing.', 'A high quality recording of a piano playing.', 45, 'cvssp/audioldm2-music', '27s', 'Examples/Beethoven_piano.mp3', ], ['Examples/Beethoven.mp3', '', 'Heavy Rock.', 40, 'stabilityai/stable-audio-open-1.0', '27s', 'Examples/Beethoven_rock.mp3', ], ['Examples/ModalJazz.mp3', 'Trumpets playing alongside a piano, bass and drums in an upbeat old-timey cool jazz song.', 'A banjo playing alongside a piano, bass and drums in an upbeat old-timey cool country song.', 45, 'cvssp/audioldm2-music', '106s', 'Examples/ModalJazz_banjo.mp3',], ['Examples/Shadows.mp3', '', '8-bit arcade game soundtrack.', 40, 'stabilityai/stable-audio-open-1.0', '34s', 'Examples/Shadows_arcade.mp3',], ['Examples/Cat.mp3', '', 'A dog barking.', 75, 'cvssp/audioldm2-large', '10s', 'Examples/Cat_dog.mp3',] ] return case intro = """

ZETA Editing ๐ŸŽง

Zero-Shot Text-Based Audio Editing Using DDPM Inversion ๐ŸŽ›๏ธ

[Paper] |  [Project page] |  [Code]

For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. Duplicate Space

NEW - 15.10.24: You can now edit using Stable Audio Open 1.0. You must be logged in after accepting the license agreement to use it.

NEW - 15.10.24: Parallel editing is enabled by default. To disable, uncheck Efficient editing under "More Options". Saves a bit of time.

""" help = """
Instructions:
""" css = '.gradio-container {max-width: 1000px !important; padding-top: 1.5rem !important;}' \ '.audio-upload .wrap {min-height: 0px;}' # with gr.Blocks(css='style.css') as demo: with gr.Blocks(css=css) as demo: def reset_do_inversion(do_inversion_user, do_inversion): # do_inversion = gr.State(value=True) do_inversion = True do_inversion_user = True return do_inversion_user, do_inversion # handle the case where the user clicked the button but the inversion was not done def clear_do_inversion_user(do_inversion_user): do_inversion_user = False return do_inversion_user def post_match_do_inversion(do_inversion_user, do_inversion): if do_inversion_user: do_inversion = True do_inversion_user = False return do_inversion_user, do_inversion gr.HTML(intro) wts = gr.State() zs = gr.State() extra_info = gr.State() saved_inv_model = gr.State() do_inversion = gr.State(value=True) # To save some runtime when editing the same thing over and over do_inversion_user = gr.State(value=False) with gr.Group(): gr.Markdown("๐Ÿ’ก **note**: input longer than **30 sec** is automatically trimmed " "(for unlimited input, see the Help section below)") with gr.Row(equal_height=True): input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", editable=True, label="Input Audio", interactive=True, scale=1, format='wav', elem_classes=['audio-upload']) output_audio = gr.Audio(label="Edited Audio", interactive=False, scale=1, format='wav') with gr.Row(): tar_prompt = gr.Textbox(label="Prompt", info="Describe your desired edited output", placeholder="a recording of a happy upbeat arcade game soundtrack", lines=2, interactive=True) with gr.Row(): t_start = gr.Slider(minimum=15, maximum=85, value=45, step=1, label="T-start (%)", interactive=True, scale=3, info="Lower T-start -> closer to original audio. Higher T-start -> stronger edit.") model_id = gr.Dropdown(label="Model Version", choices=[LDM2, LDM2_LARGE, MUSIC, STABLEAUD], info="Choose a checkpoint suitable for your audio and edit", value="cvssp/audioldm2-music", interactive=True, type="value", scale=2) with gr.Row(): submit = gr.Button("Edit", variant="primary", scale=3) gr.LoginButton(value="Login to HF (For Stable Audio)", scale=1) with gr.Accordion("More Options", open=False): with gr.Row(): src_prompt = gr.Textbox(label="Source Prompt", lines=2, interactive=True, info="Optional: Describe the original audio input", placeholder="A recording of a happy upbeat classical music piece",) with gr.Row(equal_height=True): cfg_scale_src = gr.Number(value=3, minimum=0.5, maximum=25, precision=None, label="Source Guidance Scale", interactive=True, scale=1) cfg_scale_tar = gr.Number(value=12, minimum=0.5, maximum=25, precision=None, label="Target Guidance Scale", interactive=True, scale=1) steps = gr.Number(value=50, step=1, minimum=10, maximum=300, info="Higher values (e.g. 200) yield higher-quality generation.", label="Num Diffusion Steps", interactive=True, scale=2) with gr.Row(equal_height=True): seed = gr.Number(value=0, precision=0, label="Seed", interactive=True) randomize_seed = gr.Checkbox(label='Randomize seed', value=False) save_compute = gr.Checkbox(label='Efficient editing', value=True) length = gr.Number(label="Length", interactive=False, visible=False) with gr.Accordion("Help๐Ÿ’ก", open=False): gr.HTML(help) submit.click( fn=verify_model_params, inputs=[model_id, input_audio, src_prompt, tar_prompt, cfg_scale_src], outputs=[] ).success( fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=[seed], queue=False ).then( fn=clear_do_inversion_user, inputs=[do_inversion_user], outputs=[do_inversion_user] ).then( fn=edit, inputs=[input_audio, model_id, do_inversion, wts, zs, extra_info, saved_inv_model, src_prompt, tar_prompt, steps, cfg_scale_src, cfg_scale_tar, t_start, randomize_seed, save_compute, ], outputs=[output_audio, wts, zs, extra_info, saved_inv_model, do_inversion] ).success( fn=post_match_do_inversion, inputs=[do_inversion_user, do_inversion], outputs=[do_inversion_user, do_inversion] ) # If sources changed we have to rerun inversion gr.on( triggers=[input_audio.change, src_prompt.change, model_id.change, cfg_scale_src.change, steps.change, save_compute.change], fn=reset_do_inversion, inputs=[do_inversion_user, do_inversion], outputs=[do_inversion_user, do_inversion] ) gr.Examples( label="Examples", examples=get_example(), inputs=[input_audio, src_prompt, tar_prompt, t_start, model_id, length, output_audio], outputs=[output_audio] ) demo.queue() demo.launch(state_session_capacity=15)