File size: 4,579 Bytes
33e3a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import argparse
import json
from tqdm import tqdm
from copy import deepcopy

import numpy as np
import torch

import random
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)

from scipy.io.wavfile import write as wavwrite

from dataset import load_CleanNoisyPairDataset
from util import find_max_epoch, print_size, sampling
from network import CleanUNet


def denoise(output_directory, ckpt_iter, subset, dump=False):
    """

    Denoise audio



    Parameters:

    output_directory (str):         save generated speeches to this path

    ckpt_iter (int or 'max'):       the pretrained checkpoint to be loaded; 

                                    automitically selects the maximum iteration if 'max' is selected

    subset (str):                   training, testing, validation

    dump (bool):                    whether save enhanced (denoised) audio

    """

    # setup local experiment path
    exp_path = train_config["exp_path"]
    print('exp_path:', exp_path)

    # load data
    loader_config = deepcopy(trainset_config)
    loader_config["crop_length_sec"] = 0
    dataloader = load_CleanNoisyPairDataset(
        **loader_config, 
        subset=subset,
        batch_size=1, 
        num_gpus=1
    )

    # predefine model
    net = CleanUNet(**network_config).cuda()
    print_size(net)

    # load checkpoint
    ckpt_directory = os.path.join(train_config["log"]["directory"], exp_path, 'checkpoint')
    if ckpt_iter == 'max':
        ckpt_iter = find_max_epoch(ckpt_directory)
    if ckpt_iter != 'pretrained':
        ckpt_iter = int(ckpt_iter)
    model_path = os.path.join(ckpt_directory, '{}.pkl'.format(ckpt_iter))
    checkpoint = torch.load(model_path, map_location='cpu')
    net.load_state_dict(checkpoint['model_state_dict'])
    net.eval()

    # get output directory ready
    if ckpt_iter == "pretrained":
        speech_directory = os.path.join(output_directory, exp_path, 'speech', ckpt_iter) 
    else:
        speech_directory = os.path.join(output_directory, exp_path, 'speech', '{}k'.format(ckpt_iter//1000))
    if dump and not os.path.isdir(speech_directory):
        os.makedirs(speech_directory)
        os.chmod(speech_directory, 0o775)
    print("speech_directory: ", speech_directory, flush=True)

    # inference
    all_generated_audio = []
    all_clean_audio = []
    sortkey = lambda name: '_'.join(name.split('/')[-1].split('_')[1:])
    for clean_audio, noisy_audio, fileid in tqdm(dataloader):
        filename = sortkey(fileid[0][0])

        noisy_audio = noisy_audio.cuda()
        LENGTH = len(noisy_audio[0].squeeze())
        generated_audio = sampling(net, noisy_audio)
        
        if dump:
            wavwrite(os.path.join(speech_directory, 'enhanced_{}'.format(filename)), 
                    trainset_config["sample_rate"],
                    generated_audio[0].squeeze().cpu().numpy())
        else:
            all_clean_audio.append(clean_audio[0].squeeze().cpu().numpy())
            all_generated_audio.append(generated_audio[0].squeeze().cpu().numpy())

    return all_clean_audio, all_generated_audio


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', type=str, default='config.json',
                        help='JSON file for configuration')
    parser.add_argument('-ckpt_iter', '--ckpt_iter', default='max',
                        help='Which checkpoint to use; assign a number or "max" or "pretrained"')     
    parser.add_argument('-subset', '--subset', type=str, choices=['training', 'testing', 'validation'],
                        default='testing', help='subset for denoising')
    args = parser.parse_args()

    # Parse configs. Globals nicer in this case
    with open(args.config) as f:
        data = f.read()
    config = json.loads(data)
    gen_config              = config["gen_config"]
    global network_config
    network_config          = config["network_config"]      # to define wavenet
    global train_config
    train_config            = config["train_config"]        # train config
    global trainset_config
    trainset_config         = config["trainset_config"]     # to read trainset configurations

    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

    if args.subset == "testing":
        denoise(gen_config["output_directory"],
                subset=args.subset,
                ckpt_iter=args.ckpt_iter,
                dump=True)