Spaces:
Runtime error
Runtime error
import os | |
import argparse | |
import json | |
from tqdm import tqdm | |
from copy import deepcopy | |
import numpy as np | |
import torch | |
import random | |
random.seed(0) | |
torch.manual_seed(0) | |
np.random.seed(0) | |
from scipy.io.wavfile import write as wavwrite | |
from dataset import load_CleanNoisyPairDataset | |
from util import find_max_epoch, print_size, sampling | |
from network import CleanUNet | |
def denoise(output_directory, ckpt_iter, subset, dump=False): | |
""" | |
Denoise audio | |
Parameters: | |
output_directory (str): save generated speeches to this path | |
ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded; | |
automitically selects the maximum iteration if 'max' is selected | |
subset (str): training, testing, validation | |
dump (bool): whether save enhanced (denoised) audio | |
""" | |
# setup local experiment path | |
exp_path = train_config["exp_path"] | |
print('exp_path:', exp_path) | |
# load data | |
loader_config = deepcopy(trainset_config) | |
loader_config["crop_length_sec"] = 0 | |
dataloader = load_CleanNoisyPairDataset( | |
**loader_config, | |
subset=subset, | |
batch_size=1, | |
num_gpus=1 | |
) | |
# predefine model | |
net = CleanUNet(**network_config).cuda() | |
print_size(net) | |
# load checkpoint | |
ckpt_directory = os.path.join(train_config["log"]["directory"], exp_path, 'checkpoint') | |
if ckpt_iter == 'max': | |
ckpt_iter = find_max_epoch(ckpt_directory) | |
if ckpt_iter != 'pretrained': | |
ckpt_iter = int(ckpt_iter) | |
model_path = os.path.join(ckpt_directory, '{}.pkl'.format(ckpt_iter)) | |
checkpoint = torch.load(model_path, map_location='cpu') | |
net.load_state_dict(checkpoint['model_state_dict']) | |
net.eval() | |
# get output directory ready | |
if ckpt_iter == "pretrained": | |
speech_directory = os.path.join(output_directory, exp_path, 'speech', ckpt_iter) | |
else: | |
speech_directory = os.path.join(output_directory, exp_path, 'speech', '{}k'.format(ckpt_iter//1000)) | |
if dump and not os.path.isdir(speech_directory): | |
os.makedirs(speech_directory) | |
os.chmod(speech_directory, 0o775) | |
print("speech_directory: ", speech_directory, flush=True) | |
# inference | |
all_generated_audio = [] | |
all_clean_audio = [] | |
sortkey = lambda name: '_'.join(name.split('/')[-1].split('_')[1:]) | |
for clean_audio, noisy_audio, fileid in tqdm(dataloader): | |
filename = sortkey(fileid[0][0]) | |
noisy_audio = noisy_audio.cuda() | |
LENGTH = len(noisy_audio[0].squeeze()) | |
generated_audio = sampling(net, noisy_audio) | |
if dump: | |
wavwrite(os.path.join(speech_directory, 'enhanced_{}'.format(filename)), | |
trainset_config["sample_rate"], | |
generated_audio[0].squeeze().cpu().numpy()) | |
else: | |
all_clean_audio.append(clean_audio[0].squeeze().cpu().numpy()) | |
all_generated_audio.append(generated_audio[0].squeeze().cpu().numpy()) | |
return all_clean_audio, all_generated_audio | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-c', '--config', type=str, default='config.json', | |
help='JSON file for configuration') | |
parser.add_argument('-ckpt_iter', '--ckpt_iter', default='max', | |
help='Which checkpoint to use; assign a number or "max" or "pretrained"') | |
parser.add_argument('-subset', '--subset', type=str, choices=['training', 'testing', 'validation'], | |
default='testing', help='subset for denoising') | |
args = parser.parse_args() | |
# Parse configs. Globals nicer in this case | |
with open(args.config) as f: | |
data = f.read() | |
config = json.loads(data) | |
gen_config = config["gen_config"] | |
global network_config | |
network_config = config["network_config"] # to define wavenet | |
global train_config | |
train_config = config["train_config"] # train config | |
global trainset_config | |
trainset_config = config["trainset_config"] # to read trainset configurations | |
torch.backends.cudnn.enabled = True | |
torch.backends.cudnn.benchmark = True | |
if args.subset == "testing": | |
denoise(gen_config["output_directory"], | |
subset=args.subset, | |
ckpt_iter=args.ckpt_iter, | |
dump=True) |