File size: 8,167 Bytes
96e64e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
import os
import sys
# to import modules from parent_dir
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
import torch
import json
from env import AttrDict
from bigvgan import BigVGAN
from time import time
from tqdm import tqdm
from meldataset import mel_spectrogram, MAX_WAV_VALUE
from scipy.io.wavfile import write
import numpy as np
import argparse
torch.backends.cudnn.benchmark = True
# For easier debugging
torch.set_printoptions(linewidth=200, threshold=10_000)
def generate_soundwave(duration=5.0, sr=24000):
t = np.linspace(0, duration, int(sr * duration), False, dtype=np.float32)
modulation = np.sin(2 * np.pi * t / duration)
min_freq = 220
max_freq = 1760
frequencies = min_freq + (max_freq - min_freq) * (modulation + 1) / 2
soundwave = np.sin(2 * np.pi * frequencies * t)
soundwave = soundwave / np.max(np.abs(soundwave)) * 0.95
return soundwave, sr
def get_mel(x, h):
return mel_spectrogram(
x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax
)
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print(f"Loading '{filepath}'")
checkpoint_dict = torch.load(filepath, map_location=device)
print("Complete.")
return checkpoint_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Test script to check CUDA kernel correctness."
)
parser.add_argument(
"--checkpoint_file",
type=str,
required=True,
help="Path to the checkpoint file. Assumes config.json exists in the directory.",
)
args = parser.parse_args()
config_file = os.path.join(os.path.split(args.checkpoint_file)[0], "config.json")
with open(config_file) as f:
config = f.read()
json_config = json.loads(config)
h = AttrDict({**json_config})
print("loading plain Pytorch BigVGAN")
generator_original = BigVGAN(h).to("cuda")
print("loading CUDA kernel BigVGAN with auto-build")
generator_cuda_kernel = BigVGAN(h, use_cuda_kernel=True).to("cuda")
state_dict_g = load_checkpoint(args.checkpoint_file, "cuda")
generator_original.load_state_dict(state_dict_g["generator"])
generator_cuda_kernel.load_state_dict(state_dict_g["generator"])
generator_original.remove_weight_norm()
generator_original.eval()
generator_cuda_kernel.remove_weight_norm()
generator_cuda_kernel.eval()
# define number of samples and length of mel frame to benchmark
num_sample = 10
num_mel_frame = 16384
# CUDA kernel correctness check
diff = 0.0
for i in tqdm(range(num_sample)):
# Random mel
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
with torch.inference_mode():
audio_original = generator_original(data)
with torch.inference_mode():
audio_cuda_kernel = generator_cuda_kernel(data)
# Both outputs should be (almost) the same
test_result = (audio_original - audio_cuda_kernel).abs()
diff += test_result.mean(dim=-1).item()
diff /= num_sample
if (
diff <= 2e-3
): # We can expect a small difference (~1e-3) which does not affect perceptual quality
print(
f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference"
f"\n > mean_difference={diff}"
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}"
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
)
else:
print(
f"\n[Fail] test CUDA fused vs. plain torch BigVGAN inference"
f"\n > mean_difference={diff}"
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}, "
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
)
del data, audio_original, audio_cuda_kernel
# Variables for tracking total time and VRAM usage
toc_total_original = 0
toc_total_cuda_kernel = 0
vram_used_original_total = 0
vram_used_cuda_kernel_total = 0
audio_length_total = 0
# Measure Original inference in isolation
for i in tqdm(range(num_sample)):
torch.cuda.reset_peak_memory_stats(device="cuda")
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
torch.cuda.synchronize()
tic = time()
with torch.inference_mode():
audio_original = generator_original(data)
torch.cuda.synchronize()
toc = time() - tic
toc_total_original += toc
vram_used_original_total += torch.cuda.max_memory_allocated(device="cuda")
del data, audio_original
torch.cuda.empty_cache()
# Measure CUDA kernel inference in isolation
for i in tqdm(range(num_sample)):
torch.cuda.reset_peak_memory_stats(device="cuda")
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
torch.cuda.synchronize()
tic = time()
with torch.inference_mode():
audio_cuda_kernel = generator_cuda_kernel(data)
torch.cuda.synchronize()
toc = time() - tic
toc_total_cuda_kernel += toc
audio_length_total += audio_cuda_kernel.shape[-1]
vram_used_cuda_kernel_total += torch.cuda.max_memory_allocated(device="cuda")
del data, audio_cuda_kernel
torch.cuda.empty_cache()
# Calculate metrics
audio_second = audio_length_total / h.sampling_rate
khz_original = audio_length_total / toc_total_original / 1000
khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000
vram_used_original_gb = vram_used_original_total / num_sample / (1024 ** 3)
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024 ** 3)
# Print results
print(
f"Original BigVGAN: took {toc_total_original:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_original:.1f}kHz, {audio_second / toc_total_original:.1f} faster than realtime, VRAM used {vram_used_original_gb:.1f} GB"
)
print(
f"CUDA kernel BigVGAN: took {toc_total_cuda_kernel:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_cuda_kernel:.1f}kHz, {audio_second / toc_total_cuda_kernel:.1f} faster than realtime, VRAM used {vram_used_cuda_kernel_gb:.1f} GB"
)
print(f"speedup of CUDA kernel: {khz_cuda_kernel / khz_original}")
print(f"VRAM saving of CUDA kernel: {vram_used_original_gb / vram_used_cuda_kernel_gb}")
# Use artificial sine waves for inference test
audio_real, sr = generate_soundwave(duration=5.0, sr=h.sampling_rate)
audio_real = torch.tensor(audio_real).to("cuda")
# Compute mel spectrogram from the ground truth audio
x = get_mel(audio_real.unsqueeze(0), h)
with torch.inference_mode():
y_g_hat_original = generator_original(x)
y_g_hat_cuda_kernel = generator_cuda_kernel(x)
audio_real = audio_real.squeeze()
audio_real = audio_real * MAX_WAV_VALUE
audio_real = audio_real.cpu().numpy().astype("int16")
audio_original = y_g_hat_original.squeeze()
audio_original = audio_original * MAX_WAV_VALUE
audio_original = audio_original.cpu().numpy().astype("int16")
audio_cuda_kernel = y_g_hat_cuda_kernel.squeeze()
audio_cuda_kernel = audio_cuda_kernel * MAX_WAV_VALUE
audio_cuda_kernel = audio_cuda_kernel.cpu().numpy().astype("int16")
os.makedirs("tmp", exist_ok=True)
output_file_real = os.path.join("tmp", "audio_real.wav")
output_file_original = os.path.join("tmp", "audio_generated_original.wav")
output_file_cuda_kernel = os.path.join("tmp", "audio_generated_cuda_kernel.wav")
write(output_file_real, h.sampling_rate, audio_real)
write(output_file_original, h.sampling_rate, audio_original)
write(output_file_cuda_kernel, h.sampling_rate, audio_cuda_kernel)
print("Example generated audios of original vs. fused CUDA kernel written to tmp!")
print("Done")
|