nvidia_denoiser / denoise.py
azamat's picture
Init
33e3a91
import os
import argparse
import json
from tqdm import tqdm
from copy import deepcopy
import numpy as np
import torch
import random
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)
from scipy.io.wavfile import write as wavwrite
from dataset import load_CleanNoisyPairDataset
from util import find_max_epoch, print_size, sampling
from network import CleanUNet
def denoise(output_directory, ckpt_iter, subset, dump=False):
"""
Denoise audio
Parameters:
output_directory (str): save generated speeches to this path
ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded;
automitically selects the maximum iteration if 'max' is selected
subset (str): training, testing, validation
dump (bool): whether save enhanced (denoised) audio
"""
# setup local experiment path
exp_path = train_config["exp_path"]
print('exp_path:', exp_path)
# load data
loader_config = deepcopy(trainset_config)
loader_config["crop_length_sec"] = 0
dataloader = load_CleanNoisyPairDataset(
**loader_config,
subset=subset,
batch_size=1,
num_gpus=1
)
# predefine model
net = CleanUNet(**network_config).cuda()
print_size(net)
# load checkpoint
ckpt_directory = os.path.join(train_config["log"]["directory"], exp_path, 'checkpoint')
if ckpt_iter == 'max':
ckpt_iter = find_max_epoch(ckpt_directory)
if ckpt_iter != 'pretrained':
ckpt_iter = int(ckpt_iter)
model_path = os.path.join(ckpt_directory, '{}.pkl'.format(ckpt_iter))
checkpoint = torch.load(model_path, map_location='cpu')
net.load_state_dict(checkpoint['model_state_dict'])
net.eval()
# get output directory ready
if ckpt_iter == "pretrained":
speech_directory = os.path.join(output_directory, exp_path, 'speech', ckpt_iter)
else:
speech_directory = os.path.join(output_directory, exp_path, 'speech', '{}k'.format(ckpt_iter//1000))
if dump and not os.path.isdir(speech_directory):
os.makedirs(speech_directory)
os.chmod(speech_directory, 0o775)
print("speech_directory: ", speech_directory, flush=True)
# inference
all_generated_audio = []
all_clean_audio = []
sortkey = lambda name: '_'.join(name.split('/')[-1].split('_')[1:])
for clean_audio, noisy_audio, fileid in tqdm(dataloader):
filename = sortkey(fileid[0][0])
noisy_audio = noisy_audio.cuda()
LENGTH = len(noisy_audio[0].squeeze())
generated_audio = sampling(net, noisy_audio)
if dump:
wavwrite(os.path.join(speech_directory, 'enhanced_{}'.format(filename)),
trainset_config["sample_rate"],
generated_audio[0].squeeze().cpu().numpy())
else:
all_clean_audio.append(clean_audio[0].squeeze().cpu().numpy())
all_generated_audio.append(generated_audio[0].squeeze().cpu().numpy())
return all_clean_audio, all_generated_audio
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='config.json',
help='JSON file for configuration')
parser.add_argument('-ckpt_iter', '--ckpt_iter', default='max',
help='Which checkpoint to use; assign a number or "max" or "pretrained"')
parser.add_argument('-subset', '--subset', type=str, choices=['training', 'testing', 'validation'],
default='testing', help='subset for denoising')
args = parser.parse_args()
# Parse configs. Globals nicer in this case
with open(args.config) as f:
data = f.read()
config = json.loads(data)
gen_config = config["gen_config"]
global network_config
network_config = config["network_config"] # to define wavenet
global train_config
train_config = config["train_config"] # train config
global trainset_config
trainset_config = config["trainset_config"] # to read trainset configurations
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
if args.subset == "testing":
denoise(gen_config["output_directory"],
subset=args.subset,
ckpt_iter=args.ckpt_iter,
dump=True)