OpenSound commited on
Commit
575f48c
1 Parent(s): 9d3cb0a

Upload controlnet_app.py

Browse files
Files changed (1) hide show
  1. controlnet_app.py +190 -0
controlnet_app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+ import gradio as gr
5
+ import librosa
6
+ from accelerate import Accelerator
7
+ from transformers import T5Tokenizer, T5EncoderModel
8
+ from diffusers import DDIMScheduler
9
+ from src.models.conditioners import MaskDiT
10
+ from src.models.controlnet import DiTControlNet
11
+ from src.models.conditions import Conditioner
12
+ from src.modules.autoencoder_wrapper import Autoencoder
13
+ from src.inference_controlnet import inference
14
+ from src.utils import load_yaml_with_includes
15
+
16
+
17
+ # Load model and configs
18
+ def load_models(config_name, ckpt_path, controlnet_path, vae_path, device):
19
+ params = load_yaml_with_includes(config_name)
20
+
21
+ # Load codec model
22
+ autoencoder = Autoencoder(ckpt_path=vae_path,
23
+ model_type=params['autoencoder']['name'],
24
+ quantization_first=params['autoencoder']['q_first']).to(device)
25
+ autoencoder.eval()
26
+
27
+ # Load text encoder
28
+ tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model'])
29
+ text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model']).to(device)
30
+ text_encoder.eval()
31
+
32
+ # Load main U-Net model
33
+ unet = MaskDiT(**params['model']).to(device)
34
+ unet.load_state_dict(torch.load(ckpt_path, map_location='cpu')['model'])
35
+ unet.eval()
36
+
37
+ controlnet_config = params['model'].copy()
38
+ controlnet_config.update(params['controlnet'])
39
+ controlnet = DiTControlNet(**controlnet_config).to(device)
40
+ controlnet.eval()
41
+ controlnet.load_state_dict(torch.load(controlnet_path, map_location='cpu')['model'])
42
+ conditioner = Conditioner(**params['conditioner']).to(device)
43
+
44
+ accelerator = Accelerator(mixed_precision="fp16")
45
+ unet, controlnet = accelerator.prepare(unet, controlnet)
46
+
47
+ # Load noise scheduler
48
+ noise_scheduler = DDIMScheduler(**params['diff'])
49
+
50
+ latents = torch.randn((1, 128, 128), device=device)
51
+ noise = torch.randn_like(latents)
52
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=device)
53
+ _ = noise_scheduler.add_noise(latents, noise, timesteps)
54
+
55
+ return autoencoder, unet, controlnet, conditioner, tokenizer, text_encoder, noise_scheduler, params
56
+
57
+
58
+ MAX_SEED = np.iinfo(np.int32).max
59
+
60
+ # Model and config paths
61
+ config_name = 'ckpts/controlnet/energy_l.yml'
62
+ ckpt_path = 'ckpts/s3/ezaudio_s3_l.pt'
63
+ controlnet_path = 'ckpts/controlnet/s3_l_energy.pt'
64
+ vae_path = 'ckpts/vae/1m.pt'
65
+ # save_path = 'output/'
66
+ # os.makedirs(save_path, exist_ok=True)
67
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
68
+
69
+ (autoencoder, unet, controlnet, conditioner,
70
+ tokenizer, text_encoder, noise_scheduler, params) = load_models(config_name, ckpt_path, controlnet_path, vae_path, device)
71
+
72
+
73
+ def generate_audio(text,
74
+ audio_path, surpass_noise,
75
+ guidance_scale, guidance_rescale,
76
+ ddim_steps, eta,
77
+ conditioning_scale,
78
+ random_seed, randomize_seed):
79
+ sr = params['autoencoder']['sr']
80
+
81
+ gt, _ = librosa.load(audio_path, sr=sr)
82
+ gt = gt / (np.max(np.abs(gt)) + 1e-9) # Normalize audio
83
+
84
+ if surpass_noise > 0:
85
+ mask = np.abs(gt) <= surpass_noise
86
+ gt[mask] = 0
87
+
88
+ original_length = len(gt)
89
+ # Ensure the audio is of the correct length by padding or trimming
90
+ duration_seconds = len(gt) / sr
91
+ quantized_duration = np.ceil(duration_seconds * 2) / 2 # This rounds to the nearest 0.5 seconds
92
+ num_samples = int(quantized_duration * sr)
93
+ audio_frames = round(num_samples / sr * params['autoencoder']['latent_sr'])
94
+
95
+ if len(gt) < num_samples:
96
+ padding = num_samples - len(gt)
97
+ gt = np.pad(gt, (0, padding), 'constant')
98
+ else:
99
+ gt = gt[:num_samples]
100
+
101
+ gt_audio = torch.tensor(gt).unsqueeze(0).unsqueeze(1).to(device)
102
+ gt = autoencoder(audio=gt_audio)
103
+ condition = conditioner(gt_audio.squeeze(1), gt.shape)
104
+
105
+ # Handle random seed
106
+ if randomize_seed:
107
+ random_seed = random.randint(0, MAX_SEED)
108
+
109
+ # Perform inference
110
+ pred = inference(autoencoder, unet, controlnet,
111
+ None, None, condition,
112
+ tokenizer, text_encoder,
113
+ params, noise_scheduler,
114
+ text, neg_text=None,
115
+ audio_frames=audio_frames,
116
+ guidance_scale=guidance_scale, guidance_rescale=guidance_rescale,
117
+ ddim_steps=ddim_steps, eta=eta, random_seed=random_seed,
118
+ conditioning_scale=conditioning_scale, device=device)
119
+
120
+ pred = pred.cpu().numpy().squeeze(0).squeeze(0)[:original_length]
121
+
122
+ return sr, pred
123
+
124
+ # CSS styling (optional)
125
+ css = """
126
+ #col-container {
127
+ margin: 0 auto;
128
+ max-width: 1280px;
129
+ }
130
+ """
131
+
132
+ # Gradio Blocks layout
133
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
134
+ gr.Markdown("""
135
+ # EzAudio: High-quality Text-to-Audio Generator
136
+ Generate and edit audio from text using a diffusion transformer. Adjust advanced settings for more control.
137
+
138
+ [Learn more about 😈EzAudio](https://haidog-yaqub.github.io/EzAudio-Page/)
139
+ """)
140
+ with gr.Row():
141
+ # Input for the text prompt (used for generating new audio)
142
+ text_input = gr.Textbox(
143
+ label="Text Prompt",
144
+ show_label=True,
145
+ max_lines=2,
146
+ placeholder="Describe the sound you want to generate",
147
+ value="A dog barking in the background",
148
+ scale=4
149
+ )
150
+ # Button to generate the audio
151
+ generate_button = gr.Button("Generate")
152
+
153
+ # Audio input to use as base
154
+ audio_file_input = gr.Audio(label="Upload Reference Audio (less than 10s)", value='reference.mp3', type="filepath")
155
+
156
+ # Output Component for the generated audio
157
+ generated_audio_output = gr.Audio(label="Generated Audio", type="numpy")
158
+
159
+ with gr.Accordion("Advanced Settings", open=False):
160
+ # Length of the generated audio
161
+ surpass_noise = gr.Slider(minimum=0, maximum=0.2, step=0.01, value=0.05, label="Noise Threshold (Amplitude)")
162
+ guidance_scale = gr.Slider(minimum=1.0, maximum=10.0, step=0.5, value=5.0, label="Guidance Scale")
163
+ guidance_rescale = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Guidance Rescale")
164
+ ddim_steps = gr.Slider(minimum=25, maximum=200, step=5, value=50, label="DDIM Steps")
165
+ eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="Eta")
166
+ conditioning_scale = gr.Slider(minimum=0.0, maximum=2.0, step=0.25, value=1.0, label="Conditioning Scale")
167
+ random_seed = gr.Slider(minimum=0, maximum=10000, step=1, value=0, label="Random Seed")
168
+ randomize_seed = gr.Checkbox(label="Randomize Seed (Disable Seed)", value=True)
169
+
170
+ # Link the inputs to the function
171
+ generate_button.click(
172
+ fn=generate_audio,
173
+ inputs=[
174
+ text_input, audio_file_input, surpass_noise, guidance_scale, guidance_rescale,
175
+ ddim_steps, eta, conditioning_scale, random_seed, randomize_seed
176
+ ],
177
+ outputs=[generated_audio_output]
178
+ )
179
+
180
+ text_input.submit(
181
+ fn=generate_audio,
182
+ inputs=[
183
+ text_input, audio_file_input, surpass_noise, guidance_scale, guidance_rescale,
184
+ ddim_steps, eta, conditioning_scale, random_seed, randomize_seed
185
+ ],
186
+ outputs=[generated_audio_output]
187
+ )
188
+
189
+ # Launch the Gradio demo
190
+ demo.launch(share=True)