import torch import random import numpy as np import gradio as gr import librosa import spaces from accelerate import Accelerator from transformers import T5Tokenizer, T5EncoderModel from diffusers import DDIMScheduler from src.models.conditioners import MaskDiT from src.models.controlnet import DiTControlNet from src.models.conditions import Conditioner from src.modules.autoencoder_wrapper import Autoencoder from src.inference_controlnet import inference from src.utils import load_yaml_with_includes # Load model and configs def load_models(config_name, ckpt_path, controlnet_path, vae_path, device): params = load_yaml_with_includes(config_name) # Load codec model autoencoder = Autoencoder(ckpt_path=vae_path, model_type=params['autoencoder']['name'], quantization_first=params['autoencoder']['q_first']).to(device) autoencoder.eval() # Load text encoder tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model']) text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model']).to(device) text_encoder.eval() # Load main U-Net model unet = MaskDiT(**params['model']).to(device) unet.load_state_dict(torch.load(ckpt_path, map_location='cpu')['model']) unet.eval() controlnet_config = params['model'].copy() controlnet_config.update(params['controlnet']) controlnet = DiTControlNet(**controlnet_config).to(device) controlnet.eval() controlnet.load_state_dict(torch.load(controlnet_path, map_location='cpu')['model']) conditioner = Conditioner(**params['conditioner']).to(device) accelerator = Accelerator(mixed_precision="fp16") unet, controlnet = accelerator.prepare(unet, controlnet) # Load noise scheduler noise_scheduler = DDIMScheduler(**params['diff']) latents = torch.randn((1, 128, 128), device=device) noise = torch.randn_like(latents) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=device) _ = noise_scheduler.add_noise(latents, noise, timesteps) return autoencoder, unet, controlnet, conditioner, tokenizer, text_encoder, noise_scheduler, params MAX_SEED = np.iinfo(np.int32).max # Model and config paths config_name = 'ckpts/controlnet/energy_l.yml' ckpt_path = 'ckpts/s3/ezaudio_s3_l.pt' controlnet_path = 'ckpts/controlnet/s3_l_energy.pt' vae_path = 'ckpts/vae/1m.pt' # save_path = 'output/' # os.makedirs(save_path, exist_ok=True) device = 'cuda' if torch.cuda.is_available() else 'cpu' (autoencoder, unet, controlnet, conditioner, tokenizer, text_encoder, noise_scheduler, params) = load_models(config_name, ckpt_path, controlnet_path, vae_path, device) @spaces.GPU def generate_audio(text, audio_path, surpass_noise, guidance_scale, guidance_rescale, ddim_steps, eta, conditioning_scale, random_seed, randomize_seed): sr = params['autoencoder']['sr'] gt, _ = librosa.load(audio_path, sr=sr) gt = gt / (np.max(np.abs(gt)) + 1e-9) # Normalize audio if surpass_noise > 0: mask = np.abs(gt) <= surpass_noise gt[mask] = 0 original_length = len(gt) # Ensure the audio is of the correct length by padding or trimming duration_seconds = min(len(gt) / sr, 10) quantized_duration = np.ceil(duration_seconds * 2) / 2 # This rounds to the nearest 0.5 seconds num_samples = int(quantized_duration * sr) audio_frames = round(num_samples / sr * params['autoencoder']['latent_sr']) if len(gt) < num_samples: padding = num_samples - len(gt) gt = np.pad(gt, (0, padding), 'constant') else: gt = gt[:num_samples] gt_audio = torch.tensor(gt).unsqueeze(0).unsqueeze(1).to(device) gt = autoencoder(audio=gt_audio) condition = conditioner(gt_audio.squeeze(1), gt.shape) # Handle random seed if randomize_seed: random_seed = random.randint(0, MAX_SEED) # Perform inference pred = inference(autoencoder, unet, controlnet, None, None, condition, tokenizer, text_encoder, params, noise_scheduler, text, neg_text=None, audio_frames=audio_frames, guidance_scale=guidance_scale, guidance_rescale=guidance_rescale, ddim_steps=ddim_steps, eta=eta, random_seed=random_seed, conditioning_scale=conditioning_scale, device=device) pred = pred.cpu().numpy().squeeze(0).squeeze(0)[:original_length] return sr, pred # CSS styling (optional) css = """ #col-container { margin: 0 auto; max-width: 1280px; } """ examples_energy = [ ["Dog barking in the background", "reference.mp3"], ["Duck quacking", "reference2.mp3"], ["Truck honking on the street", "reference3.mp3"] ] # Gradio Blocks layout with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: gr.Markdown(""" # EzAudio-ControlNet: Interactive and Creative Control for Text-to-Audio Generation EzAudio-ControlNet enables control over the timing of sound effects within audio generation. Learn more about 😈**EzAudio** on the [EzAudio Homepage](https://haidog-yaqub.github.io/EzAudio-Page/). Explore **Vanilla Text-to-Audio**, **Editing**, and **Inpainting** features on the [🤗EzAudio Space](https://huggingface.co/spaces/OpenSound/EzAudio). """) with gr.Row(): # Input for the text prompt (used for generating new audio) text_input = gr.Textbox( label="Text Prompt", show_label=True, max_lines=2, placeholder="Describe the sound you want to generate", value="A dog barking in the background", scale=4 ) # Button to generate the audio generate_button = gr.Button("Generate") # Audio input to use as base audio_file_input = gr.Audio(label="Upload Reference Audio (less than 10s)", value='reference.mp3', type="filepath") # Output Component for the generated audio generated_audio_output = gr.Audio(label="Generated Audio", type="numpy") with gr.Accordion("Advanced Settings", open=False): # Length of the generated audio surpass_noise = gr.Slider(minimum=0, maximum=0.1, step=0.01, value=0.0, label="Noise Threshold (Amplitude)") guidance_scale = gr.Slider(minimum=1.0, maximum=10.0, step=0.5, value=5.0, label="Guidance Scale") guidance_rescale = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Guidance Rescale") ddim_steps = gr.Slider(minimum=25, maximum=200, step=5, value=50, label="DDIM Steps") eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="Eta") conditioning_scale = gr.Slider(minimum=0.0, maximum=2.0, step=0.25, value=1.0, label="Conditioning Scale") random_seed = gr.Slider(minimum=0, maximum=10000, step=1, value=0, label="Random Seed") randomize_seed = gr.Checkbox(label="Randomize Seed (Disable Seed)", value=True) gr.Examples( examples=examples_energy, inputs=[text_input, audio_file_input] ) # Link the inputs to the function generate_button.click( fn=generate_audio, inputs=[ text_input, audio_file_input, surpass_noise, guidance_scale, guidance_rescale, ddim_steps, eta, conditioning_scale, random_seed, randomize_seed ], outputs=[generated_audio_output] ) text_input.submit( fn=generate_audio, inputs=[ text_input, audio_file_input, surpass_noise, guidance_scale, guidance_rescale, ddim_steps, eta, conditioning_scale, random_seed, randomize_seed ], outputs=[generated_audio_output] ) # Launch the Gradio demo demo.launch()