L0SG's picture
update
aa76dec
raw
history blame
14.7 kB
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import os
import json
from pathlib import Path
from collections import namedtuple
from typing import Optional, List, Union, Dict
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm
import activations
from utils import init_weights, get_padding
from alias_free_torch.act import Activation1d as TorchActivation1d
from env import AttrDict
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
def load_hparams_from_json(path) -> AttrDict:
with open(path) as f:
data = f.read()
h = json.loads(data)
return AttrDict(h)
class AMPBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
super(AMPBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
# select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
# faster CUDA kernel implementation of Activation1d
from alias_free_cuda.activation1d import Activation1d as CudaActivation1d
Activation1d = CudaActivation1d
else:
Activation1d = TorchActivation1d
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
def forward(self, x):
acts1, acts2 = self.activations[::2], self.activations[1::2]
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class AMPBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
super(AMPBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])))
])
self.convs.apply(init_weights)
self.num_layers = len(self.convs) # total number of conv layers
# select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
# faster CUDA kernel implementation of Activation1d
from alias_free_cuda.activation1d import Activation1d as CudaActivation1d
Activation1d = CudaActivation1d
else:
Activation1d = TorchActivation1d
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
Activation1d(
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
])
else:
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
def forward(self, x):
for c, a in zip (self.convs, self.activations):
xt = a(x)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class BigVGAN(
torch.nn.Module,
PyTorchModelHubMixin,
library_name="bigvgan",
repo_url="https://github.com/NVIDIA/BigVGAN",
docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
pipeline_tag="audio-to-audio",
license="mit",
tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"]
):
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
# New in v2: if use_cuda_kernel is set to True, it loads optimized CUDA kernels for AMP.
# NOTE: use_cuda_kernel=True should be used for inference only (training is not supported).
def __init__(
self,
h,
use_cuda_kernel: bool=False
):
super(BigVGAN, self).__init__()
self.h = h
self.h["use_cuda_kernel"] = use_cuda_kernel # add it to global hyperparameters (h)
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
# pre conv
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
# transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(nn.ModuleList([
weight_norm(ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
h.upsample_initial_channel // (2 ** (i + 1)),
k, u, padding=(k - u) // 2))
]))
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
# select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
# faster CUDA kernel implementation of Activation1d
from alias_free_cuda.activation1d import Activation1d as CudaActivation1d
Activation1d = CudaActivation1d
else:
Activation1d = TorchActivation1d
# post conv
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
self.activation_post = Activation1d(activation=activation_post)
else:
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
# whether to use bias for the final conv_post. Defaults to True for backward compatibility
self.use_bias_at_final = h.get("use_bias_at_final", True)
self.conv_post = weight_norm(Conv1d(
ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final
))
# weight initialization
for i in range(len(self.ups)):
self.ups[i].apply(init_weights)
self.conv_post.apply(init_weights)
# final tanh activation. Defaults to True for backward compatibility
self.use_tanh_at_final = h.get("use_tanh_at_final", True)
def forward(self, x):
# pre conv
x = self.conv_pre(x)
for i in range(self.num_upsamples):
# upsampling
for i_up in range(len(self.ups[i])):
x = self.ups[i][i_up](x)
# AMP blocks
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
# post conv
x = self.activation_post(x)
x = self.conv_post(x)
# final tanh activation
if self.use_tanh_at_final:
x = torch.tanh(x)
else:
x = torch.clamp(x, min=-1., max=1.) # bound the output to [-1, 1]
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
for l_i in l:
remove_weight_norm(l_i)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
##################################################################
# additional methods for huggingface_hub support
##################################################################
def _save_pretrained(self, save_directory: Path) -> None:
"""Save weights and config.json from a Pytorch model to a local directory."""
model_path = save_directory / 'bigvgan_generator.pt'
torch.save(
{'generator': self.state_dict()},
model_path
)
config_path = save_directory / 'config.json'
with open(config_path, 'w') as config_file:
json.dump(self.h, config_file, indent=4)
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
revision: str,
cache_dir: str,
force_download: bool,
proxies: Optional[Dict],
resume_download: bool,
local_files_only: bool,
token: Union[str, bool, None],
map_location: str = "cpu", # additional argument
strict: bool = False, # additional argument
use_cuda_kernel: bool = False,
**model_kwargs,
):
"""Load Pytorch pretrained weights and return the loaded model."""
##################################################################
# download and load hyperparameters (h) used by BigVGAN
##################################################################
config_file = hf_hub_download(
repo_id=model_id,
filename='config.json',
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
h = load_hparams_from_json(config_file)
##################################################################
# instantiate BigVGAN using h
##################################################################
if use_cuda_kernel:
print(f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!")
print(f"[WARNING] You need nvcc and ninja installed in your system to build the kernel. For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis")
model = cls(h, use_cuda_kernel=use_cuda_kernel)
##################################################################
# download and load pretrained generator weight
##################################################################
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, 'bigvgan_generator.pt')
else:
print(f"Downloading weights from {model_id}")
model_file = hf_hub_download(
repo_id=model_id,
filename='bigvgan_generator.pt',
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
checkpoint_dict = torch.load(model_file, map_location=map_location)
model.load_state_dict(checkpoint_dict['generator'])
return model