import argparse import os import librosa import numpy as np import soundfile as sf import torch from tqdm import tqdm from lib import dataset from lib import nets from lib import spec_utils from lib import utils class Separator(object): def __init__(self, model, device, batchsize, cropsize, postprocess=False): self.model = model self.offset = model.offset self.device = device self.batchsize = batchsize self.cropsize = cropsize self.postprocess = postprocess def _separate(self, X_mag_pad, roi_size): X_dataset = [] patches = (X_mag_pad.shape[2] - 2 * self.offset) // roi_size for i in range(patches): start = i * roi_size X_mag_crop = X_mag_pad[:, :, start:start + self.cropsize] X_dataset.append(X_mag_crop) X_dataset = np.asarray(X_dataset) self.model.eval() with torch.no_grad(): mask = [] # To reduce the overhead, dataloader is not used. for i in tqdm(range(0, patches, self.batchsize)): X_batch = X_dataset[i: i + self.batchsize] X_batch = torch.from_numpy(X_batch).to(self.device) pred = self.model.predict_mask(X_batch) pred = pred.detach().cpu().numpy() pred = np.concatenate(pred, axis=2) mask.append(pred) mask = np.concatenate(mask, axis=2) return mask def _preprocess(self, X_spec): X_mag = np.abs(X_spec) X_phase = np.angle(X_spec) return X_mag, X_phase def _postprocess(self, mask, X_mag, X_phase): if self.postprocess: mask = spec_utils.merge_artifacts(mask) y_spec = mask * X_mag * np.exp(1.j * X_phase) v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase) return y_spec, v_spec def separate(self, X_spec): X_mag, X_phase = self._preprocess(X_spec) n_frame = X_mag.shape[2] pad_l, pad_r, roi_size = dataset.make_padding(n_frame, self.cropsize, self.offset) X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') X_mag_pad /= X_mag_pad.max() mask = self._separate(X_mag_pad, roi_size) mask = mask[:, :, :n_frame] y_spec, v_spec = self._postprocess(mask, X_mag, X_phase) return y_spec, v_spec def separate_tta(self, X_spec): X_mag, X_phase = self._preprocess(X_spec) n_frame = X_mag.shape[2] pad_l, pad_r, roi_size = dataset.make_padding(n_frame, self.cropsize, self.offset) X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') X_mag_pad /= X_mag_pad.max() mask = self._separate(X_mag_pad, roi_size) pad_l += roi_size // 2 pad_r += roi_size // 2 X_mag_pad = np.pad(X_mag, ((0, 0), (0, 0), (pad_l, pad_r)), mode='constant') X_mag_pad /= X_mag_pad.max() mask_tta = self._separate(X_mag_pad, roi_size) mask_tta = mask_tta[:, :, roi_size // 2:] mask = (mask[:, :, :n_frame] + mask_tta[:, :, :n_frame]) * 0.5 y_spec, v_spec = self._postprocess(mask, X_mag, X_phase) return y_spec, v_spec def main(): p = argparse.ArgumentParser() p.add_argument('--gpu', '-g', type=int, default=-1) p.add_argument('--pretrained_model', '-P', type=str, default='models/baseline.pth') p.add_argument('--input', '-i', required=True) p.add_argument('--sr', '-r', type=int, default=44100) p.add_argument('--n_fft', '-f', type=int, default=2048) p.add_argument('--hop_length', '-H', type=int, default=1024) p.add_argument('--batchsize', '-B', type=int, default=4) p.add_argument('--cropsize', '-c', type=int, default=256) p.add_argument('--output_image', '-I', action='store_true') p.add_argument('--postprocess', '-p', action='store_true') p.add_argument('--tta', '-t', action='store_true') p.add_argument('--output_dir', '-o', type=str, default="") args = p.parse_args() print('loading model...', end=' ') device = torch.device('cpu') model = nets.CascadedNet(args.n_fft, 32, 128) model.load_state_dict(torch.load(args.pretrained_model, map_location=device)) if torch.cuda.is_available() and args.gpu >= 0: device = torch.device('cuda:{}'.format(args.gpu)) model.to(device) print('done') print('loading wave source...', end=' ') X, sr = librosa.load( args.input, args.sr, False, dtype=np.float32, res_type='kaiser_fast') basename = os.path.splitext(os.path.basename(args.input))[0] print('done') if X.ndim == 1: # mono to stereo X = np.asarray([X, X]) print('stft of wave source...', end=' ') X_spec = spec_utils.wave_to_spectrogram(X, args.hop_length, args.n_fft) print('done') sp = Separator(model, device, args.batchsize, args.cropsize, args.postprocess) if args.tta: y_spec, v_spec = sp.separate_tta(X_spec) else: y_spec, v_spec = sp.separate(X_spec) print('validating output directory...', end=' ') output_dir = args.output_dir if output_dir != "": # modifies output_dir if theres an arg specified output_dir = output_dir.rstrip('/') + '/' os.makedirs(output_dir, exist_ok=True) print('done') print('inverse stft of instruments...', end=' ') wave = spec_utils.spectrogram_to_wave(y_spec, hop_length=args.hop_length) print('done') # sf.write('{}{}_Instruments.wav'.format(output_dir, basename), wave.T, sr) sf.write('{}Instruments.wav'.format(output_dir), wave.T, sr) print('inverse stft of vocals...', end=' ') wave = spec_utils.spectrogram_to_wave(v_spec, hop_length=args.hop_length) print('done') sf.write('{}{}_Vocals.wav'.format(output_dir, basename), wave.T, sr) if args.output_image: image = spec_utils.spectrogram_to_image(y_spec) utils.imwrite('{}{}_Instruments.jpg'.format(output_dir, basename), image) image = spec_utils.spectrogram_to_image(v_spec) utils.imwrite('{}{}_Vocals.jpg'.format(output_dir, basename), image) if __name__ == '__main__': main()