Eddycrack864 commited on
Commit
dfc1efe
1 Parent(s): becfb80

Upload 17 files

Browse files
demucs/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
demucs/__main__.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ import sys
10
+ import time
11
+ from dataclasses import dataclass, field
12
+ from fractions import Fraction
13
+
14
+ import torch as th
15
+ from torch import distributed, nn
16
+ from torch.nn.parallel.distributed import DistributedDataParallel
17
+
18
+ from .augment import FlipChannels, FlipSign, Remix, Shift
19
+ from .compressed import StemsSet, build_musdb_metadata, get_musdb_tracks
20
+ from .model import Demucs
21
+ from .parser import get_name, get_parser
22
+ from .raw import Rawset
23
+ from .tasnet import ConvTasNet
24
+ from .test import evaluate
25
+ from .train import train_model, validate_model
26
+ from .utils import human_seconds, load_model, save_model, sizeof_fmt
27
+
28
+
29
+ @dataclass
30
+ class SavedState:
31
+ metrics: list = field(default_factory=list)
32
+ last_state: dict = None
33
+ best_state: dict = None
34
+ optimizer: dict = None
35
+
36
+
37
+ def main():
38
+ parser = get_parser()
39
+ args = parser.parse_args()
40
+ name = get_name(parser, args)
41
+ print(f"Experiment {name}")
42
+
43
+ if args.musdb is None and args.rank == 0:
44
+ print(
45
+ "You must provide the path to the MusDB dataset with the --musdb flag. "
46
+ "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.",
47
+ file=sys.stderr)
48
+ sys.exit(1)
49
+
50
+ eval_folder = args.evals / name
51
+ eval_folder.mkdir(exist_ok=True, parents=True)
52
+ args.logs.mkdir(exist_ok=True)
53
+ metrics_path = args.logs / f"{name}.json"
54
+ eval_folder.mkdir(exist_ok=True, parents=True)
55
+ args.checkpoints.mkdir(exist_ok=True, parents=True)
56
+ args.models.mkdir(exist_ok=True, parents=True)
57
+
58
+ if args.device is None:
59
+ device = "cpu"
60
+ if th.cuda.is_available():
61
+ device = "cuda"
62
+ else:
63
+ device = args.device
64
+
65
+ th.manual_seed(args.seed)
66
+ # Prevents too many threads to be started when running `museval` as it can be quite
67
+ # inefficient on NUMA architectures.
68
+ os.environ["OMP_NUM_THREADS"] = "1"
69
+
70
+ if args.world_size > 1:
71
+ if device != "cuda" and args.rank == 0:
72
+ print("Error: distributed training is only available with cuda device", file=sys.stderr)
73
+ sys.exit(1)
74
+ th.cuda.set_device(args.rank % th.cuda.device_count())
75
+ distributed.init_process_group(backend="nccl",
76
+ init_method="tcp://" + args.master,
77
+ rank=args.rank,
78
+ world_size=args.world_size)
79
+
80
+ checkpoint = args.checkpoints / f"{name}.th"
81
+ checkpoint_tmp = args.checkpoints / f"{name}.th.tmp"
82
+ if args.restart and checkpoint.exists():
83
+ checkpoint.unlink()
84
+
85
+ if args.test:
86
+ args.epochs = 1
87
+ args.repeat = 0
88
+ model = load_model(args.models / args.test)
89
+ elif args.tasnet:
90
+ model = ConvTasNet(audio_channels=args.audio_channels, samplerate=args.samplerate, X=args.X)
91
+ else:
92
+ model = Demucs(
93
+ audio_channels=args.audio_channels,
94
+ channels=args.channels,
95
+ context=args.context,
96
+ depth=args.depth,
97
+ glu=args.glu,
98
+ growth=args.growth,
99
+ kernel_size=args.kernel_size,
100
+ lstm_layers=args.lstm_layers,
101
+ rescale=args.rescale,
102
+ rewrite=args.rewrite,
103
+ sources=4,
104
+ stride=args.conv_stride,
105
+ upsample=args.upsample,
106
+ samplerate=args.samplerate
107
+ )
108
+ model.to(device)
109
+ if args.show:
110
+ print(model)
111
+ size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
112
+ print(f"Model size {size}")
113
+ return
114
+
115
+ optimizer = th.optim.Adam(model.parameters(), lr=args.lr)
116
+
117
+ try:
118
+ saved = th.load(checkpoint, map_location='cpu')
119
+ except IOError:
120
+ saved = SavedState()
121
+ else:
122
+ model.load_state_dict(saved.last_state)
123
+ optimizer.load_state_dict(saved.optimizer)
124
+
125
+ if args.save_model:
126
+ if args.rank == 0:
127
+ model.to("cpu")
128
+ model.load_state_dict(saved.best_state)
129
+ save_model(model, args.models / f"{name}.th")
130
+ return
131
+
132
+ if args.rank == 0:
133
+ done = args.logs / f"{name}.done"
134
+ if done.exists():
135
+ done.unlink()
136
+
137
+ if args.augment:
138
+ augment = nn.Sequential(FlipSign(), FlipChannels(), Shift(args.data_stride),
139
+ Remix(group_size=args.remix_group_size)).to(device)
140
+ else:
141
+ augment = Shift(args.data_stride)
142
+
143
+ if args.mse:
144
+ criterion = nn.MSELoss()
145
+ else:
146
+ criterion = nn.L1Loss()
147
+
148
+ # Setting number of samples so that all convolution windows are full.
149
+ # Prevents hard to debug mistake with the prediction being shifted compared
150
+ # to the input mixture.
151
+ samples = model.valid_length(args.samples)
152
+ print(f"Number of training samples adjusted to {samples}")
153
+
154
+ if args.raw:
155
+ train_set = Rawset(args.raw / "train",
156
+ samples=samples + args.data_stride,
157
+ channels=args.audio_channels,
158
+ streams=[0, 1, 2, 3, 4],
159
+ stride=args.data_stride)
160
+
161
+ valid_set = Rawset(args.raw / "valid", channels=args.audio_channels)
162
+ else:
163
+ if not args.metadata.is_file() and args.rank == 0:
164
+ build_musdb_metadata(args.metadata, args.musdb, args.workers)
165
+ if args.world_size > 1:
166
+ distributed.barrier()
167
+ metadata = json.load(open(args.metadata))
168
+ duration = Fraction(samples + args.data_stride, args.samplerate)
169
+ stride = Fraction(args.data_stride, args.samplerate)
170
+ train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"),
171
+ metadata,
172
+ duration=duration,
173
+ stride=stride,
174
+ samplerate=args.samplerate,
175
+ channels=args.audio_channels)
176
+ valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"),
177
+ metadata,
178
+ samplerate=args.samplerate,
179
+ channels=args.audio_channels)
180
+
181
+ best_loss = float("inf")
182
+ for epoch, metrics in enumerate(saved.metrics):
183
+ print(f"Epoch {epoch:03d}: "
184
+ f"train={metrics['train']:.8f} "
185
+ f"valid={metrics['valid']:.8f} "
186
+ f"best={metrics['best']:.4f} "
187
+ f"duration={human_seconds(metrics['duration'])}")
188
+ best_loss = metrics['best']
189
+
190
+ if args.world_size > 1:
191
+ dmodel = DistributedDataParallel(model,
192
+ device_ids=[th.cuda.current_device()],
193
+ output_device=th.cuda.current_device())
194
+ else:
195
+ dmodel = model
196
+
197
+ for epoch in range(len(saved.metrics), args.epochs):
198
+ begin = time.time()
199
+ model.train()
200
+ train_loss = train_model(epoch,
201
+ train_set,
202
+ dmodel,
203
+ criterion,
204
+ optimizer,
205
+ augment,
206
+ batch_size=args.batch_size,
207
+ device=device,
208
+ repeat=args.repeat,
209
+ seed=args.seed,
210
+ workers=args.workers,
211
+ world_size=args.world_size)
212
+ model.eval()
213
+ valid_loss = validate_model(epoch,
214
+ valid_set,
215
+ model,
216
+ criterion,
217
+ device=device,
218
+ rank=args.rank,
219
+ split=args.split_valid,
220
+ world_size=args.world_size)
221
+
222
+ duration = time.time() - begin
223
+ if valid_loss < best_loss:
224
+ best_loss = valid_loss
225
+ saved.best_state = {
226
+ key: value.to("cpu").clone()
227
+ for key, value in model.state_dict().items()
228
+ }
229
+ saved.metrics.append({
230
+ "train": train_loss,
231
+ "valid": valid_loss,
232
+ "best": best_loss,
233
+ "duration": duration
234
+ })
235
+ if args.rank == 0:
236
+ json.dump(saved.metrics, open(metrics_path, "w"))
237
+
238
+ saved.last_state = model.state_dict()
239
+ saved.optimizer = optimizer.state_dict()
240
+ if args.rank == 0 and not args.test:
241
+ th.save(saved, checkpoint_tmp)
242
+ checkpoint_tmp.rename(checkpoint)
243
+
244
+ print(f"Epoch {epoch:03d}: "
245
+ f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} "
246
+ f"duration={human_seconds(duration)}")
247
+
248
+ del dmodel
249
+ model.load_state_dict(saved.best_state)
250
+ if args.eval_cpu:
251
+ device = "cpu"
252
+ model.to(device)
253
+ model.eval()
254
+ evaluate(model,
255
+ args.musdb,
256
+ eval_folder,
257
+ rank=args.rank,
258
+ world_size=args.world_size,
259
+ device=device,
260
+ save=args.save,
261
+ split=args.split_valid,
262
+ shifts=args.shifts,
263
+ workers=args.eval_workers)
264
+ model.to("cpu")
265
+ save_model(model, args.models / f"{name}.th")
266
+ if args.rank == 0:
267
+ print("done")
268
+ done.write_text("done")
269
+
270
+
271
+ if __name__ == "__main__":
272
+ main()
demucs/apply.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Code to apply a model to a mix. It will handle chunking with overlaps and
8
+ inteprolation between chunks, as well as the "shift trick".
9
+ """
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ import random
12
+ import typing as tp
13
+ from multiprocessing import Process,Queue,Pipe
14
+
15
+ import torch as th
16
+ from torch import nn
17
+ from torch.nn import functional as F
18
+ import tqdm
19
+ import tkinter as tk
20
+
21
+ from .demucs import Demucs
22
+ from .hdemucs import HDemucs
23
+ from .utils import center_trim, DummyPoolExecutor
24
+
25
+ Model = tp.Union[Demucs, HDemucs]
26
+
27
+ progress_bar_num = 0
28
+
29
+ class BagOfModels(nn.Module):
30
+ def __init__(self, models: tp.List[Model],
31
+ weights: tp.Optional[tp.List[tp.List[float]]] = None,
32
+ segment: tp.Optional[float] = None):
33
+ """
34
+ Represents a bag of models with specific weights.
35
+ You should call `apply_model` rather than calling directly the forward here for
36
+ optimal performance.
37
+
38
+ Args:
39
+ models (list[nn.Module]): list of Demucs/HDemucs models.
40
+ weights (list[list[float]]): list of weights. If None, assumed to
41
+ be all ones, otherwise it should be a list of N list (N number of models),
42
+ each containing S floats (S number of sources).
43
+ segment (None or float): overrides the `segment` attribute of each model
44
+ (this is performed inplace, be careful if you reuse the models passed).
45
+ """
46
+
47
+ super().__init__()
48
+ assert len(models) > 0
49
+ first = models[0]
50
+ for other in models:
51
+ assert other.sources == first.sources
52
+ assert other.samplerate == first.samplerate
53
+ assert other.audio_channels == first.audio_channels
54
+ if segment is not None:
55
+ other.segment = segment
56
+
57
+ self.audio_channels = first.audio_channels
58
+ self.samplerate = first.samplerate
59
+ self.sources = first.sources
60
+ self.models = nn.ModuleList(models)
61
+
62
+ if weights is None:
63
+ weights = [[1. for _ in first.sources] for _ in models]
64
+ else:
65
+ assert len(weights) == len(models)
66
+ for weight in weights:
67
+ assert len(weight) == len(first.sources)
68
+ self.weights = weights
69
+
70
+ def forward(self, x):
71
+ raise NotImplementedError("Call `apply_model` on this.")
72
+
73
+ class TensorChunk:
74
+ def __init__(self, tensor, offset=0, length=None):
75
+ total_length = tensor.shape[-1]
76
+ assert offset >= 0
77
+ assert offset < total_length
78
+
79
+ if length is None:
80
+ length = total_length - offset
81
+ else:
82
+ length = min(total_length - offset, length)
83
+
84
+ if isinstance(tensor, TensorChunk):
85
+ self.tensor = tensor.tensor
86
+ self.offset = offset + tensor.offset
87
+ else:
88
+ self.tensor = tensor
89
+ self.offset = offset
90
+ self.length = length
91
+ self.device = tensor.device
92
+
93
+ @property
94
+ def shape(self):
95
+ shape = list(self.tensor.shape)
96
+ shape[-1] = self.length
97
+ return shape
98
+
99
+ def padded(self, target_length):
100
+ delta = target_length - self.length
101
+ total_length = self.tensor.shape[-1]
102
+ assert delta >= 0
103
+
104
+ start = self.offset - delta // 2
105
+ end = start + target_length
106
+
107
+ correct_start = max(0, start)
108
+ correct_end = min(total_length, end)
109
+
110
+ pad_left = correct_start - start
111
+ pad_right = end - correct_end
112
+
113
+ out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
114
+ assert out.shape[-1] == target_length
115
+ return out
116
+
117
+ def tensor_chunk(tensor_or_chunk):
118
+ if isinstance(tensor_or_chunk, TensorChunk):
119
+ return tensor_or_chunk
120
+ else:
121
+ assert isinstance(tensor_or_chunk, th.Tensor)
122
+ return TensorChunk(tensor_or_chunk)
123
+
124
+ def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1., static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None):
125
+ """
126
+ Apply model to a given mixture.
127
+
128
+ Args:
129
+ shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
130
+ and apply the oppositve shift to the output. This is repeated `shifts` time and
131
+ all predictions are averaged. This effectively makes the model time equivariant
132
+ and improves SDR by up to 0.2 points.
133
+ split (bool): if True, the input will be broken down in 8 seconds extracts
134
+ and predictions will be performed individually on each and concatenated.
135
+ Useful for model with large memory footprint like Tasnet.
136
+ progress (bool): if True, show a progress bar (requires split=True)
137
+ device (torch.device, str, or None): if provided, device on which to
138
+ execute the computation, otherwise `mix.device` is assumed.
139
+ When `device` is different from `mix.device`, only local computations will
140
+ be on `device`, while the entire tracks will be stored on `mix.device`.
141
+ """
142
+
143
+ global fut_length
144
+ global bag_num
145
+ global prog_bar
146
+
147
+ if device is None:
148
+ device = mix.device
149
+ else:
150
+ device = th.device(device)
151
+ if pool is None:
152
+ if num_workers > 0 and device.type == 'cpu':
153
+ pool = ThreadPoolExecutor(num_workers)
154
+ else:
155
+ pool = DummyPoolExecutor()
156
+
157
+ kwargs = {
158
+ 'shifts': shifts,
159
+ 'split': split,
160
+ 'overlap': overlap,
161
+ 'transition_power': transition_power,
162
+ 'progress': progress,
163
+ 'device': device,
164
+ 'pool': pool,
165
+ 'set_progress_bar': set_progress_bar,
166
+ 'static_shifts': static_shifts,
167
+ }
168
+
169
+ if isinstance(model, BagOfModels):
170
+ # Special treatment for bag of model.
171
+ # We explicitely apply multiple times `apply_model` so that the random shifts
172
+ # are different for each model.
173
+
174
+ estimates = 0
175
+ totals = [0] * len(model.sources)
176
+ bag_num = len(model.models)
177
+ fut_length = 0
178
+ prog_bar = 0
179
+ current_model = 0 #(bag_num + 1)
180
+ for sub_model, weight in zip(model.models, model.weights):
181
+ original_model_device = next(iter(sub_model.parameters())).device
182
+ sub_model.to(device)
183
+ fut_length += fut_length
184
+ current_model += 1
185
+ out = apply_model(sub_model, mix, **kwargs)
186
+ sub_model.to(original_model_device)
187
+ for k, inst_weight in enumerate(weight):
188
+ out[:, k, :, :] *= inst_weight
189
+ totals[k] += inst_weight
190
+ estimates += out
191
+ del out
192
+
193
+ for k in range(estimates.shape[1]):
194
+ estimates[:, k, :, :] /= totals[k]
195
+ return estimates
196
+
197
+ model.to(device)
198
+ model.eval()
199
+ assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
200
+ batch, channels, length = mix.shape
201
+
202
+ if shifts:
203
+ kwargs['shifts'] = 0
204
+ max_shift = int(0.5 * model.samplerate)
205
+ mix = tensor_chunk(mix)
206
+ padded_mix = mix.padded(length + 2 * max_shift)
207
+ out = 0
208
+ for _ in range(shifts):
209
+ offset = random.randint(0, max_shift)
210
+ shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
211
+ shifted_out = apply_model(model, shifted, **kwargs)
212
+ out += shifted_out[..., max_shift - offset:]
213
+ out /= shifts
214
+ return out
215
+ elif split:
216
+ kwargs['split'] = False
217
+ out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
218
+ sum_weight = th.zeros(length, device=mix.device)
219
+ segment = int(model.samplerate * model.segment)
220
+ stride = int((1 - overlap) * segment)
221
+ offsets = range(0, length, stride)
222
+ scale = float(format(stride / model.samplerate, ".2f"))
223
+ # We start from a triangle shaped weight, with maximal weight in the middle
224
+ # of the segment. Then we normalize and take to the power `transition_power`.
225
+ # Large values of transition power will lead to sharper transitions.
226
+ weight = th.cat([th.arange(1, segment // 2 + 1, device=device),
227
+ th.arange(segment - segment // 2, 0, -1, device=device)])
228
+ assert len(weight) == segment
229
+ # If the overlap < 50%, this will translate to linear transition when
230
+ # transition_power is 1.
231
+ weight = (weight / weight.max())**transition_power
232
+ futures = []
233
+ for offset in offsets:
234
+ chunk = TensorChunk(mix, offset, segment)
235
+ future = pool.submit(apply_model, model, chunk, **kwargs)
236
+ futures.append((future, offset))
237
+ offset += segment
238
+ if progress:
239
+ futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
240
+ for future, offset in futures:
241
+ if set_progress_bar:
242
+ fut_length = (len(futures) * bag_num * static_shifts)
243
+ prog_bar += 1
244
+ set_progress_bar(0.1, (0.8/fut_length*prog_bar))
245
+ chunk_out = future.result()
246
+ chunk_length = chunk_out.shape[-1]
247
+ out[..., offset:offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device)
248
+ sum_weight[offset:offset + segment] += weight[:chunk_length].to(mix.device)
249
+ assert sum_weight.min() > 0
250
+ out /= sum_weight
251
+ return out
252
+ else:
253
+ if hasattr(model, 'valid_length'):
254
+ valid_length = model.valid_length(length)
255
+ else:
256
+ valid_length = length
257
+ mix = tensor_chunk(mix)
258
+ padded_mix = mix.padded(valid_length).to(device)
259
+ with th.no_grad():
260
+ out = model(padded_mix)
261
+ return center_trim(out, length)
262
+
263
+ def demucs_segments(demucs_segment, demucs_model):
264
+
265
+ if demucs_segment == 'Default':
266
+ segment = None
267
+ if isinstance(demucs_model, BagOfModels):
268
+ if segment is not None:
269
+ for sub in demucs_model.models:
270
+ sub.segment = segment
271
+ else:
272
+ if segment is not None:
273
+ sub.segment = segment
274
+ else:
275
+ try:
276
+ segment = int(demucs_segment)
277
+ if isinstance(demucs_model, BagOfModels):
278
+ if segment is not None:
279
+ for sub in demucs_model.models:
280
+ sub.segment = segment
281
+ else:
282
+ if segment is not None:
283
+ sub.segment = segment
284
+ except:
285
+ segment = None
286
+ if isinstance(demucs_model, BagOfModels):
287
+ if segment is not None:
288
+ for sub in demucs_model.models:
289
+ sub.segment = segment
290
+ else:
291
+ if segment is not None:
292
+ sub.segment = segment
293
+
294
+ return demucs_model
demucs/demucs.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import typing as tp
9
+
10
+ import julius
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+ from .states import capture_init
16
+ from .utils import center_trim, unfold
17
+
18
+
19
+ class BLSTM(nn.Module):
20
+ """
21
+ BiLSTM with same hidden units as input dim.
22
+ If `max_steps` is not None, input will be splitting in overlapping
23
+ chunks and the LSTM applied separately on each chunk.
24
+ """
25
+ def __init__(self, dim, layers=1, max_steps=None, skip=False):
26
+ super().__init__()
27
+ assert max_steps is None or max_steps % 4 == 0
28
+ self.max_steps = max_steps
29
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
30
+ self.linear = nn.Linear(2 * dim, dim)
31
+ self.skip = skip
32
+
33
+ def forward(self, x):
34
+ B, C, T = x.shape
35
+ y = x
36
+ framed = False
37
+ if self.max_steps is not None and T > self.max_steps:
38
+ width = self.max_steps
39
+ stride = width // 2
40
+ frames = unfold(x, width, stride)
41
+ nframes = frames.shape[2]
42
+ framed = True
43
+ x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
44
+
45
+ x = x.permute(2, 0, 1)
46
+
47
+ x = self.lstm(x)[0]
48
+ x = self.linear(x)
49
+ x = x.permute(1, 2, 0)
50
+ if framed:
51
+ out = []
52
+ frames = x.reshape(B, -1, C, width)
53
+ limit = stride // 2
54
+ for k in range(nframes):
55
+ if k == 0:
56
+ out.append(frames[:, k, :, :-limit])
57
+ elif k == nframes - 1:
58
+ out.append(frames[:, k, :, limit:])
59
+ else:
60
+ out.append(frames[:, k, :, limit:-limit])
61
+ out = torch.cat(out, -1)
62
+ out = out[..., :T]
63
+ x = out
64
+ if self.skip:
65
+ x = x + y
66
+ return x
67
+
68
+
69
+ def rescale_conv(conv, reference):
70
+ """Rescale initial weight scale. It is unclear why it helps but it certainly does.
71
+ """
72
+ std = conv.weight.std().detach()
73
+ scale = (std / reference)**0.5
74
+ conv.weight.data /= scale
75
+ if conv.bias is not None:
76
+ conv.bias.data /= scale
77
+
78
+
79
+ def rescale_module(module, reference):
80
+ for sub in module.modules():
81
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
82
+ rescale_conv(sub, reference)
83
+
84
+
85
+ class LayerScale(nn.Module):
86
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
87
+ This rescales diagonaly residual outputs close to 0 initially, then learnt.
88
+ """
89
+ def __init__(self, channels: int, init: float = 0):
90
+ super().__init__()
91
+ self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
92
+ self.scale.data[:] = init
93
+
94
+ def forward(self, x):
95
+ return self.scale[:, None] * x
96
+
97
+
98
+ class DConv(nn.Module):
99
+ """
100
+ New residual branches in each encoder layer.
101
+ This alternates dilated convolutions, potentially with LSTMs and attention.
102
+ Also before entering each residual branch, dimension is projected on a smaller subspace,
103
+ e.g. of dim `channels // compress`.
104
+ """
105
+ def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
106
+ norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
107
+ kernel=3, dilate=True):
108
+ """
109
+ Args:
110
+ channels: input/output channels for residual branch.
111
+ compress: amount of channel compression inside the branch.
112
+ depth: number of layers in the residual branch. Each layer has its own
113
+ projection, and potentially LSTM and attention.
114
+ init: initial scale for LayerNorm.
115
+ norm: use GroupNorm.
116
+ attn: use LocalAttention.
117
+ heads: number of heads for the LocalAttention.
118
+ ndecay: number of decay controls in the LocalAttention.
119
+ lstm: use LSTM.
120
+ gelu: Use GELU activation.
121
+ kernel: kernel size for the (dilated) convolutions.
122
+ dilate: if true, use dilation, increasing with the depth.
123
+ """
124
+
125
+ super().__init__()
126
+ assert kernel % 2 == 1
127
+ self.channels = channels
128
+ self.compress = compress
129
+ self.depth = abs(depth)
130
+ dilate = depth > 0
131
+
132
+ norm_fn: tp.Callable[[int], nn.Module]
133
+ norm_fn = lambda d: nn.Identity() # noqa
134
+ if norm:
135
+ norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
136
+
137
+ hidden = int(channels / compress)
138
+
139
+ act: tp.Type[nn.Module]
140
+ if gelu:
141
+ act = nn.GELU
142
+ else:
143
+ act = nn.ReLU
144
+
145
+ self.layers = nn.ModuleList([])
146
+ for d in range(self.depth):
147
+ dilation = 2 ** d if dilate else 1
148
+ padding = dilation * (kernel // 2)
149
+ mods = [
150
+ nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
151
+ norm_fn(hidden), act(),
152
+ nn.Conv1d(hidden, 2 * channels, 1),
153
+ norm_fn(2 * channels), nn.GLU(1),
154
+ LayerScale(channels, init),
155
+ ]
156
+ if attn:
157
+ mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
158
+ if lstm:
159
+ mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
160
+ layer = nn.Sequential(*mods)
161
+ self.layers.append(layer)
162
+
163
+ def forward(self, x):
164
+ for layer in self.layers:
165
+ x = x + layer(x)
166
+ return x
167
+
168
+
169
+ class LocalState(nn.Module):
170
+ """Local state allows to have attention based only on data (no positional embedding),
171
+ but while setting a constraint on the time window (e.g. decaying penalty term).
172
+
173
+ Also a failed experiments with trying to provide some frequency based attention.
174
+ """
175
+ def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
176
+ super().__init__()
177
+ assert channels % heads == 0, (channels, heads)
178
+ self.heads = heads
179
+ self.nfreqs = nfreqs
180
+ self.ndecay = ndecay
181
+ self.content = nn.Conv1d(channels, channels, 1)
182
+ self.query = nn.Conv1d(channels, channels, 1)
183
+ self.key = nn.Conv1d(channels, channels, 1)
184
+ if nfreqs:
185
+ self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
186
+ if ndecay:
187
+ self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
188
+ # Initialize decay close to zero (there is a sigmoid), for maximum initial window.
189
+ self.query_decay.weight.data *= 0.01
190
+ assert self.query_decay.bias is not None # stupid type checker
191
+ self.query_decay.bias.data[:] = -2
192
+ self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
193
+
194
+ def forward(self, x):
195
+ B, C, T = x.shape
196
+ heads = self.heads
197
+ indexes = torch.arange(T, device=x.device, dtype=x.dtype)
198
+ # left index are keys, right index are queries
199
+ delta = indexes[:, None] - indexes[None, :]
200
+
201
+ queries = self.query(x).view(B, heads, -1, T)
202
+ keys = self.key(x).view(B, heads, -1, T)
203
+ # t are keys, s are queries
204
+ dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
205
+ dots /= keys.shape[2]**0.5
206
+ if self.nfreqs:
207
+ periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
208
+ freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
209
+ freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
210
+ dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
211
+ if self.ndecay:
212
+ decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
213
+ decay_q = self.query_decay(x).view(B, heads, -1, T)
214
+ decay_q = torch.sigmoid(decay_q) / 2
215
+ decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
216
+ dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
217
+
218
+ # Kill self reference.
219
+ dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
220
+ weights = torch.softmax(dots, dim=2)
221
+
222
+ content = self.content(x).view(B, heads, -1, T)
223
+ result = torch.einsum("bhts,bhct->bhcs", weights, content)
224
+ if self.nfreqs:
225
+ time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
226
+ result = torch.cat([result, time_sig], 2)
227
+ result = result.reshape(B, -1, T)
228
+ return x + self.proj(result)
229
+
230
+
231
+ class Demucs(nn.Module):
232
+ @capture_init
233
+ def __init__(self,
234
+ sources,
235
+ # Channels
236
+ audio_channels=2,
237
+ channels=64,
238
+ growth=2.,
239
+ # Main structure
240
+ depth=6,
241
+ rewrite=True,
242
+ lstm_layers=0,
243
+ # Convolutions
244
+ kernel_size=8,
245
+ stride=4,
246
+ context=1,
247
+ # Activations
248
+ gelu=True,
249
+ glu=True,
250
+ # Normalization
251
+ norm_starts=4,
252
+ norm_groups=4,
253
+ # DConv residual branch
254
+ dconv_mode=1,
255
+ dconv_depth=2,
256
+ dconv_comp=4,
257
+ dconv_attn=4,
258
+ dconv_lstm=4,
259
+ dconv_init=1e-4,
260
+ # Pre/post processing
261
+ normalize=True,
262
+ resample=True,
263
+ # Weight init
264
+ rescale=0.1,
265
+ # Metadata
266
+ samplerate=44100,
267
+ segment=4 * 10):
268
+ """
269
+ Args:
270
+ sources (list[str]): list of source names
271
+ audio_channels (int): stereo or mono
272
+ channels (int): first convolution channels
273
+ depth (int): number of encoder/decoder layers
274
+ growth (float): multiply (resp divide) number of channels by that
275
+ for each layer of the encoder (resp decoder)
276
+ depth (int): number of layers in the encoder and in the decoder.
277
+ rewrite (bool): add 1x1 convolution to each layer.
278
+ lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
279
+ by default, as this is now replaced by the smaller and faster small LSTMs
280
+ in the DConv branches.
281
+ kernel_size (int): kernel size for convolutions
282
+ stride (int): stride for convolutions
283
+ context (int): kernel size of the convolution in the
284
+ decoder before the transposed convolution. If > 1,
285
+ will provide some context from neighboring time steps.
286
+ gelu: use GELU activation function.
287
+ glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
288
+ norm_starts: layer at which group norm starts being used.
289
+ decoder layers are numbered in reverse order.
290
+ norm_groups: number of groups for group norm.
291
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
292
+ dconv_depth: depth of residual DConv branch.
293
+ dconv_comp: compression of DConv branch.
294
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
295
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
296
+ dconv_init: initial scale for the DConv branch LayerScale.
297
+ normalize (bool): normalizes the input audio on the fly, and scales back
298
+ the output by the same amount.
299
+ resample (bool): upsample x2 the input and downsample /2 the output.
300
+ rescale (int): rescale initial weights of convolutions
301
+ to get their standard deviation closer to `rescale`.
302
+ samplerate (int): stored as meta information for easing
303
+ future evaluations of the model.
304
+ segment (float): duration of the chunks of audio to ideally evaluate the model on.
305
+ This is used by `demucs.apply.apply_model`.
306
+ """
307
+
308
+ super().__init__()
309
+ self.audio_channels = audio_channels
310
+ self.sources = sources
311
+ self.kernel_size = kernel_size
312
+ self.context = context
313
+ self.stride = stride
314
+ self.depth = depth
315
+ self.resample = resample
316
+ self.channels = channels
317
+ self.normalize = normalize
318
+ self.samplerate = samplerate
319
+ self.segment = segment
320
+ self.encoder = nn.ModuleList()
321
+ self.decoder = nn.ModuleList()
322
+ self.skip_scales = nn.ModuleList()
323
+
324
+ if glu:
325
+ activation = nn.GLU(dim=1)
326
+ ch_scale = 2
327
+ else:
328
+ activation = nn.ReLU()
329
+ ch_scale = 1
330
+ if gelu:
331
+ act2 = nn.GELU
332
+ else:
333
+ act2 = nn.ReLU
334
+
335
+ in_channels = audio_channels
336
+ padding = 0
337
+ for index in range(depth):
338
+ norm_fn = lambda d: nn.Identity() # noqa
339
+ if index >= norm_starts:
340
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
341
+
342
+ encode = []
343
+ encode += [
344
+ nn.Conv1d(in_channels, channels, kernel_size, stride),
345
+ norm_fn(channels),
346
+ act2(),
347
+ ]
348
+ attn = index >= dconv_attn
349
+ lstm = index >= dconv_lstm
350
+ if dconv_mode & 1:
351
+ encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
352
+ compress=dconv_comp, attn=attn, lstm=lstm)]
353
+ if rewrite:
354
+ encode += [
355
+ nn.Conv1d(channels, ch_scale * channels, 1),
356
+ norm_fn(ch_scale * channels), activation]
357
+ self.encoder.append(nn.Sequential(*encode))
358
+
359
+ decode = []
360
+ if index > 0:
361
+ out_channels = in_channels
362
+ else:
363
+ out_channels = len(self.sources) * audio_channels
364
+ if rewrite:
365
+ decode += [
366
+ nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
367
+ norm_fn(ch_scale * channels), activation]
368
+ if dconv_mode & 2:
369
+ decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
370
+ compress=dconv_comp, attn=attn, lstm=lstm)]
371
+ decode += [nn.ConvTranspose1d(channels, out_channels,
372
+ kernel_size, stride, padding=padding)]
373
+ if index > 0:
374
+ decode += [norm_fn(out_channels), act2()]
375
+ self.decoder.insert(0, nn.Sequential(*decode))
376
+ in_channels = channels
377
+ channels = int(growth * channels)
378
+
379
+ channels = in_channels
380
+ if lstm_layers:
381
+ self.lstm = BLSTM(channels, lstm_layers)
382
+ else:
383
+ self.lstm = None
384
+
385
+ if rescale:
386
+ rescale_module(self, reference=rescale)
387
+
388
+ def valid_length(self, length):
389
+ """
390
+ Return the nearest valid length to use with the model so that
391
+ there is no time steps left over in a convolution, e.g. for all
392
+ layers, size of the input - kernel_size % stride = 0.
393
+
394
+ Note that input are automatically padded if necessary to ensure that the output
395
+ has the same length as the input.
396
+ """
397
+ if self.resample:
398
+ length *= 2
399
+
400
+ for _ in range(self.depth):
401
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
402
+ length = max(1, length)
403
+
404
+ for idx in range(self.depth):
405
+ length = (length - 1) * self.stride + self.kernel_size
406
+
407
+ if self.resample:
408
+ length = math.ceil(length / 2)
409
+ return int(length)
410
+
411
+ def forward(self, mix):
412
+ x = mix
413
+ length = x.shape[-1]
414
+
415
+ if self.normalize:
416
+ mono = mix.mean(dim=1, keepdim=True)
417
+ mean = mono.mean(dim=-1, keepdim=True)
418
+ std = mono.std(dim=-1, keepdim=True)
419
+ x = (x - mean) / (1e-5 + std)
420
+ else:
421
+ mean = 0
422
+ std = 1
423
+
424
+ delta = self.valid_length(length) - length
425
+ x = F.pad(x, (delta // 2, delta - delta // 2))
426
+
427
+ if self.resample:
428
+ x = julius.resample_frac(x, 1, 2)
429
+
430
+ saved = []
431
+ for encode in self.encoder:
432
+ x = encode(x)
433
+ saved.append(x)
434
+
435
+ if self.lstm:
436
+ x = self.lstm(x)
437
+
438
+ for decode in self.decoder:
439
+ skip = saved.pop(-1)
440
+ skip = center_trim(skip, x)
441
+ x = decode(x + skip)
442
+
443
+ if self.resample:
444
+ x = julius.resample_frac(x, 2, 1)
445
+ x = x * std + mean
446
+ x = center_trim(x, length)
447
+ x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
448
+ return x
449
+
450
+ def load_state_dict(self, state, strict=True):
451
+ # fix a mismatch with previous generation Demucs models.
452
+ for idx in range(self.depth):
453
+ for a in ['encoder', 'decoder']:
454
+ for b in ['bias', 'weight']:
455
+ new = f'{a}.{idx}.3.{b}'
456
+ old = f'{a}.{idx}.2.{b}'
457
+ if old in state and new not in state:
458
+ state[new] = state.pop(old)
459
+ super().load_state_dict(state, strict=strict)
demucs/filtering.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch import Tensor
5
+ from torch.utils.data import DataLoader
6
+
7
+ def atan2(y, x):
8
+ r"""Element-wise arctangent function of y/x.
9
+ Returns a new tensor with signed angles in radians.
10
+ It is an alternative implementation of torch.atan2
11
+
12
+ Args:
13
+ y (Tensor): First input tensor
14
+ x (Tensor): Second input tensor [shape=y.shape]
15
+
16
+ Returns:
17
+ Tensor: [shape=y.shape].
18
+ """
19
+ pi = 2 * torch.asin(torch.tensor(1.0))
20
+ x += ((x == 0) & (y == 0)) * 1.0
21
+ out = torch.atan(y / x)
22
+ out += ((y >= 0) & (x < 0)) * pi
23
+ out -= ((y < 0) & (x < 0)) * pi
24
+ out *= 1 - ((y > 0) & (x == 0)) * 1.0
25
+ out += ((y > 0) & (x == 0)) * (pi / 2)
26
+ out *= 1 - ((y < 0) & (x == 0)) * 1.0
27
+ out += ((y < 0) & (x == 0)) * (-pi / 2)
28
+ return out
29
+
30
+
31
+ # Define basic complex operations on torch.Tensor objects whose last dimension
32
+ # consists in the concatenation of the real and imaginary parts.
33
+
34
+
35
+ def _norm(x: torch.Tensor) -> torch.Tensor:
36
+ r"""Computes the norm value of a torch Tensor, assuming that it
37
+ comes as real and imaginary part in its last dimension.
38
+
39
+ Args:
40
+ x (Tensor): Input Tensor of shape [shape=(..., 2)]
41
+
42
+ Returns:
43
+ Tensor: shape as x excluding the last dimension.
44
+ """
45
+ return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2
46
+
47
+
48
+ def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
49
+ """Element-wise multiplication of two complex Tensors described
50
+ through their real and imaginary parts.
51
+ The result is added to the `out` tensor"""
52
+
53
+ # check `out` and allocate it if needed
54
+ target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
55
+ if out is None or out.shape != target_shape:
56
+ out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
57
+ if out is a:
58
+ real_a = a[..., 0]
59
+ out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1])
60
+ out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0])
61
+ else:
62
+ out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1])
63
+ out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0])
64
+ return out
65
+
66
+
67
+ def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
68
+ """Element-wise multiplication of two complex Tensors described
69
+ through their real and imaginary parts
70
+ can work in place in case out is a only"""
71
+ target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
72
+ if out is None or out.shape != target_shape:
73
+ out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
74
+ if out is a:
75
+ real_a = a[..., 0]
76
+ out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1]
77
+ out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0]
78
+ else:
79
+ out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]
80
+ out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]
81
+ return out
82
+
83
+
84
+ def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
85
+ """Element-wise multiplicative inverse of a Tensor with complex
86
+ entries described through their real and imaginary parts.
87
+ can work in place in case out is z"""
88
+ ez = _norm(z)
89
+ if out is None or out.shape != z.shape:
90
+ out = torch.zeros_like(z)
91
+ out[..., 0] = z[..., 0] / ez
92
+ out[..., 1] = -z[..., 1] / ez
93
+ return out
94
+
95
+
96
+ def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor:
97
+ """Element-wise complex conjugate of a Tensor with complex entries
98
+ described through their real and imaginary parts.
99
+ can work in place in case out is z"""
100
+ if out is None or out.shape != z.shape:
101
+ out = torch.zeros_like(z)
102
+ out[..., 0] = z[..., 0]
103
+ out[..., 1] = -z[..., 1]
104
+ return out
105
+
106
+
107
+ def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
108
+ """
109
+ Invert 1x1 or 2x2 matrices
110
+
111
+ Will generate errors if the matrices are singular: user must handle this
112
+ through his own regularization schemes.
113
+
114
+ Args:
115
+ M (Tensor): [shape=(..., nb_channels, nb_channels, 2)]
116
+ matrices to invert: must be square along dimensions -3 and -2
117
+
118
+ Returns:
119
+ invM (Tensor): [shape=M.shape]
120
+ inverses of M
121
+ """
122
+ nb_channels = M.shape[-2]
123
+
124
+ if out is None or out.shape != M.shape:
125
+ out = torch.empty_like(M)
126
+
127
+ if nb_channels == 1:
128
+ # scalar case
129
+ out = _inv(M, out)
130
+ elif nb_channels == 2:
131
+ # two channels case: analytical expression
132
+
133
+ # first compute the determinent
134
+ det = _mul(M[..., 0, 0, :], M[..., 1, 1, :])
135
+ det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :])
136
+ # invert it
137
+ invDet = _inv(det)
138
+
139
+ # then fill out the matrix with the inverse
140
+ out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :])
141
+ out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :])
142
+ out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :])
143
+ out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :])
144
+ else:
145
+ raise Exception("Only 2 channels are supported for the torch version.")
146
+ return out
147
+
148
+
149
+ # Now define the signal-processing low-level functions used by the Separator
150
+
151
+
152
+ def expectation_maximization(
153
+ y: torch.Tensor,
154
+ x: torch.Tensor,
155
+ iterations: int = 2,
156
+ eps: float = 1e-10,
157
+ batch_size: int = 200,
158
+ ):
159
+ r"""Expectation maximization algorithm, for refining source separation
160
+ estimates.
161
+
162
+ This algorithm allows to make source separation results better by
163
+ enforcing multichannel consistency for the estimates. This usually means
164
+ a better perceptual quality in terms of spatial artifacts.
165
+
166
+ The implementation follows the details presented in [1]_, taking
167
+ inspiration from the original EM algorithm proposed in [2]_ and its
168
+ weighted refinement proposed in [3]_, [4]_.
169
+ It works by iteratively:
170
+
171
+ * Re-estimate source parameters (power spectral densities and spatial
172
+ covariance matrices) through :func:`get_local_gaussian_model`.
173
+
174
+ * Separate again the mixture with the new parameters by first computing
175
+ the new modelled mixture covariance matrices with :func:`get_mix_model`,
176
+ prepare the Wiener filters through :func:`wiener_gain` and apply them
177
+ with :func:`apply_filter``.
178
+
179
+ References
180
+ ----------
181
+ .. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
182
+ N. Takahashi and Y. Mitsufuji, "Improving music source separation based
183
+ on deep neural networks through data augmentation and network
184
+ blending." 2017 IEEE International Conference on Acoustics, Speech
185
+ and Signal Processing (ICASSP). IEEE, 2017.
186
+
187
+ .. [2] N.Q. Duong and E. Vincent and R.Gribonval. "Under-determined
188
+ reverberant audio source separation using a full-rank spatial
189
+ covariance model." IEEE Transactions on Audio, Speech, and Language
190
+ Processing 18.7 (2010): 1830-1840.
191
+
192
+ .. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
193
+ separation with deep neural networks." IEEE/ACM Transactions on Audio,
194
+ Speech, and Language Processing 24.9 (2016): 1652-1664.
195
+
196
+ .. [4] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
197
+ separation with deep neural networks." 2016 24th European Signal
198
+ Processing Conference (EUSIPCO). IEEE, 2016.
199
+
200
+ .. [5] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
201
+ source separation." IEEE Transactions on Signal Processing
202
+ 62.16 (2014): 4298-4310.
203
+
204
+ Args:
205
+ y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
206
+ initial estimates for the sources
207
+ x (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2)]
208
+ complex STFT of the mixture signal
209
+ iterations (int): [scalar]
210
+ number of iterations for the EM algorithm.
211
+ eps (float or None): [scalar]
212
+ The epsilon value to use for regularization and filters.
213
+
214
+ Returns:
215
+ y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)]
216
+ estimated sources after iterations
217
+ v (Tensor): [shape=(nb_frames, nb_bins, nb_sources)]
218
+ estimated power spectral densities
219
+ R (Tensor): [shape=(nb_bins, nb_channels, nb_channels, 2, nb_sources)]
220
+ estimated spatial covariance matrices
221
+
222
+ Notes:
223
+ * You need an initial estimate for the sources to apply this
224
+ algorithm. This is precisely what the :func:`wiener` function does.
225
+ * This algorithm *is not* an implementation of the "exact" EM
226
+ proposed in [1]_. In particular, it does compute the posterior
227
+ covariance matrices the same (exact) way. Instead, it uses the
228
+ simplified approximate scheme initially proposed in [5]_ and further
229
+ refined in [3]_, [4]_, that boils down to just take the empirical
230
+ covariance of the recent source estimates, followed by a weighted
231
+ average for the update of the spatial covariance matrix. It has been
232
+ empirically demonstrated that this simplified algorithm is more
233
+ robust for music separation.
234
+
235
+ Warning:
236
+ It is *very* important to make sure `x.dtype` is `torch.float64`
237
+ if you want double precision, because this function will **not**
238
+ do such conversion for you from `torch.complex32`, in case you want the
239
+ smaller RAM usage on purpose.
240
+
241
+ It is usually always better in terms of quality to have double
242
+ precision, by e.g. calling :func:`expectation_maximization`
243
+ with ``x.to(torch.float64)``.
244
+ """
245
+ # dimensions
246
+ (nb_frames, nb_bins, nb_channels) = x.shape[:-1]
247
+ nb_sources = y.shape[-1]
248
+
249
+ regularization = torch.cat(
250
+ (
251
+ torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None],
252
+ torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device),
253
+ ),
254
+ dim=2,
255
+ )
256
+ regularization = torch.sqrt(torch.as_tensor(eps)) * (
257
+ regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1))
258
+ )
259
+
260
+ # allocate the spatial covariance matrices
261
+ R = [
262
+ torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device)
263
+ for j in range(nb_sources)
264
+ ]
265
+ weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device)
266
+
267
+ v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device)
268
+ for it in range(iterations):
269
+ # constructing the mixture covariance matrix. Doing it with a loop
270
+ # to avoid storing anytime in RAM the whole 6D tensor
271
+
272
+ # update the PSD as the average spectrogram over channels
273
+ v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2)
274
+
275
+ # update spatial covariance matrices (weighted update)
276
+ for j in range(nb_sources):
277
+ R[j] = torch.tensor(0.0, device=x.device)
278
+ weight = torch.tensor(eps, device=x.device)
279
+ pos: int = 0
280
+ batch_size = batch_size if batch_size else nb_frames
281
+ while pos < nb_frames:
282
+ t = torch.arange(pos, min(nb_frames, pos + batch_size))
283
+ pos = int(t[-1]) + 1
284
+
285
+ R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0)
286
+ weight = weight + torch.sum(v[t, ..., j], dim=0)
287
+ R[j] = R[j] / weight[..., None, None, None]
288
+ weight = torch.zeros_like(weight)
289
+
290
+ # cloning y if we track gradient, because we're going to update it
291
+ if y.requires_grad:
292
+ y = y.clone()
293
+
294
+ pos = 0
295
+ while pos < nb_frames:
296
+ t = torch.arange(pos, min(nb_frames, pos + batch_size))
297
+ pos = int(t[-1]) + 1
298
+
299
+ y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype)
300
+
301
+ # compute mix covariance matrix
302
+ Cxx = regularization
303
+ for j in range(nb_sources):
304
+ Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone())
305
+
306
+ # invert it
307
+ inv_Cxx = _invert(Cxx)
308
+
309
+ # separate the sources
310
+ for j in range(nb_sources):
311
+
312
+ # create a wiener gain for this source
313
+ gain = torch.zeros_like(inv_Cxx)
314
+
315
+ # computes multichannel Wiener gain as v_j R_j inv_Cxx
316
+ indices = torch.cartesian_prod(
317
+ torch.arange(nb_channels),
318
+ torch.arange(nb_channels),
319
+ torch.arange(nb_channels),
320
+ )
321
+ for index in indices:
322
+ gain[:, :, index[0], index[1], :] = _mul_add(
323
+ R[j][None, :, index[0], index[2], :].clone(),
324
+ inv_Cxx[:, :, index[2], index[1], :],
325
+ gain[:, :, index[0], index[1], :],
326
+ )
327
+ gain = gain * v[t, ..., None, None, None, j]
328
+
329
+ # apply it to the mixture
330
+ for i in range(nb_channels):
331
+ y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j])
332
+
333
+ return y, v, R
334
+
335
+
336
+ def wiener(
337
+ targets_spectrograms: torch.Tensor,
338
+ mix_stft: torch.Tensor,
339
+ iterations: int = 1,
340
+ softmask: bool = False,
341
+ residual: bool = False,
342
+ scale_factor: float = 10.0,
343
+ eps: float = 1e-10,
344
+ ):
345
+ """Wiener-based separation for multichannel audio.
346
+
347
+ The method uses the (possibly multichannel) spectrograms of the
348
+ sources to separate the (complex) Short Term Fourier Transform of the
349
+ mix. Separation is done in a sequential way by:
350
+
351
+ * Getting an initial estimate. This can be done in two ways: either by
352
+ directly using the spectrograms with the mixture phase, or
353
+ by using a softmasking strategy. This initial phase is controlled
354
+ by the `softmask` flag.
355
+
356
+ * If required, adding an additional residual target as the mix minus
357
+ all targets.
358
+
359
+ * Refinining these initial estimates through a call to
360
+ :func:`expectation_maximization` if the number of iterations is nonzero.
361
+
362
+ This implementation also allows to specify the epsilon value used for
363
+ regularization. It is based on [1]_, [2]_, [3]_, [4]_.
364
+
365
+ References
366
+ ----------
367
+ .. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and
368
+ N. Takahashi and Y. Mitsufuji, "Improving music source separation based
369
+ on deep neural networks through data augmentation and network
370
+ blending." 2017 IEEE International Conference on Acoustics, Speech
371
+ and Signal Processing (ICASSP). IEEE, 2017.
372
+
373
+ .. [2] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source
374
+ separation with deep neural networks." IEEE/ACM Transactions on Audio,
375
+ Speech, and Language Processing 24.9 (2016): 1652-1664.
376
+
377
+ .. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music
378
+ separation with deep neural networks." 2016 24th European Signal
379
+ Processing Conference (EUSIPCO). IEEE, 2016.
380
+
381
+ .. [4] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for
382
+ source separation." IEEE Transactions on Signal Processing
383
+ 62.16 (2014): 4298-4310.
384
+
385
+ Args:
386
+ targets_spectrograms (Tensor): spectrograms of the sources
387
+ [shape=(nb_frames, nb_bins, nb_channels, nb_sources)].
388
+ This is a nonnegative tensor that is
389
+ usually the output of the actual separation method of the user. The
390
+ spectrograms may be mono, but they need to be 4-dimensional in all
391
+ cases.
392
+ mix_stft (Tensor): [shape=(nb_frames, nb_bins, nb_channels, complex=2)]
393
+ STFT of the mixture signal.
394
+ iterations (int): [scalar]
395
+ number of iterations for the EM algorithm
396
+ softmask (bool): Describes how the initial estimates are obtained.
397
+ * if `False`, then the mixture phase will directly be used with the
398
+ spectrogram as initial estimates.
399
+ * if `True`, initial estimates are obtained by multiplying the
400
+ complex mix element-wise with the ratio of each target spectrogram
401
+ with the sum of them all. This strategy is better if the model are
402
+ not really good, and worse otherwise.
403
+ residual (bool): if `True`, an additional target is created, which is
404
+ equal to the mixture minus the other targets, before application of
405
+ expectation maximization
406
+ eps (float): Epsilon value to use for computing the separations.
407
+ This is used whenever division with a model energy is
408
+ performed, i.e. when softmasking and when iterating the EM.
409
+ It can be understood as the energy of the additional white noise
410
+ that is taken out when separating.
411
+
412
+ Returns:
413
+ Tensor: shape=(nb_frames, nb_bins, nb_channels, complex=2, nb_sources)
414
+ STFT of estimated sources
415
+
416
+ Notes:
417
+ * Be careful that you need *magnitude spectrogram estimates* for the
418
+ case `softmask==False`.
419
+ * `softmask=False` is recommended
420
+ * The epsilon value will have a huge impact on performance. If it's
421
+ large, only the parts of the signal with a significant energy will
422
+ be kept in the sources. This epsilon then directly controls the
423
+ energy of the reconstruction error.
424
+
425
+ Warning:
426
+ As in :func:`expectation_maximization`, we recommend converting the
427
+ mixture `x` to double precision `torch.float64` *before* calling
428
+ :func:`wiener`.
429
+ """
430
+ if softmask:
431
+ # if we use softmask, we compute the ratio mask for all targets and
432
+ # multiply by the mix stft
433
+ y = (
434
+ mix_stft[..., None]
435
+ * (
436
+ targets_spectrograms
437
+ / (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype))
438
+ )[..., None, :]
439
+ )
440
+ else:
441
+ # otherwise, we just multiply the targets spectrograms with mix phase
442
+ # we tacitly assume that we have magnitude estimates.
443
+ angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None]
444
+ nb_sources = targets_spectrograms.shape[-1]
445
+ y = torch.zeros(
446
+ mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device
447
+ )
448
+ y[..., 0, :] = targets_spectrograms * torch.cos(angle)
449
+ y[..., 1, :] = targets_spectrograms * torch.sin(angle)
450
+
451
+ if residual:
452
+ # if required, adding an additional target as the mix minus
453
+ # available targets
454
+ y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1)
455
+
456
+ if iterations == 0:
457
+ return y
458
+
459
+ # we need to refine the estimates. Scales down the estimates for
460
+ # numerical stability
461
+ max_abs = torch.max(
462
+ torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device),
463
+ torch.sqrt(_norm(mix_stft)).max() / scale_factor,
464
+ )
465
+
466
+ mix_stft = mix_stft / max_abs
467
+ y = y / max_abs
468
+
469
+ # call expectation maximization
470
+ y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0]
471
+
472
+ # scale estimates up again
473
+ y = y * max_abs
474
+ return y
475
+
476
+
477
+ def _covariance(y_j):
478
+ """
479
+ Compute the empirical covariance for a source.
480
+
481
+ Args:
482
+ y_j (Tensor): complex stft of the source.
483
+ [shape=(nb_frames, nb_bins, nb_channels, 2)].
484
+
485
+ Returns:
486
+ Cj (Tensor): [shape=(nb_frames, nb_bins, nb_channels, nb_channels, 2)]
487
+ just y_j * conj(y_j.T): empirical covariance for each TF bin.
488
+ """
489
+ (nb_frames, nb_bins, nb_channels) = y_j.shape[:-1]
490
+ Cj = torch.zeros(
491
+ (nb_frames, nb_bins, nb_channels, nb_channels, 2),
492
+ dtype=y_j.dtype,
493
+ device=y_j.device,
494
+ )
495
+ indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels))
496
+ for index in indices:
497
+ Cj[:, :, index[0], index[1], :] = _mul_add(
498
+ y_j[:, :, index[0], :],
499
+ _conj(y_j[:, :, index[1], :]),
500
+ Cj[:, :, index[0], index[1], :],
501
+ )
502
+ return Cj
demucs/hdemucs.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ This code contains the spectrogram and Hybrid version of Demucs.
8
+ """
9
+ from copy import deepcopy
10
+ import math
11
+ import typing as tp
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+ from .filtering import wiener
16
+ from .demucs import DConv, rescale_module
17
+ from .states import capture_init
18
+ from .spec import spectro, ispectro
19
+
20
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
21
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
22
+ If this is the case, we insert extra 0 padding to the right before the reflection happen."""
23
+ x0 = x
24
+ length = x.shape[-1]
25
+ padding_left, padding_right = paddings
26
+ if mode == 'reflect':
27
+ max_pad = max(padding_left, padding_right)
28
+ if length <= max_pad:
29
+ extra_pad = max_pad - length + 1
30
+ extra_pad_right = min(padding_right, extra_pad)
31
+ extra_pad_left = extra_pad - extra_pad_right
32
+ paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
33
+ x = F.pad(x, (extra_pad_left, extra_pad_right))
34
+ out = F.pad(x, paddings, mode, value)
35
+ assert out.shape[-1] == length + padding_left + padding_right
36
+ assert (out[..., padding_left: padding_left + length] == x0).all()
37
+ return out
38
+
39
+ class ScaledEmbedding(nn.Module):
40
+ """
41
+ Boost learning rate for embeddings (with `scale`).
42
+ Also, can make embeddings continuous with `smooth`.
43
+ """
44
+ def __init__(self, num_embeddings: int, embedding_dim: int,
45
+ scale: float = 10., smooth=False):
46
+ super().__init__()
47
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
48
+ if smooth:
49
+ weight = torch.cumsum(self.embedding.weight.data, dim=0)
50
+ # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
51
+ weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
52
+ self.embedding.weight.data[:] = weight
53
+ self.embedding.weight.data /= scale
54
+ self.scale = scale
55
+
56
+ @property
57
+ def weight(self):
58
+ return self.embedding.weight * self.scale
59
+
60
+ def forward(self, x):
61
+ out = self.embedding(x) * self.scale
62
+ return out
63
+
64
+
65
+ class HEncLayer(nn.Module):
66
+ def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
67
+ freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,
68
+ rewrite=True):
69
+ """Encoder layer. This used both by the time and the frequency branch.
70
+
71
+ Args:
72
+ chin: number of input channels.
73
+ chout: number of output channels.
74
+ norm_groups: number of groups for group norm.
75
+ empty: used to make a layer with just the first conv. this is used
76
+ before merging the time and freq. branches.
77
+ freq: this is acting on frequencies.
78
+ dconv: insert DConv residual branches.
79
+ norm: use GroupNorm.
80
+ context: context size for the 1x1 conv.
81
+ dconv_kw: list of kwargs for the DConv class.
82
+ pad: pad the input. Padding is done so that the output size is
83
+ always the input size / stride.
84
+ rewrite: add 1x1 conv at the end of the layer.
85
+ """
86
+ super().__init__()
87
+ norm_fn = lambda d: nn.Identity() # noqa
88
+ if norm:
89
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
90
+ if pad:
91
+ pad = kernel_size // 4
92
+ else:
93
+ pad = 0
94
+ klass = nn.Conv1d
95
+ self.freq = freq
96
+ self.kernel_size = kernel_size
97
+ self.stride = stride
98
+ self.empty = empty
99
+ self.norm = norm
100
+ self.pad = pad
101
+ if freq:
102
+ kernel_size = [kernel_size, 1]
103
+ stride = [stride, 1]
104
+ pad = [pad, 0]
105
+ klass = nn.Conv2d
106
+ self.conv = klass(chin, chout, kernel_size, stride, pad)
107
+ if self.empty:
108
+ return
109
+ self.norm1 = norm_fn(chout)
110
+ self.rewrite = None
111
+ if rewrite:
112
+ self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
113
+ self.norm2 = norm_fn(2 * chout)
114
+
115
+ self.dconv = None
116
+ if dconv:
117
+ self.dconv = DConv(chout, **dconv_kw)
118
+
119
+ def forward(self, x, inject=None):
120
+ """
121
+ `inject` is used to inject the result from the time branch into the frequency branch,
122
+ when both have the same stride.
123
+ """
124
+ if not self.freq and x.dim() == 4:
125
+ B, C, Fr, T = x.shape
126
+ x = x.view(B, -1, T)
127
+
128
+ if not self.freq:
129
+ le = x.shape[-1]
130
+ if not le % self.stride == 0:
131
+ x = F.pad(x, (0, self.stride - (le % self.stride)))
132
+ y = self.conv(x)
133
+ if self.empty:
134
+ return y
135
+ if inject is not None:
136
+ assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
137
+ if inject.dim() == 3 and y.dim() == 4:
138
+ inject = inject[:, :, None]
139
+ y = y + inject
140
+ y = F.gelu(self.norm1(y))
141
+ if self.dconv:
142
+ if self.freq:
143
+ B, C, Fr, T = y.shape
144
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
145
+ y = self.dconv(y)
146
+ if self.freq:
147
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
148
+ if self.rewrite:
149
+ z = self.norm2(self.rewrite(y))
150
+ z = F.glu(z, dim=1)
151
+ else:
152
+ z = y
153
+ return z
154
+
155
+
156
+ class MultiWrap(nn.Module):
157
+ """
158
+ Takes one layer and replicate it N times. each replica will act
159
+ on a frequency band. All is done so that if the N replica have the same weights,
160
+ then this is exactly equivalent to applying the original module on all frequencies.
161
+
162
+ This is a bit over-engineered to avoid edge artifacts when splitting
163
+ the frequency bands, but it is possible the naive implementation would work as well...
164
+ """
165
+ def __init__(self, layer, split_ratios):
166
+ """
167
+ Args:
168
+ layer: module to clone, must be either HEncLayer or HDecLayer.
169
+ split_ratios: list of float indicating which ratio to keep for each band.
170
+ """
171
+ super().__init__()
172
+ self.split_ratios = split_ratios
173
+ self.layers = nn.ModuleList()
174
+ self.conv = isinstance(layer, HEncLayer)
175
+ assert not layer.norm
176
+ assert layer.freq
177
+ assert layer.pad
178
+ if not self.conv:
179
+ assert not layer.context_freq
180
+ for k in range(len(split_ratios) + 1):
181
+ lay = deepcopy(layer)
182
+ if self.conv:
183
+ lay.conv.padding = (0, 0)
184
+ else:
185
+ lay.pad = False
186
+ for m in lay.modules():
187
+ if hasattr(m, 'reset_parameters'):
188
+ m.reset_parameters()
189
+ self.layers.append(lay)
190
+
191
+ def forward(self, x, skip=None, length=None):
192
+ B, C, Fr, T = x.shape
193
+
194
+ ratios = list(self.split_ratios) + [1]
195
+ start = 0
196
+ outs = []
197
+ for ratio, layer in zip(ratios, self.layers):
198
+ if self.conv:
199
+ pad = layer.kernel_size // 4
200
+ if ratio == 1:
201
+ limit = Fr
202
+ frames = -1
203
+ else:
204
+ limit = int(round(Fr * ratio))
205
+ le = limit - start
206
+ if start == 0:
207
+ le += pad
208
+ frames = round((le - layer.kernel_size) / layer.stride + 1)
209
+ limit = start + (frames - 1) * layer.stride + layer.kernel_size
210
+ if start == 0:
211
+ limit -= pad
212
+ assert limit - start > 0, (limit, start)
213
+ assert limit <= Fr, (limit, Fr)
214
+ y = x[:, :, start:limit, :]
215
+ if start == 0:
216
+ y = F.pad(y, (0, 0, pad, 0))
217
+ if ratio == 1:
218
+ y = F.pad(y, (0, 0, 0, pad))
219
+ outs.append(layer(y))
220
+ start = limit - layer.kernel_size + layer.stride
221
+ else:
222
+ if ratio == 1:
223
+ limit = Fr
224
+ else:
225
+ limit = int(round(Fr * ratio))
226
+ last = layer.last
227
+ layer.last = True
228
+
229
+ y = x[:, :, start:limit]
230
+ s = skip[:, :, start:limit]
231
+ out, _ = layer(y, s, None)
232
+ if outs:
233
+ outs[-1][:, :, -layer.stride:] += (
234
+ out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1))
235
+ out = out[:, :, layer.stride:]
236
+ if ratio == 1:
237
+ out = out[:, :, :-layer.stride // 2, :]
238
+ if start == 0:
239
+ out = out[:, :, layer.stride // 2:, :]
240
+ outs.append(out)
241
+ layer.last = last
242
+ start = limit
243
+ out = torch.cat(outs, dim=2)
244
+ if not self.conv and not last:
245
+ out = F.gelu(out)
246
+ if self.conv:
247
+ return out
248
+ else:
249
+ return out, None
250
+
251
+
252
+ class HDecLayer(nn.Module):
253
+ def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
254
+ freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
255
+ context_freq=True, rewrite=True):
256
+ """
257
+ Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
258
+ """
259
+ super().__init__()
260
+ norm_fn = lambda d: nn.Identity() # noqa
261
+ if norm:
262
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
263
+ if pad:
264
+ pad = kernel_size // 4
265
+ else:
266
+ pad = 0
267
+ self.pad = pad
268
+ self.last = last
269
+ self.freq = freq
270
+ self.chin = chin
271
+ self.empty = empty
272
+ self.stride = stride
273
+ self.kernel_size = kernel_size
274
+ self.norm = norm
275
+ self.context_freq = context_freq
276
+ klass = nn.Conv1d
277
+ klass_tr = nn.ConvTranspose1d
278
+ if freq:
279
+ kernel_size = [kernel_size, 1]
280
+ stride = [stride, 1]
281
+ klass = nn.Conv2d
282
+ klass_tr = nn.ConvTranspose2d
283
+ self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
284
+ self.norm2 = norm_fn(chout)
285
+ if self.empty:
286
+ return
287
+ self.rewrite = None
288
+ if rewrite:
289
+ if context_freq:
290
+ self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
291
+ else:
292
+ self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
293
+ [0, context])
294
+ self.norm1 = norm_fn(2 * chin)
295
+
296
+ self.dconv = None
297
+ if dconv:
298
+ self.dconv = DConv(chin, **dconv_kw)
299
+
300
+ def forward(self, x, skip, length):
301
+ if self.freq and x.dim() == 3:
302
+ B, C, T = x.shape
303
+ x = x.view(B, self.chin, -1, T)
304
+
305
+ if not self.empty:
306
+ x = x + skip
307
+
308
+ if self.rewrite:
309
+ y = F.glu(self.norm1(self.rewrite(x)), dim=1)
310
+ else:
311
+ y = x
312
+ if self.dconv:
313
+ if self.freq:
314
+ B, C, Fr, T = y.shape
315
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
316
+ y = self.dconv(y)
317
+ if self.freq:
318
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
319
+ else:
320
+ y = x
321
+ assert skip is None
322
+ z = self.norm2(self.conv_tr(y))
323
+ if self.freq:
324
+ if self.pad:
325
+ z = z[..., self.pad:-self.pad, :]
326
+ else:
327
+ z = z[..., self.pad:self.pad + length]
328
+ assert z.shape[-1] == length, (z.shape[-1], length)
329
+ if not self.last:
330
+ z = F.gelu(z)
331
+ return z, y
332
+
333
+
334
+ class HDemucs(nn.Module):
335
+ """
336
+ Spectrogram and hybrid Demucs model.
337
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
338
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
339
+ Frequency layers can still access information across time steps thanks to the DConv residual.
340
+
341
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
342
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
343
+
344
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
345
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
346
+ Open Unmix implementation [Stoter et al. 2019].
347
+
348
+ The loss is always on the temporal domain, by backpropagating through the above
349
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
350
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
351
+ contribution, without changing the one from the waveform, which will lead to worse performance.
352
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
353
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
354
+ hybrid models.
355
+
356
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
357
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
358
+
359
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
360
+ """
361
+ @capture_init
362
+ def __init__(self,
363
+ sources,
364
+ # Channels
365
+ audio_channels=2,
366
+ channels=48,
367
+ channels_time=None,
368
+ growth=2,
369
+ # STFT
370
+ nfft=4096,
371
+ wiener_iters=0,
372
+ end_iters=0,
373
+ wiener_residual=False,
374
+ cac=True,
375
+ # Main structure
376
+ depth=6,
377
+ rewrite=True,
378
+ hybrid=True,
379
+ hybrid_old=False,
380
+ # Frequency branch
381
+ multi_freqs=None,
382
+ multi_freqs_depth=2,
383
+ freq_emb=0.2,
384
+ emb_scale=10,
385
+ emb_smooth=True,
386
+ # Convolutions
387
+ kernel_size=8,
388
+ time_stride=2,
389
+ stride=4,
390
+ context=1,
391
+ context_enc=0,
392
+ # Normalization
393
+ norm_starts=4,
394
+ norm_groups=4,
395
+ # DConv residual branch
396
+ dconv_mode=1,
397
+ dconv_depth=2,
398
+ dconv_comp=4,
399
+ dconv_attn=4,
400
+ dconv_lstm=4,
401
+ dconv_init=1e-4,
402
+ # Weight init
403
+ rescale=0.1,
404
+ # Metadata
405
+ samplerate=44100,
406
+ segment=4 * 10):
407
+
408
+ """
409
+ Args:
410
+ sources (list[str]): list of source names.
411
+ audio_channels (int): input/output audio channels.
412
+ channels (int): initial number of hidden channels.
413
+ channels_time: if not None, use a different `channels` value for the time branch.
414
+ growth: increase the number of hidden channels by this factor at each layer.
415
+ nfft: number of fft bins. Note that changing this require careful computation of
416
+ various shape parameters and will not work out of the box for hybrid models.
417
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
418
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
419
+ wiener_residual: add residual source before wiener filtering.
420
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
421
+ in input and output. no further processing is done before ISTFT.
422
+ depth (int): number of layers in the encoder and in the decoder.
423
+ rewrite (bool): add 1x1 convolution to each layer.
424
+ hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
425
+ hybrid_old: some models trained for MDX had a padding bug. This replicates
426
+ this bug to avoid retraining them.
427
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
428
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
429
+ layers will be wrapped.
430
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
431
+ the actual value controls the weight of the embedding.
432
+ emb_scale: equivalent to scaling the embedding learning rate
433
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
434
+ kernel_size: kernel_size for encoder and decoder layers.
435
+ stride: stride for encoder and decoder layers.
436
+ time_stride: stride for the final time layer, after the merge.
437
+ context: context for 1x1 conv in the decoder.
438
+ context_enc: context for 1x1 conv in the encoder.
439
+ norm_starts: layer at which group norm starts being used.
440
+ decoder layers are numbered in reverse order.
441
+ norm_groups: number of groups for group norm.
442
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
443
+ dconv_depth: depth of residual DConv branch.
444
+ dconv_comp: compression of DConv branch.
445
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
446
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
447
+ dconv_init: initial scale for the DConv branch LayerScale.
448
+ rescale: weight recaling trick
449
+
450
+ """
451
+ super().__init__()
452
+
453
+ self.cac = cac
454
+ self.wiener_residual = wiener_residual
455
+ self.audio_channels = audio_channels
456
+ self.sources = sources
457
+ self.kernel_size = kernel_size
458
+ self.context = context
459
+ self.stride = stride
460
+ self.depth = depth
461
+ self.channels = channels
462
+ self.samplerate = samplerate
463
+ self.segment = segment
464
+
465
+ self.nfft = nfft
466
+ self.hop_length = nfft // 4
467
+ self.wiener_iters = wiener_iters
468
+ self.end_iters = end_iters
469
+ self.freq_emb = None
470
+ self.hybrid = hybrid
471
+ self.hybrid_old = hybrid_old
472
+ if hybrid_old:
473
+ assert hybrid, "hybrid_old must come with hybrid=True"
474
+ if hybrid:
475
+ assert wiener_iters == end_iters
476
+
477
+ self.encoder = nn.ModuleList()
478
+ self.decoder = nn.ModuleList()
479
+
480
+ if hybrid:
481
+ self.tencoder = nn.ModuleList()
482
+ self.tdecoder = nn.ModuleList()
483
+
484
+ chin = audio_channels
485
+ chin_z = chin # number of channels for the freq branch
486
+ if self.cac:
487
+ chin_z *= 2
488
+ chout = channels_time or channels
489
+ chout_z = channels
490
+ freqs = nfft // 2
491
+
492
+ for index in range(depth):
493
+ lstm = index >= dconv_lstm
494
+ attn = index >= dconv_attn
495
+ norm = index >= norm_starts
496
+ freq = freqs > 1
497
+ stri = stride
498
+ ker = kernel_size
499
+ if not freq:
500
+ assert freqs == 1
501
+ ker = time_stride * 2
502
+ stri = time_stride
503
+
504
+ pad = True
505
+ last_freq = False
506
+ if freq and freqs <= kernel_size:
507
+ ker = freqs
508
+ pad = False
509
+ last_freq = True
510
+
511
+ kw = {
512
+ 'kernel_size': ker,
513
+ 'stride': stri,
514
+ 'freq': freq,
515
+ 'pad': pad,
516
+ 'norm': norm,
517
+ 'rewrite': rewrite,
518
+ 'norm_groups': norm_groups,
519
+ 'dconv_kw': {
520
+ 'lstm': lstm,
521
+ 'attn': attn,
522
+ 'depth': dconv_depth,
523
+ 'compress': dconv_comp,
524
+ 'init': dconv_init,
525
+ 'gelu': True,
526
+ }
527
+ }
528
+ kwt = dict(kw)
529
+ kwt['freq'] = 0
530
+ kwt['kernel_size'] = kernel_size
531
+ kwt['stride'] = stride
532
+ kwt['pad'] = True
533
+ kw_dec = dict(kw)
534
+ multi = False
535
+ if multi_freqs and index < multi_freqs_depth:
536
+ multi = True
537
+ kw_dec['context_freq'] = False
538
+
539
+ if last_freq:
540
+ chout_z = max(chout, chout_z)
541
+ chout = chout_z
542
+
543
+ enc = HEncLayer(chin_z, chout_z,
544
+ dconv=dconv_mode & 1, context=context_enc, **kw)
545
+ if hybrid and freq:
546
+ tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc,
547
+ empty=last_freq, **kwt)
548
+ self.tencoder.append(tenc)
549
+
550
+ if multi:
551
+ enc = MultiWrap(enc, multi_freqs)
552
+ self.encoder.append(enc)
553
+ if index == 0:
554
+ chin = self.audio_channels * len(self.sources)
555
+ chin_z = chin
556
+ if self.cac:
557
+ chin_z *= 2
558
+ dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2,
559
+ last=index == 0, context=context, **kw_dec)
560
+ if multi:
561
+ dec = MultiWrap(dec, multi_freqs)
562
+ if hybrid and freq:
563
+ tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq,
564
+ last=index == 0, context=context, **kwt)
565
+ self.tdecoder.insert(0, tdec)
566
+ self.decoder.insert(0, dec)
567
+
568
+ chin = chout
569
+ chin_z = chout_z
570
+ chout = int(growth * chout)
571
+ chout_z = int(growth * chout_z)
572
+ if freq:
573
+ if freqs <= kernel_size:
574
+ freqs = 1
575
+ else:
576
+ freqs //= stride
577
+ if index == 0 and freq_emb:
578
+ self.freq_emb = ScaledEmbedding(
579
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
580
+ self.freq_emb_scale = freq_emb
581
+
582
+ if rescale:
583
+ rescale_module(self, reference=rescale)
584
+
585
+ def _spec(self, x):
586
+ hl = self.hop_length
587
+ nfft = self.nfft
588
+ x0 = x # noqa
589
+
590
+ if self.hybrid:
591
+ # We re-pad the signal in order to keep the property
592
+ # that the size of the output is exactly the size of the input
593
+ # divided by the stride (here hop_length), when divisible.
594
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
595
+ # which is not supported by torch.stft.
596
+ # Having all convolution operations follow this convention allow to easily
597
+ # align the time and frequency branches later on.
598
+ assert hl == nfft // 4
599
+ le = int(math.ceil(x.shape[-1] / hl))
600
+ pad = hl // 2 * 3
601
+ if not self.hybrid_old:
602
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect')
603
+ else:
604
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]))
605
+
606
+ z = spectro(x, nfft, hl)[..., :-1, :]
607
+ if self.hybrid:
608
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
609
+ z = z[..., 2:2+le]
610
+ return z
611
+
612
+ def _ispec(self, z, length=None, scale=0):
613
+ hl = self.hop_length // (4 ** scale)
614
+ z = F.pad(z, (0, 0, 0, 1))
615
+ if self.hybrid:
616
+ z = F.pad(z, (2, 2))
617
+ pad = hl // 2 * 3
618
+ if not self.hybrid_old:
619
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
620
+ else:
621
+ le = hl * int(math.ceil(length / hl))
622
+ x = ispectro(z, hl, length=le)
623
+ if not self.hybrid_old:
624
+ x = x[..., pad:pad + length]
625
+ else:
626
+ x = x[..., :length]
627
+ else:
628
+ x = ispectro(z, hl, length)
629
+ return x
630
+
631
+ def _magnitude(self, z):
632
+ # return the magnitude of the spectrogram, except when cac is True,
633
+ # in which case we just move the complex dimension to the channel one.
634
+ if self.cac:
635
+ B, C, Fr, T = z.shape
636
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
637
+ m = m.reshape(B, C * 2, Fr, T)
638
+ else:
639
+ m = z.abs()
640
+ return m
641
+
642
+ def _mask(self, z, m):
643
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
644
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
645
+ niters = self.wiener_iters
646
+ if self.cac:
647
+ B, S, C, Fr, T = m.shape
648
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
649
+ out = torch.view_as_complex(out.contiguous())
650
+ return out
651
+ if self.training:
652
+ niters = self.end_iters
653
+ if niters < 0:
654
+ z = z[:, None]
655
+ return z / (1e-8 + z.abs()) * m
656
+ else:
657
+ return self._wiener(m, z, niters)
658
+
659
+ def _wiener(self, mag_out, mix_stft, niters):
660
+ # apply wiener filtering from OpenUnmix.
661
+ init = mix_stft.dtype
662
+ wiener_win_len = 300
663
+ residual = self.wiener_residual
664
+
665
+ B, S, C, Fq, T = mag_out.shape
666
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
667
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
668
+
669
+ outs = []
670
+ for sample in range(B):
671
+ pos = 0
672
+ out = []
673
+ for pos in range(0, T, wiener_win_len):
674
+ frame = slice(pos, pos + wiener_win_len)
675
+ z_out = wiener(
676
+ mag_out[sample, frame], mix_stft[sample, frame], niters,
677
+ residual=residual)
678
+ out.append(z_out.transpose(-1, -2))
679
+ outs.append(torch.cat(out, dim=0))
680
+ out = torch.view_as_complex(torch.stack(outs, 0))
681
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
682
+ if residual:
683
+ out = out[:, :-1]
684
+ assert list(out.shape) == [B, S, C, Fq, T]
685
+ return out.to(init)
686
+
687
+ def forward(self, mix):
688
+ x = mix
689
+ length = x.shape[-1]
690
+
691
+ z = self._spec(mix)
692
+ mag = self._magnitude(z)
693
+ x = mag
694
+
695
+ B, C, Fq, T = x.shape
696
+
697
+ # unlike previous Demucs, we always normalize because it is easier.
698
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
699
+ std = x.std(dim=(1, 2, 3), keepdim=True)
700
+ x = (x - mean) / (1e-5 + std)
701
+ # x will be the freq. branch input.
702
+
703
+ if self.hybrid:
704
+ # Prepare the time branch input.
705
+ xt = mix
706
+ meant = xt.mean(dim=(1, 2), keepdim=True)
707
+ stdt = xt.std(dim=(1, 2), keepdim=True)
708
+ xt = (xt - meant) / (1e-5 + stdt)
709
+
710
+ # okay, this is a giant mess I know...
711
+ saved = [] # skip connections, freq.
712
+ saved_t = [] # skip connections, time.
713
+ lengths = [] # saved lengths to properly remove padding, freq branch.
714
+ lengths_t = [] # saved lengths for time branch.
715
+ for idx, encode in enumerate(self.encoder):
716
+ lengths.append(x.shape[-1])
717
+ inject = None
718
+ if self.hybrid and idx < len(self.tencoder):
719
+ # we have not yet merged branches.
720
+ lengths_t.append(xt.shape[-1])
721
+ tenc = self.tencoder[idx]
722
+ xt = tenc(xt)
723
+ if not tenc.empty:
724
+ # save for skip connection
725
+ saved_t.append(xt)
726
+ else:
727
+ # tenc contains just the first conv., so that now time and freq.
728
+ # branches have the same shape and can be merged.
729
+ inject = xt
730
+ x = encode(x, inject)
731
+ if idx == 0 and self.freq_emb is not None:
732
+ # add frequency embedding to allow for non equivariant convolutions
733
+ # over the frequency axis.
734
+ frs = torch.arange(x.shape[-2], device=x.device)
735
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
736
+ x = x + self.freq_emb_scale * emb
737
+
738
+ saved.append(x)
739
+
740
+ x = torch.zeros_like(x)
741
+ if self.hybrid:
742
+ xt = torch.zeros_like(x)
743
+ # initialize everything to zero (signal will go through u-net skips).
744
+
745
+ for idx, decode in enumerate(self.decoder):
746
+ skip = saved.pop(-1)
747
+ x, pre = decode(x, skip, lengths.pop(-1))
748
+ # `pre` contains the output just before final transposed convolution,
749
+ # which is used when the freq. and time branch separate.
750
+
751
+ if self.hybrid:
752
+ offset = self.depth - len(self.tdecoder)
753
+ if self.hybrid and idx >= offset:
754
+ tdec = self.tdecoder[idx - offset]
755
+ length_t = lengths_t.pop(-1)
756
+ if tdec.empty:
757
+ assert pre.shape[2] == 1, pre.shape
758
+ pre = pre[:, :, 0]
759
+ xt, _ = tdec(pre, None, length_t)
760
+ else:
761
+ skip = saved_t.pop(-1)
762
+ xt, _ = tdec(xt, skip, length_t)
763
+
764
+ # Let's make sure we used all stored skip connections.
765
+ assert len(saved) == 0
766
+ assert len(lengths_t) == 0
767
+ assert len(saved_t) == 0
768
+
769
+ S = len(self.sources)
770
+ x = x.view(B, S, -1, Fq, T)
771
+ x = x * std[:, None] + mean[:, None]
772
+
773
+ zout = self._mask(z, x)
774
+ x = self._ispec(zout, length)
775
+
776
+ if self.hybrid:
777
+ xt = xt.view(B, S, -1, length)
778
+ xt = xt * stdt[:, None] + meant[:, None]
779
+ x = xt + x
780
+ return x
781
+
782
+
demucs/htdemucs.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # First author is Simon Rouard.
7
+ """
8
+ This code contains the spectrogram and Hybrid version of Demucs.
9
+ """
10
+ import math
11
+
12
+ from .filtering import wiener
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from fractions import Fraction
17
+ from einops import rearrange
18
+
19
+ from .transformer import CrossTransformerEncoder
20
+
21
+ from .demucs import rescale_module
22
+ from .states import capture_init
23
+ from .spec import spectro, ispectro
24
+ from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
25
+
26
+
27
+ class HTDemucs(nn.Module):
28
+ """
29
+ Spectrogram and hybrid Demucs model.
30
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
31
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
32
+ Frequency layers can still access information across time steps thanks to the DConv residual.
33
+
34
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
35
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
36
+
37
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
38
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
39
+ Open Unmix implementation [Stoter et al. 2019].
40
+
41
+ The loss is always on the temporal domain, by backpropagating through the above
42
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
43
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
44
+ contribution, without changing the one from the waveform, which will lead to worse performance.
45
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
46
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
47
+ hybrid models.
48
+
49
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
50
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
51
+
52
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
53
+ """
54
+
55
+ @capture_init
56
+ def __init__(
57
+ self,
58
+ sources,
59
+ # Channels
60
+ audio_channels=2,
61
+ channels=48,
62
+ channels_time=None,
63
+ growth=2,
64
+ # STFT
65
+ nfft=4096,
66
+ wiener_iters=0,
67
+ end_iters=0,
68
+ wiener_residual=False,
69
+ cac=True,
70
+ # Main structure
71
+ depth=4,
72
+ rewrite=True,
73
+ # Frequency branch
74
+ multi_freqs=None,
75
+ multi_freqs_depth=3,
76
+ freq_emb=0.2,
77
+ emb_scale=10,
78
+ emb_smooth=True,
79
+ # Convolutions
80
+ kernel_size=8,
81
+ time_stride=2,
82
+ stride=4,
83
+ context=1,
84
+ context_enc=0,
85
+ # Normalization
86
+ norm_starts=4,
87
+ norm_groups=4,
88
+ # DConv residual branch
89
+ dconv_mode=1,
90
+ dconv_depth=2,
91
+ dconv_comp=8,
92
+ dconv_init=1e-3,
93
+ # Before the Transformer
94
+ bottom_channels=0,
95
+ # Transformer
96
+ t_layers=5,
97
+ t_emb="sin",
98
+ t_hidden_scale=4.0,
99
+ t_heads=8,
100
+ t_dropout=0.0,
101
+ t_max_positions=10000,
102
+ t_norm_in=True,
103
+ t_norm_in_group=False,
104
+ t_group_norm=False,
105
+ t_norm_first=True,
106
+ t_norm_out=True,
107
+ t_max_period=10000.0,
108
+ t_weight_decay=0.0,
109
+ t_lr=None,
110
+ t_layer_scale=True,
111
+ t_gelu=True,
112
+ t_weight_pos_embed=1.0,
113
+ t_sin_random_shift=0,
114
+ t_cape_mean_normalize=True,
115
+ t_cape_augment=True,
116
+ t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
117
+ t_sparse_self_attn=False,
118
+ t_sparse_cross_attn=False,
119
+ t_mask_type="diag",
120
+ t_mask_random_seed=42,
121
+ t_sparse_attn_window=500,
122
+ t_global_window=100,
123
+ t_sparsity=0.95,
124
+ t_auto_sparsity=False,
125
+ # ------ Particuliar parameters
126
+ t_cross_first=False,
127
+ # Weight init
128
+ rescale=0.1,
129
+ # Metadata
130
+ samplerate=44100,
131
+ segment=10,
132
+ use_train_segment=True,
133
+ ):
134
+ """
135
+ Args:
136
+ sources (list[str]): list of source names.
137
+ audio_channels (int): input/output audio channels.
138
+ channels (int): initial number of hidden channels.
139
+ channels_time: if not None, use a different `channels` value for the time branch.
140
+ growth: increase the number of hidden channels by this factor at each layer.
141
+ nfft: number of fft bins. Note that changing this require careful computation of
142
+ various shape parameters and will not work out of the box for hybrid models.
143
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
144
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
145
+ wiener_residual: add residual source before wiener filtering.
146
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
147
+ in input and output. no further processing is done before ISTFT.
148
+ depth (int): number of layers in the encoder and in the decoder.
149
+ rewrite (bool): add 1x1 convolution to each layer.
150
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
151
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
152
+ layers will be wrapped.
153
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
154
+ the actual value controls the weight of the embedding.
155
+ emb_scale: equivalent to scaling the embedding learning rate
156
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
157
+ kernel_size: kernel_size for encoder and decoder layers.
158
+ stride: stride for encoder and decoder layers.
159
+ time_stride: stride for the final time layer, after the merge.
160
+ context: context for 1x1 conv in the decoder.
161
+ context_enc: context for 1x1 conv in the encoder.
162
+ norm_starts: layer at which group norm starts being used.
163
+ decoder layers are numbered in reverse order.
164
+ norm_groups: number of groups for group norm.
165
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
166
+ dconv_depth: depth of residual DConv branch.
167
+ dconv_comp: compression of DConv branch.
168
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
169
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
170
+ dconv_init: initial scale for the DConv branch LayerScale.
171
+ bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
172
+ transformer in order to change the number of channels
173
+ t_layers: number of layers in each branch (waveform and spec) of the transformer
174
+ t_emb: "sin", "cape" or "scaled"
175
+ t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
176
+ for instance if C = 384 (the number of channels in the transformer) and
177
+ t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
178
+ 384 * 4 = 1536
179
+ t_heads: number of heads for the transformer
180
+ t_dropout: dropout in the transformer
181
+ t_max_positions: max_positions for the "scaled" positional embedding, only
182
+ useful if t_emb="scaled"
183
+ t_norm_in: (bool) norm before addinf positional embedding and getting into the
184
+ transformer layers
185
+ t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
186
+ timesteps (GroupNorm with group=1)
187
+ t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
188
+ timesteps (GroupNorm with group=1)
189
+ t_norm_first: (bool) if True the norm is before the attention and before the FFN
190
+ t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
191
+ t_max_period: (float) denominator in the sinusoidal embedding expression
192
+ t_weight_decay: (float) weight decay for the transformer
193
+ t_lr: (float) specific learning rate for the transformer
194
+ t_layer_scale: (bool) Layer Scale for the transformer
195
+ t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
196
+ t_weight_pos_embed: (float) weighting of the positional embedding
197
+ t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
198
+ see: https://arxiv.org/abs/2106.03143
199
+ t_cape_augment: (bool) if t_emb="cape", must be True during training and False
200
+ during the inference, see: https://arxiv.org/abs/2106.03143
201
+ t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
202
+ see: https://arxiv.org/abs/2106.03143
203
+ t_sparse_self_attn: (bool) if True, the self attentions are sparse
204
+ t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
205
+ unless you designed really specific masks)
206
+ t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
207
+ with '_' between: i.e. "diag_jmask_random" (note that this is permutation
208
+ invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
209
+ t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
210
+ that generated the random part of the mask
211
+ t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
212
+ a key (j), the mask is True id |i-j|<=t_sparse_attn_window
213
+ t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
214
+ and mask[:, :t_global_window] will be True
215
+ t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
216
+ level of the random part of the mask.
217
+ t_cross_first: (bool) if True cross attention is the first layer of the
218
+ transformer (False seems to be better)
219
+ rescale: weight rescaling trick
220
+ use_train_segment: (bool) if True, the actual size that is used during the
221
+ training is used during inference.
222
+ """
223
+ super().__init__()
224
+ self.cac = cac
225
+ self.wiener_residual = wiener_residual
226
+ self.audio_channels = audio_channels
227
+ self.sources = sources
228
+ self.kernel_size = kernel_size
229
+ self.context = context
230
+ self.stride = stride
231
+ self.depth = depth
232
+ self.bottom_channels = bottom_channels
233
+ self.channels = channels
234
+ self.samplerate = samplerate
235
+ self.segment = segment
236
+ self.use_train_segment = use_train_segment
237
+ self.nfft = nfft
238
+ self.hop_length = nfft // 4
239
+ self.wiener_iters = wiener_iters
240
+ self.end_iters = end_iters
241
+ self.freq_emb = None
242
+ assert wiener_iters == end_iters
243
+
244
+ self.encoder = nn.ModuleList()
245
+ self.decoder = nn.ModuleList()
246
+
247
+ self.tencoder = nn.ModuleList()
248
+ self.tdecoder = nn.ModuleList()
249
+
250
+ chin = audio_channels
251
+ chin_z = chin # number of channels for the freq branch
252
+ if self.cac:
253
+ chin_z *= 2
254
+ chout = channels_time or channels
255
+ chout_z = channels
256
+ freqs = nfft // 2
257
+
258
+ for index in range(depth):
259
+ norm = index >= norm_starts
260
+ freq = freqs > 1
261
+ stri = stride
262
+ ker = kernel_size
263
+ if not freq:
264
+ assert freqs == 1
265
+ ker = time_stride * 2
266
+ stri = time_stride
267
+
268
+ pad = True
269
+ last_freq = False
270
+ if freq and freqs <= kernel_size:
271
+ ker = freqs
272
+ pad = False
273
+ last_freq = True
274
+
275
+ kw = {
276
+ "kernel_size": ker,
277
+ "stride": stri,
278
+ "freq": freq,
279
+ "pad": pad,
280
+ "norm": norm,
281
+ "rewrite": rewrite,
282
+ "norm_groups": norm_groups,
283
+ "dconv_kw": {
284
+ "depth": dconv_depth,
285
+ "compress": dconv_comp,
286
+ "init": dconv_init,
287
+ "gelu": True,
288
+ },
289
+ }
290
+ kwt = dict(kw)
291
+ kwt["freq"] = 0
292
+ kwt["kernel_size"] = kernel_size
293
+ kwt["stride"] = stride
294
+ kwt["pad"] = True
295
+ kw_dec = dict(kw)
296
+ multi = False
297
+ if multi_freqs and index < multi_freqs_depth:
298
+ multi = True
299
+ kw_dec["context_freq"] = False
300
+
301
+ if last_freq:
302
+ chout_z = max(chout, chout_z)
303
+ chout = chout_z
304
+
305
+ enc = HEncLayer(
306
+ chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
307
+ )
308
+ if freq:
309
+ tenc = HEncLayer(
310
+ chin,
311
+ chout,
312
+ dconv=dconv_mode & 1,
313
+ context=context_enc,
314
+ empty=last_freq,
315
+ **kwt
316
+ )
317
+ self.tencoder.append(tenc)
318
+
319
+ if multi:
320
+ enc = MultiWrap(enc, multi_freqs)
321
+ self.encoder.append(enc)
322
+ if index == 0:
323
+ chin = self.audio_channels * len(self.sources)
324
+ chin_z = chin
325
+ if self.cac:
326
+ chin_z *= 2
327
+ dec = HDecLayer(
328
+ chout_z,
329
+ chin_z,
330
+ dconv=dconv_mode & 2,
331
+ last=index == 0,
332
+ context=context,
333
+ **kw_dec
334
+ )
335
+ if multi:
336
+ dec = MultiWrap(dec, multi_freqs)
337
+ if freq:
338
+ tdec = HDecLayer(
339
+ chout,
340
+ chin,
341
+ dconv=dconv_mode & 2,
342
+ empty=last_freq,
343
+ last=index == 0,
344
+ context=context,
345
+ **kwt
346
+ )
347
+ self.tdecoder.insert(0, tdec)
348
+ self.decoder.insert(0, dec)
349
+
350
+ chin = chout
351
+ chin_z = chout_z
352
+ chout = int(growth * chout)
353
+ chout_z = int(growth * chout_z)
354
+ if freq:
355
+ if freqs <= kernel_size:
356
+ freqs = 1
357
+ else:
358
+ freqs //= stride
359
+ if index == 0 and freq_emb:
360
+ self.freq_emb = ScaledEmbedding(
361
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale
362
+ )
363
+ self.freq_emb_scale = freq_emb
364
+
365
+ if rescale:
366
+ rescale_module(self, reference=rescale)
367
+
368
+ transformer_channels = channels * growth ** (depth - 1)
369
+ if bottom_channels:
370
+ self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
371
+ self.channel_downsampler = nn.Conv1d(
372
+ bottom_channels, transformer_channels, 1
373
+ )
374
+ self.channel_upsampler_t = nn.Conv1d(
375
+ transformer_channels, bottom_channels, 1
376
+ )
377
+ self.channel_downsampler_t = nn.Conv1d(
378
+ bottom_channels, transformer_channels, 1
379
+ )
380
+
381
+ transformer_channels = bottom_channels
382
+
383
+ if t_layers > 0:
384
+ self.crosstransformer = CrossTransformerEncoder(
385
+ dim=transformer_channels,
386
+ emb=t_emb,
387
+ hidden_scale=t_hidden_scale,
388
+ num_heads=t_heads,
389
+ num_layers=t_layers,
390
+ cross_first=t_cross_first,
391
+ dropout=t_dropout,
392
+ max_positions=t_max_positions,
393
+ norm_in=t_norm_in,
394
+ norm_in_group=t_norm_in_group,
395
+ group_norm=t_group_norm,
396
+ norm_first=t_norm_first,
397
+ norm_out=t_norm_out,
398
+ max_period=t_max_period,
399
+ weight_decay=t_weight_decay,
400
+ lr=t_lr,
401
+ layer_scale=t_layer_scale,
402
+ gelu=t_gelu,
403
+ sin_random_shift=t_sin_random_shift,
404
+ weight_pos_embed=t_weight_pos_embed,
405
+ cape_mean_normalize=t_cape_mean_normalize,
406
+ cape_augment=t_cape_augment,
407
+ cape_glob_loc_scale=t_cape_glob_loc_scale,
408
+ sparse_self_attn=t_sparse_self_attn,
409
+ sparse_cross_attn=t_sparse_cross_attn,
410
+ mask_type=t_mask_type,
411
+ mask_random_seed=t_mask_random_seed,
412
+ sparse_attn_window=t_sparse_attn_window,
413
+ global_window=t_global_window,
414
+ sparsity=t_sparsity,
415
+ auto_sparsity=t_auto_sparsity,
416
+ )
417
+ else:
418
+ self.crosstransformer = None
419
+
420
+ def _spec(self, x):
421
+ hl = self.hop_length
422
+ nfft = self.nfft
423
+ x0 = x # noqa
424
+
425
+ # We re-pad the signal in order to keep the property
426
+ # that the size of the output is exactly the size of the input
427
+ # divided by the stride (here hop_length), when divisible.
428
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
429
+ # which is not supported by torch.stft.
430
+ # Having all convolution operations follow this convention allow to easily
431
+ # align the time and frequency branches later on.
432
+ assert hl == nfft // 4
433
+ le = int(math.ceil(x.shape[-1] / hl))
434
+ pad = hl // 2 * 3
435
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
436
+
437
+ z = spectro(x, nfft, hl)[..., :-1, :]
438
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
439
+ z = z[..., 2: 2 + le]
440
+ return z
441
+
442
+ def _ispec(self, z, length=None, scale=0):
443
+ hl = self.hop_length // (4**scale)
444
+ z = F.pad(z, (0, 0, 0, 1))
445
+ z = F.pad(z, (2, 2))
446
+ pad = hl // 2 * 3
447
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
448
+ x = ispectro(z, hl, length=le)
449
+ x = x[..., pad: pad + length]
450
+ return x
451
+
452
+ def _magnitude(self, z):
453
+ # return the magnitude of the spectrogram, except when cac is True,
454
+ # in which case we just move the complex dimension to the channel one.
455
+ if self.cac:
456
+ B, C, Fr, T = z.shape
457
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
458
+ m = m.reshape(B, C * 2, Fr, T)
459
+ else:
460
+ m = z.abs()
461
+ return m
462
+
463
+ def _mask(self, z, m):
464
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
465
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
466
+ niters = self.wiener_iters
467
+ if self.cac:
468
+ B, S, C, Fr, T = m.shape
469
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
470
+ out = torch.view_as_complex(out.contiguous())
471
+ return out
472
+ if self.training:
473
+ niters = self.end_iters
474
+ if niters < 0:
475
+ z = z[:, None]
476
+ return z / (1e-8 + z.abs()) * m
477
+ else:
478
+ return self._wiener(m, z, niters)
479
+
480
+ def _wiener(self, mag_out, mix_stft, niters):
481
+ # apply wiener filtering from OpenUnmix.
482
+ init = mix_stft.dtype
483
+ wiener_win_len = 300
484
+ residual = self.wiener_residual
485
+
486
+ B, S, C, Fq, T = mag_out.shape
487
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
488
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
489
+
490
+ outs = []
491
+ for sample in range(B):
492
+ pos = 0
493
+ out = []
494
+ for pos in range(0, T, wiener_win_len):
495
+ frame = slice(pos, pos + wiener_win_len)
496
+ z_out = wiener(
497
+ mag_out[sample, frame],
498
+ mix_stft[sample, frame],
499
+ niters,
500
+ residual=residual,
501
+ )
502
+ out.append(z_out.transpose(-1, -2))
503
+ outs.append(torch.cat(out, dim=0))
504
+ out = torch.view_as_complex(torch.stack(outs, 0))
505
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
506
+ if residual:
507
+ out = out[:, :-1]
508
+ assert list(out.shape) == [B, S, C, Fq, T]
509
+ return out.to(init)
510
+
511
+ def valid_length(self, length: int):
512
+ """
513
+ Return a length that is appropriate for evaluation.
514
+ In our case, always return the training length, unless
515
+ it is smaller than the given length, in which case this
516
+ raises an error.
517
+ """
518
+ if not self.use_train_segment:
519
+ return length
520
+ training_length = int(self.segment * self.samplerate)
521
+ if training_length < length:
522
+ raise ValueError(
523
+ f"Given length {length} is longer than "
524
+ f"training length {training_length}")
525
+ return training_length
526
+
527
+ def forward(self, mix):
528
+ length = mix.shape[-1]
529
+ length_pre_pad = None
530
+ if self.use_train_segment:
531
+ if self.training:
532
+ self.segment = Fraction(mix.shape[-1], self.samplerate)
533
+ else:
534
+ training_length = int(self.segment * self.samplerate)
535
+ if mix.shape[-1] < training_length:
536
+ length_pre_pad = mix.shape[-1]
537
+ mix = F.pad(mix, (0, training_length - length_pre_pad))
538
+ z = self._spec(mix)
539
+ mag = self._magnitude(z)
540
+ x = mag
541
+
542
+ B, C, Fq, T = x.shape
543
+
544
+ # unlike previous Demucs, we always normalize because it is easier.
545
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
546
+ std = x.std(dim=(1, 2, 3), keepdim=True)
547
+ x = (x - mean) / (1e-5 + std)
548
+ # x will be the freq. branch input.
549
+
550
+ # Prepare the time branch input.
551
+ xt = mix
552
+ meant = xt.mean(dim=(1, 2), keepdim=True)
553
+ stdt = xt.std(dim=(1, 2), keepdim=True)
554
+ xt = (xt - meant) / (1e-5 + stdt)
555
+
556
+ # okay, this is a giant mess I know...
557
+ saved = [] # skip connections, freq.
558
+ saved_t = [] # skip connections, time.
559
+ lengths = [] # saved lengths to properly remove padding, freq branch.
560
+ lengths_t = [] # saved lengths for time branch.
561
+ for idx, encode in enumerate(self.encoder):
562
+ lengths.append(x.shape[-1])
563
+ inject = None
564
+ if idx < len(self.tencoder):
565
+ # we have not yet merged branches.
566
+ lengths_t.append(xt.shape[-1])
567
+ tenc = self.tencoder[idx]
568
+ xt = tenc(xt)
569
+ if not tenc.empty:
570
+ # save for skip connection
571
+ saved_t.append(xt)
572
+ else:
573
+ # tenc contains just the first conv., so that now time and freq.
574
+ # branches have the same shape and can be merged.
575
+ inject = xt
576
+ x = encode(x, inject)
577
+ if idx == 0 and self.freq_emb is not None:
578
+ # add frequency embedding to allow for non equivariant convolutions
579
+ # over the frequency axis.
580
+ frs = torch.arange(x.shape[-2], device=x.device)
581
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
582
+ x = x + self.freq_emb_scale * emb
583
+
584
+ saved.append(x)
585
+ if self.crosstransformer:
586
+ if self.bottom_channels:
587
+ b, c, f, t = x.shape
588
+ x = rearrange(x, "b c f t-> b c (f t)")
589
+ x = self.channel_upsampler(x)
590
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
591
+ xt = self.channel_upsampler_t(xt)
592
+
593
+ x, xt = self.crosstransformer(x, xt)
594
+
595
+ if self.bottom_channels:
596
+ x = rearrange(x, "b c f t-> b c (f t)")
597
+ x = self.channel_downsampler(x)
598
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
599
+ xt = self.channel_downsampler_t(xt)
600
+
601
+ for idx, decode in enumerate(self.decoder):
602
+ skip = saved.pop(-1)
603
+ x, pre = decode(x, skip, lengths.pop(-1))
604
+ # `pre` contains the output just before final transposed convolution,
605
+ # which is used when the freq. and time branch separate.
606
+
607
+ offset = self.depth - len(self.tdecoder)
608
+ if idx >= offset:
609
+ tdec = self.tdecoder[idx - offset]
610
+ length_t = lengths_t.pop(-1)
611
+ if tdec.empty:
612
+ assert pre.shape[2] == 1, pre.shape
613
+ pre = pre[:, :, 0]
614
+ xt, _ = tdec(pre, None, length_t)
615
+ else:
616
+ skip = saved_t.pop(-1)
617
+ xt, _ = tdec(xt, skip, length_t)
618
+
619
+ # Let's make sure we used all stored skip connections.
620
+ assert len(saved) == 0
621
+ assert len(lengths_t) == 0
622
+ assert len(saved_t) == 0
623
+
624
+ S = len(self.sources)
625
+ x = x.view(B, S, -1, Fq, T)
626
+ x = x * std[:, None] + mean[:, None]
627
+
628
+ zout = self._mask(z, x)
629
+ if self.use_train_segment:
630
+ if self.training:
631
+ x = self._ispec(zout, length)
632
+ else:
633
+ x = self._ispec(zout, training_length)
634
+ else:
635
+ x = self._ispec(zout, length)
636
+
637
+ if self.use_train_segment:
638
+ if self.training:
639
+ xt = xt.view(B, S, -1, length)
640
+ else:
641
+ xt = xt.view(B, S, -1, training_length)
642
+ else:
643
+ xt = xt.view(B, S, -1, length)
644
+ xt = xt * stdt[:, None] + meant[:, None]
645
+ x = xt + x
646
+ if length_pre_pad:
647
+ x = x[..., :length_pre_pad]
648
+ return x
demucs/model.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import torch as th
10
+ from torch import nn
11
+
12
+ from .utils import capture_init, center_trim
13
+
14
+
15
+ class BLSTM(nn.Module):
16
+ def __init__(self, dim, layers=1):
17
+ super().__init__()
18
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
19
+ self.linear = nn.Linear(2 * dim, dim)
20
+
21
+ def forward(self, x):
22
+ x = x.permute(2, 0, 1)
23
+ x = self.lstm(x)[0]
24
+ x = self.linear(x)
25
+ x = x.permute(1, 2, 0)
26
+ return x
27
+
28
+
29
+ def rescale_conv(conv, reference):
30
+ std = conv.weight.std().detach()
31
+ scale = (std / reference)**0.5
32
+ conv.weight.data /= scale
33
+ if conv.bias is not None:
34
+ conv.bias.data /= scale
35
+
36
+
37
+ def rescale_module(module, reference):
38
+ for sub in module.modules():
39
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
40
+ rescale_conv(sub, reference)
41
+
42
+
43
+ def upsample(x, stride):
44
+ """
45
+ Linear upsampling, the output will be `stride` times longer.
46
+ """
47
+ batch, channels, time = x.size()
48
+ weight = th.arange(stride, device=x.device, dtype=th.float) / stride
49
+ x = x.view(batch, channels, time, 1)
50
+ out = x[..., :-1, :] * (1 - weight) + x[..., 1:, :] * weight
51
+ return out.reshape(batch, channels, -1)
52
+
53
+
54
+ def downsample(x, stride):
55
+ """
56
+ Downsample x by decimation.
57
+ """
58
+ return x[:, :, ::stride]
59
+
60
+
61
+ class Demucs(nn.Module):
62
+ @capture_init
63
+ def __init__(self,
64
+ sources=4,
65
+ audio_channels=2,
66
+ channels=64,
67
+ depth=6,
68
+ rewrite=True,
69
+ glu=True,
70
+ upsample=False,
71
+ rescale=0.1,
72
+ kernel_size=8,
73
+ stride=4,
74
+ growth=2.,
75
+ lstm_layers=2,
76
+ context=3,
77
+ samplerate=44100):
78
+ """
79
+ Args:
80
+ sources (int): number of sources to separate
81
+ audio_channels (int): stereo or mono
82
+ channels (int): first convolution channels
83
+ depth (int): number of encoder/decoder layers
84
+ rewrite (bool): add 1x1 convolution to each encoder layer
85
+ and a convolution to each decoder layer.
86
+ For the decoder layer, `context` gives the kernel size.
87
+ glu (bool): use glu instead of ReLU
88
+ upsample (bool): use linear upsampling with convolutions
89
+ Wave-U-Net style, instead of transposed convolutions
90
+ rescale (int): rescale initial weights of convolutions
91
+ to get their standard deviation closer to `rescale`
92
+ kernel_size (int): kernel size for convolutions
93
+ stride (int): stride for convolutions
94
+ growth (float): multiply (resp divide) number of channels by that
95
+ for each layer of the encoder (resp decoder)
96
+ lstm_layers (int): number of lstm layers, 0 = no lstm
97
+ context (int): kernel size of the convolution in the
98
+ decoder before the transposed convolution. If > 1,
99
+ will provide some context from neighboring time
100
+ steps.
101
+ """
102
+
103
+ super().__init__()
104
+ self.audio_channels = audio_channels
105
+ self.sources = sources
106
+ self.kernel_size = kernel_size
107
+ self.context = context
108
+ self.stride = stride
109
+ self.depth = depth
110
+ self.upsample = upsample
111
+ self.channels = channels
112
+ self.samplerate = samplerate
113
+
114
+ self.encoder = nn.ModuleList()
115
+ self.decoder = nn.ModuleList()
116
+
117
+ self.final = None
118
+ if upsample:
119
+ self.final = nn.Conv1d(channels + audio_channels, sources * audio_channels, 1)
120
+ stride = 1
121
+
122
+ if glu:
123
+ activation = nn.GLU(dim=1)
124
+ ch_scale = 2
125
+ else:
126
+ activation = nn.ReLU()
127
+ ch_scale = 1
128
+ in_channels = audio_channels
129
+ for index in range(depth):
130
+ encode = []
131
+ encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
132
+ if rewrite:
133
+ encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
134
+ self.encoder.append(nn.Sequential(*encode))
135
+
136
+ decode = []
137
+ if index > 0:
138
+ out_channels = in_channels
139
+ else:
140
+ if upsample:
141
+ out_channels = channels
142
+ else:
143
+ out_channels = sources * audio_channels
144
+ if rewrite:
145
+ decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
146
+ if upsample:
147
+ decode += [
148
+ nn.Conv1d(channels, out_channels, kernel_size, stride=1),
149
+ ]
150
+ else:
151
+ decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
152
+ if index > 0:
153
+ decode.append(nn.ReLU())
154
+ self.decoder.insert(0, nn.Sequential(*decode))
155
+ in_channels = channels
156
+ channels = int(growth * channels)
157
+
158
+ channels = in_channels
159
+
160
+ if lstm_layers:
161
+ self.lstm = BLSTM(channels, lstm_layers)
162
+ else:
163
+ self.lstm = None
164
+
165
+ if rescale:
166
+ rescale_module(self, reference=rescale)
167
+
168
+ def valid_length(self, length):
169
+ """
170
+ Return the nearest valid length to use with the model so that
171
+ there is no time steps left over in a convolutions, e.g. for all
172
+ layers, size of the input - kernel_size % stride = 0.
173
+
174
+ If the mixture has a valid length, the estimated sources
175
+ will have exactly the same length when context = 1. If context > 1,
176
+ the two signals can be center trimmed to match.
177
+
178
+ For training, extracts should have a valid length.For evaluation
179
+ on full tracks we recommend passing `pad = True` to :method:`forward`.
180
+ """
181
+ for _ in range(self.depth):
182
+ if self.upsample:
183
+ length = math.ceil(length / self.stride) + self.kernel_size - 1
184
+ else:
185
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
186
+ length = max(1, length)
187
+ length += self.context - 1
188
+ for _ in range(self.depth):
189
+ if self.upsample:
190
+ length = length * self.stride + self.kernel_size - 1
191
+ else:
192
+ length = (length - 1) * self.stride + self.kernel_size
193
+
194
+ return int(length)
195
+
196
+ def forward(self, mix):
197
+ x = mix
198
+ saved = [x]
199
+ for encode in self.encoder:
200
+ x = encode(x)
201
+ saved.append(x)
202
+ if self.upsample:
203
+ x = downsample(x, self.stride)
204
+ if self.lstm:
205
+ x = self.lstm(x)
206
+ for decode in self.decoder:
207
+ if self.upsample:
208
+ x = upsample(x, stride=self.stride)
209
+ skip = center_trim(saved.pop(-1), x)
210
+ x = x + skip
211
+ x = decode(x)
212
+ if self.final:
213
+ skip = center_trim(saved.pop(-1), x)
214
+ x = th.cat([x, skip], dim=1)
215
+ x = self.final(x)
216
+
217
+ x = x.view(x.size(0), self.sources, self.audio_channels, x.size(-1))
218
+ return x
demucs/model_v2.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import julius
10
+ from torch import nn
11
+ from .tasnet_v2 import ConvTasNet
12
+
13
+ from .utils import capture_init, center_trim
14
+
15
+
16
+ class BLSTM(nn.Module):
17
+ def __init__(self, dim, layers=1):
18
+ super().__init__()
19
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
20
+ self.linear = nn.Linear(2 * dim, dim)
21
+
22
+ def forward(self, x):
23
+ x = x.permute(2, 0, 1)
24
+ x = self.lstm(x)[0]
25
+ x = self.linear(x)
26
+ x = x.permute(1, 2, 0)
27
+ return x
28
+
29
+
30
+ def rescale_conv(conv, reference):
31
+ std = conv.weight.std().detach()
32
+ scale = (std / reference)**0.5
33
+ conv.weight.data /= scale
34
+ if conv.bias is not None:
35
+ conv.bias.data /= scale
36
+
37
+
38
+ def rescale_module(module, reference):
39
+ for sub in module.modules():
40
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
41
+ rescale_conv(sub, reference)
42
+
43
+ def auto_load_demucs_model_v2(sources, demucs_model_name):
44
+
45
+ if '48' in demucs_model_name:
46
+ channels=48
47
+ elif 'unittest' in demucs_model_name:
48
+ channels=4
49
+ else:
50
+ channels=64
51
+
52
+ if 'tasnet' in demucs_model_name:
53
+ init_demucs_model = ConvTasNet(sources, X=10)
54
+ else:
55
+ init_demucs_model = Demucs(sources, channels=channels)
56
+
57
+ return init_demucs_model
58
+
59
+ class Demucs(nn.Module):
60
+ @capture_init
61
+ def __init__(self,
62
+ sources,
63
+ audio_channels=2,
64
+ channels=64,
65
+ depth=6,
66
+ rewrite=True,
67
+ glu=True,
68
+ rescale=0.1,
69
+ resample=True,
70
+ kernel_size=8,
71
+ stride=4,
72
+ growth=2.,
73
+ lstm_layers=2,
74
+ context=3,
75
+ normalize=False,
76
+ samplerate=44100,
77
+ segment_length=4 * 10 * 44100):
78
+ """
79
+ Args:
80
+ sources (list[str]): list of source names
81
+ audio_channels (int): stereo or mono
82
+ channels (int): first convolution channels
83
+ depth (int): number of encoder/decoder layers
84
+ rewrite (bool): add 1x1 convolution to each encoder layer
85
+ and a convolution to each decoder layer.
86
+ For the decoder layer, `context` gives the kernel size.
87
+ glu (bool): use glu instead of ReLU
88
+ resample_input (bool): upsample x2 the input and downsample /2 the output.
89
+ rescale (int): rescale initial weights of convolutions
90
+ to get their standard deviation closer to `rescale`
91
+ kernel_size (int): kernel size for convolutions
92
+ stride (int): stride for convolutions
93
+ growth (float): multiply (resp divide) number of channels by that
94
+ for each layer of the encoder (resp decoder)
95
+ lstm_layers (int): number of lstm layers, 0 = no lstm
96
+ context (int): kernel size of the convolution in the
97
+ decoder before the transposed convolution. If > 1,
98
+ will provide some context from neighboring time
99
+ steps.
100
+ samplerate (int): stored as meta information for easing
101
+ future evaluations of the model.
102
+ segment_length (int): stored as meta information for easing
103
+ future evaluations of the model. Length of the segments on which
104
+ the model was trained.
105
+ """
106
+
107
+ super().__init__()
108
+ self.audio_channels = audio_channels
109
+ self.sources = sources
110
+ self.kernel_size = kernel_size
111
+ self.context = context
112
+ self.stride = stride
113
+ self.depth = depth
114
+ self.resample = resample
115
+ self.channels = channels
116
+ self.normalize = normalize
117
+ self.samplerate = samplerate
118
+ self.segment_length = segment_length
119
+
120
+ self.encoder = nn.ModuleList()
121
+ self.decoder = nn.ModuleList()
122
+
123
+ if glu:
124
+ activation = nn.GLU(dim=1)
125
+ ch_scale = 2
126
+ else:
127
+ activation = nn.ReLU()
128
+ ch_scale = 1
129
+ in_channels = audio_channels
130
+ for index in range(depth):
131
+ encode = []
132
+ encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
133
+ if rewrite:
134
+ encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
135
+ self.encoder.append(nn.Sequential(*encode))
136
+
137
+ decode = []
138
+ if index > 0:
139
+ out_channels = in_channels
140
+ else:
141
+ out_channels = len(self.sources) * audio_channels
142
+ if rewrite:
143
+ decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
144
+ decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
145
+ if index > 0:
146
+ decode.append(nn.ReLU())
147
+ self.decoder.insert(0, nn.Sequential(*decode))
148
+ in_channels = channels
149
+ channels = int(growth * channels)
150
+
151
+ channels = in_channels
152
+
153
+ if lstm_layers:
154
+ self.lstm = BLSTM(channels, lstm_layers)
155
+ else:
156
+ self.lstm = None
157
+
158
+ if rescale:
159
+ rescale_module(self, reference=rescale)
160
+
161
+ def valid_length(self, length):
162
+ """
163
+ Return the nearest valid length to use with the model so that
164
+ there is no time steps left over in a convolutions, e.g. for all
165
+ layers, size of the input - kernel_size % stride = 0.
166
+
167
+ If the mixture has a valid length, the estimated sources
168
+ will have exactly the same length when context = 1. If context > 1,
169
+ the two signals can be center trimmed to match.
170
+
171
+ For training, extracts should have a valid length.For evaluation
172
+ on full tracks we recommend passing `pad = True` to :method:`forward`.
173
+ """
174
+ if self.resample:
175
+ length *= 2
176
+ for _ in range(self.depth):
177
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
178
+ length = max(1, length)
179
+ length += self.context - 1
180
+ for _ in range(self.depth):
181
+ length = (length - 1) * self.stride + self.kernel_size
182
+
183
+ if self.resample:
184
+ length = math.ceil(length / 2)
185
+ return int(length)
186
+
187
+ def forward(self, mix):
188
+ x = mix
189
+
190
+ if self.normalize:
191
+ mono = mix.mean(dim=1, keepdim=True)
192
+ mean = mono.mean(dim=-1, keepdim=True)
193
+ std = mono.std(dim=-1, keepdim=True)
194
+ else:
195
+ mean = 0
196
+ std = 1
197
+
198
+ x = (x - mean) / (1e-5 + std)
199
+
200
+ if self.resample:
201
+ x = julius.resample_frac(x, 1, 2)
202
+
203
+ saved = []
204
+ for encode in self.encoder:
205
+ x = encode(x)
206
+ saved.append(x)
207
+ if self.lstm:
208
+ x = self.lstm(x)
209
+ for decode in self.decoder:
210
+ skip = center_trim(saved.pop(-1), x)
211
+ x = x + skip
212
+ x = decode(x)
213
+
214
+ if self.resample:
215
+ x = julius.resample_frac(x, 2, 1)
216
+ x = x * std + mean
217
+ x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
218
+ return x
demucs/pretrained.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Loading pretrained models.
7
+ """
8
+
9
+ import logging
10
+ from pathlib import Path
11
+ import typing as tp
12
+
13
+ #from dora.log import fatal
14
+
15
+ import logging
16
+
17
+ from diffq import DiffQuantizer
18
+ import torch.hub
19
+
20
+ from .model import Demucs
21
+ from .tasnet_v2 import ConvTasNet
22
+ from .utils import set_state
23
+
24
+ from .hdemucs import HDemucs
25
+ from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa
26
+
27
+ logger = logging.getLogger(__name__)
28
+ ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/"
29
+ REMOTE_ROOT = Path(__file__).parent / 'remote'
30
+
31
+ SOURCES = ["drums", "bass", "other", "vocals"]
32
+
33
+
34
+ def demucs_unittest():
35
+ model = HDemucs(channels=4, sources=SOURCES)
36
+ return model
37
+
38
+
39
+ def add_model_flags(parser):
40
+ group = parser.add_mutually_exclusive_group(required=False)
41
+ group.add_argument("-s", "--sig", help="Locally trained XP signature.")
42
+ group.add_argument("-n", "--name", default="mdx_extra_q",
43
+ help="Pretrained model name or signature. Default is mdx_extra_q.")
44
+ parser.add_argument("--repo", type=Path,
45
+ help="Folder containing all pre-trained models for use with -n.")
46
+
47
+
48
+ def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
49
+ root: str = ''
50
+ models: tp.Dict[str, str] = {}
51
+ for line in remote_file_list.read_text().split('\n'):
52
+ line = line.strip()
53
+ if line.startswith('#'):
54
+ continue
55
+ elif line.startswith('root:'):
56
+ root = line.split(':', 1)[1].strip()
57
+ else:
58
+ sig = line.split('-', 1)[0]
59
+ assert sig not in models
60
+ models[sig] = ROOT_URL + root + line
61
+ return models
62
+
63
+ def get_model(name: str,
64
+ repo: tp.Optional[Path] = None):
65
+ """`name` must be a bag of models name or a pretrained signature
66
+ from the remote AWS model repo or the specified local repo if `repo` is not None.
67
+ """
68
+ if name == 'demucs_unittest':
69
+ return demucs_unittest()
70
+ model_repo: ModelOnlyRepo
71
+ if repo is None:
72
+ models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
73
+ model_repo = RemoteRepo(models)
74
+ bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
75
+ else:
76
+ if not repo.is_dir():
77
+ fatal(f"{repo} must exist and be a directory.")
78
+ model_repo = LocalRepo(repo)
79
+ bag_repo = BagOnlyRepo(repo, model_repo)
80
+ any_repo = AnyModelRepo(model_repo, bag_repo)
81
+ model = any_repo.get_model(name)
82
+ model.eval()
83
+ return model
84
+
85
+ def get_model_from_args(args):
86
+ """
87
+ Load local model package or pre-trained model.
88
+ """
89
+ return get_model(name=args.name, repo=args.repo)
90
+
91
+ logger = logging.getLogger(__name__)
92
+ ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/"
93
+
94
+ PRETRAINED_MODELS = {
95
+ 'demucs': 'e07c671f',
96
+ 'demucs48_hq': '28a1282c',
97
+ 'demucs_extra': '3646af93',
98
+ 'demucs_quantized': '07afea75',
99
+ 'tasnet': 'beb46fac',
100
+ 'tasnet_extra': 'df3777b2',
101
+ 'demucs_unittest': '09ebc15f',
102
+ }
103
+
104
+ SOURCES = ["drums", "bass", "other", "vocals"]
105
+
106
+
107
+ def get_url(name):
108
+ sig = PRETRAINED_MODELS[name]
109
+ return ROOT + name + "-" + sig[:8] + ".th"
110
+
111
+ def is_pretrained(name):
112
+ return name in PRETRAINED_MODELS
113
+
114
+
115
+ def load_pretrained(name):
116
+ if name == "demucs":
117
+ return demucs(pretrained=True)
118
+ elif name == "demucs48_hq":
119
+ return demucs(pretrained=True, hq=True, channels=48)
120
+ elif name == "demucs_extra":
121
+ return demucs(pretrained=True, extra=True)
122
+ elif name == "demucs_quantized":
123
+ return demucs(pretrained=True, quantized=True)
124
+ elif name == "demucs_unittest":
125
+ return demucs_unittest(pretrained=True)
126
+ elif name == "tasnet":
127
+ return tasnet(pretrained=True)
128
+ elif name == "tasnet_extra":
129
+ return tasnet(pretrained=True, extra=True)
130
+ else:
131
+ raise ValueError(f"Invalid pretrained name {name}")
132
+
133
+
134
+ def _load_state(name, model, quantizer=None):
135
+ url = get_url(name)
136
+ state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
137
+ set_state(model, quantizer, state)
138
+ if quantizer:
139
+ quantizer.detach()
140
+
141
+
142
+ def demucs_unittest(pretrained=True):
143
+ model = Demucs(channels=4, sources=SOURCES)
144
+ if pretrained:
145
+ _load_state('demucs_unittest', model)
146
+ return model
147
+
148
+
149
+ def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64):
150
+ if not pretrained and (extra or quantized or hq):
151
+ raise ValueError("if extra or quantized is True, pretrained must be True.")
152
+ model = Demucs(sources=SOURCES, channels=channels)
153
+ if pretrained:
154
+ name = 'demucs'
155
+ if channels != 64:
156
+ name += str(channels)
157
+ quantizer = None
158
+ if sum([extra, quantized, hq]) > 1:
159
+ raise ValueError("Only one of extra, quantized, hq, can be True.")
160
+ if quantized:
161
+ quantizer = DiffQuantizer(model, group_size=8, min_size=1)
162
+ name += '_quantized'
163
+ if extra:
164
+ name += '_extra'
165
+ if hq:
166
+ name += '_hq'
167
+ _load_state(name, model, quantizer)
168
+ return model
169
+
170
+
171
+ def tasnet(pretrained=True, extra=False):
172
+ if not pretrained and extra:
173
+ raise ValueError("if extra is True, pretrained must be True.")
174
+ model = ConvTasNet(X=10, sources=SOURCES)
175
+ if pretrained:
176
+ name = 'tasnet'
177
+ if extra:
178
+ name = 'tasnet_extra'
179
+ _load_state(name, model)
180
+ return model
demucs/repo.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Represents a model repository, including pre-trained models and bags of models.
7
+ A repo can either be the main remote repository stored in AWS, or a local repository
8
+ with your own models.
9
+ """
10
+
11
+ from hashlib import sha256
12
+ from pathlib import Path
13
+ import typing as tp
14
+
15
+ import torch
16
+ import yaml
17
+
18
+ from .apply import BagOfModels, Model
19
+ from .states import load_model
20
+
21
+
22
+ AnyModel = tp.Union[Model, BagOfModels]
23
+
24
+
25
+ class ModelLoadingError(RuntimeError):
26
+ pass
27
+
28
+
29
+ def check_checksum(path: Path, checksum: str):
30
+ sha = sha256()
31
+ with open(path, 'rb') as file:
32
+ while True:
33
+ buf = file.read(2**20)
34
+ if not buf:
35
+ break
36
+ sha.update(buf)
37
+ actual_checksum = sha.hexdigest()[:len(checksum)]
38
+ if actual_checksum != checksum:
39
+ raise ModelLoadingError(f'Invalid checksum for file {path}, '
40
+ f'expected {checksum} but got {actual_checksum}')
41
+
42
+ class ModelOnlyRepo:
43
+ """Base class for all model only repos.
44
+ """
45
+ def has_model(self, sig: str) -> bool:
46
+ raise NotImplementedError()
47
+
48
+ def get_model(self, sig: str) -> Model:
49
+ raise NotImplementedError()
50
+
51
+
52
+ class RemoteRepo(ModelOnlyRepo):
53
+ def __init__(self, models: tp.Dict[str, str]):
54
+ self._models = models
55
+
56
+ def has_model(self, sig: str) -> bool:
57
+ return sig in self._models
58
+
59
+ def get_model(self, sig: str) -> Model:
60
+ try:
61
+ url = self._models[sig]
62
+ except KeyError:
63
+ raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.')
64
+ pkg = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
65
+ return load_model(pkg)
66
+
67
+
68
+ class LocalRepo(ModelOnlyRepo):
69
+ def __init__(self, root: Path):
70
+ self.root = root
71
+ self.scan()
72
+
73
+ def scan(self):
74
+ self._models = {}
75
+ self._checksums = {}
76
+ for file in self.root.iterdir():
77
+ if file.suffix == '.th':
78
+ if '-' in file.stem:
79
+ xp_sig, checksum = file.stem.split('-')
80
+ self._checksums[xp_sig] = checksum
81
+ else:
82
+ xp_sig = file.stem
83
+ if xp_sig in self._models:
84
+ print('Whats xp? ', xp_sig)
85
+ raise ModelLoadingError(
86
+ f'Duplicate pre-trained model exist for signature {xp_sig}. '
87
+ 'Please delete all but one.')
88
+ self._models[xp_sig] = file
89
+
90
+ def has_model(self, sig: str) -> bool:
91
+ return sig in self._models
92
+
93
+ def get_model(self, sig: str) -> Model:
94
+ try:
95
+ file = self._models[sig]
96
+ except KeyError:
97
+ raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.')
98
+ if sig in self._checksums:
99
+ check_checksum(file, self._checksums[sig])
100
+ return load_model(file)
101
+
102
+
103
+ class BagOnlyRepo:
104
+ """Handles only YAML files containing bag of models, leaving the actual
105
+ model loading to some Repo.
106
+ """
107
+ def __init__(self, root: Path, model_repo: ModelOnlyRepo):
108
+ self.root = root
109
+ self.model_repo = model_repo
110
+ self.scan()
111
+
112
+ def scan(self):
113
+ self._bags = {}
114
+ for file in self.root.iterdir():
115
+ if file.suffix == '.yaml':
116
+ self._bags[file.stem] = file
117
+
118
+ def has_model(self, name: str) -> bool:
119
+ return name in self._bags
120
+
121
+ def get_model(self, name: str) -> BagOfModels:
122
+ try:
123
+ yaml_file = self._bags[name]
124
+ except KeyError:
125
+ raise ModelLoadingError(f'{name} is neither a single pre-trained model or '
126
+ 'a bag of models.')
127
+ bag = yaml.safe_load(open(yaml_file))
128
+ signatures = bag['models']
129
+ models = [self.model_repo.get_model(sig) for sig in signatures]
130
+ weights = bag.get('weights')
131
+ segment = bag.get('segment')
132
+ return BagOfModels(models, weights, segment)
133
+
134
+
135
+ class AnyModelRepo:
136
+ def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo):
137
+ self.model_repo = model_repo
138
+ self.bag_repo = bag_repo
139
+
140
+ def has_model(self, name_or_sig: str) -> bool:
141
+ return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig)
142
+
143
+ def get_model(self, name_or_sig: str) -> AnyModel:
144
+ print('name_or_sig: ', name_or_sig)
145
+ if self.model_repo.has_model(name_or_sig):
146
+ return self.model_repo.get_model(name_or_sig)
147
+ else:
148
+ return self.bag_repo.get_model(name_or_sig)
demucs/spec.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Conveniance wrapper to perform STFT and iSTFT"""
7
+
8
+ import torch as th
9
+
10
+
11
+ def spectro(x, n_fft=512, hop_length=None, pad=0):
12
+ *other, length = x.shape
13
+ x = x.reshape(-1, length)
14
+ z = th.stft(x,
15
+ n_fft * (1 + pad),
16
+ hop_length or n_fft // 4,
17
+ window=th.hann_window(n_fft).to(x),
18
+ win_length=n_fft,
19
+ normalized=True,
20
+ center=True,
21
+ return_complex=True,
22
+ pad_mode='reflect')
23
+ _, freqs, frame = z.shape
24
+ return z.view(*other, freqs, frame)
25
+
26
+
27
+ def ispectro(z, hop_length=None, length=None, pad=0):
28
+ *other, freqs, frames = z.shape
29
+ n_fft = 2 * freqs - 2
30
+ z = z.view(-1, freqs, frames)
31
+ win_length = n_fft // (1 + pad)
32
+ x = th.istft(z,
33
+ n_fft,
34
+ hop_length,
35
+ window=th.hann_window(win_length).to(z.real),
36
+ win_length=win_length,
37
+ normalized=True,
38
+ length=length,
39
+ center=True)
40
+ _, length = x.shape
41
+ return x.view(*other, length)
demucs/states.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Utilities to save and load models.
8
+ """
9
+ from contextlib import contextmanager
10
+
11
+ import functools
12
+ import hashlib
13
+ import inspect
14
+ import io
15
+ from pathlib import Path
16
+ import warnings
17
+
18
+ from omegaconf import OmegaConf
19
+ from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state
20
+ import torch
21
+
22
+
23
+ def get_quantizer(model, args, optimizer=None):
24
+ """Return the quantizer given the XP quantization args."""
25
+ quantizer = None
26
+ if args.diffq:
27
+ quantizer = DiffQuantizer(
28
+ model, min_size=args.min_size, group_size=args.group_size)
29
+ if optimizer is not None:
30
+ quantizer.setup_optimizer(optimizer)
31
+ elif args.qat:
32
+ quantizer = UniformQuantizer(
33
+ model, bits=args.qat, min_size=args.min_size)
34
+ return quantizer
35
+
36
+
37
+ def load_model(path_or_package, strict=False):
38
+ """Load a model from the given serialized model, either given as a dict (already loaded)
39
+ or a path to a file on disk."""
40
+ if isinstance(path_or_package, dict):
41
+ package = path_or_package
42
+ elif isinstance(path_or_package, (str, Path)):
43
+ with warnings.catch_warnings():
44
+ warnings.simplefilter("ignore")
45
+ path = path_or_package
46
+ package = torch.load(path, 'cpu')
47
+ else:
48
+ raise ValueError(f"Invalid type for {path_or_package}.")
49
+
50
+ klass = package["klass"]
51
+ args = package["args"]
52
+ kwargs = package["kwargs"]
53
+
54
+ if strict:
55
+ model = klass(*args, **kwargs)
56
+ else:
57
+ sig = inspect.signature(klass)
58
+ for key in list(kwargs):
59
+ if key not in sig.parameters:
60
+ warnings.warn("Dropping inexistant parameter " + key)
61
+ del kwargs[key]
62
+ model = klass(*args, **kwargs)
63
+
64
+ state = package["state"]
65
+
66
+ set_state(model, state)
67
+ return model
68
+
69
+
70
+ def get_state(model, quantizer, half=False):
71
+ """Get the state from a model, potentially with quantization applied.
72
+ If `half` is True, model are stored as half precision, which shouldn't impact performance
73
+ but half the state size."""
74
+ if quantizer is None:
75
+ dtype = torch.half if half else None
76
+ state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()}
77
+ else:
78
+ state = quantizer.get_quantized_state()
79
+ state['__quantized'] = True
80
+ return state
81
+
82
+
83
+ def set_state(model, state, quantizer=None):
84
+ """Set the state on a given model."""
85
+ if state.get('__quantized'):
86
+ if quantizer is not None:
87
+ quantizer.restore_quantized_state(model, state['quantized'])
88
+ else:
89
+ restore_quantized_state(model, state)
90
+ else:
91
+ model.load_state_dict(state)
92
+ return state
93
+
94
+
95
+ def save_with_checksum(content, path):
96
+ """Save the given value on disk, along with a sha256 hash.
97
+ Should be used with the output of either `serialize_model` or `get_state`."""
98
+ buf = io.BytesIO()
99
+ torch.save(content, buf)
100
+ sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
101
+
102
+ path = path.parent / (path.stem + "-" + sig + path.suffix)
103
+ path.write_bytes(buf.getvalue())
104
+
105
+
106
+ def serialize_model(model, training_args, quantizer=None, half=True):
107
+ args, kwargs = model._init_args_kwargs
108
+ klass = model.__class__
109
+
110
+ state = get_state(model, quantizer, half)
111
+ return {
112
+ 'klass': klass,
113
+ 'args': args,
114
+ 'kwargs': kwargs,
115
+ 'state': state,
116
+ 'training_args': OmegaConf.to_container(training_args, resolve=True),
117
+ }
118
+
119
+
120
+ def copy_state(state):
121
+ return {k: v.cpu().clone() for k, v in state.items()}
122
+
123
+
124
+ @contextmanager
125
+ def swap_state(model, state):
126
+ """
127
+ Context manager that swaps the state of a model, e.g:
128
+
129
+ # model is in old state
130
+ with swap_state(model, new_state):
131
+ # model in new state
132
+ # model back to old state
133
+ """
134
+ old_state = copy_state(model.state_dict())
135
+ model.load_state_dict(state, strict=False)
136
+ try:
137
+ yield
138
+ finally:
139
+ model.load_state_dict(old_state)
140
+
141
+
142
+ def capture_init(init):
143
+ @functools.wraps(init)
144
+ def __init__(self, *args, **kwargs):
145
+ self._init_args_kwargs = (args, kwargs)
146
+ init(self, *args, **kwargs)
147
+
148
+ return __init__
demucs/tasnet.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Created on 2018/12
8
+ # Author: Kaituo XU
9
+ # Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
10
+ # Here is the original license:
11
+ # The MIT License (MIT)
12
+ #
13
+ # Copyright (c) 2018 Kaituo XU
14
+ #
15
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
16
+ # of this software and associated documentation files (the "Software"), to deal
17
+ # in the Software without restriction, including without limitation the rights
18
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
19
+ # copies of the Software, and to permit persons to whom the Software is
20
+ # furnished to do so, subject to the following conditions:
21
+ #
22
+ # The above copyright notice and this permission notice shall be included in all
23
+ # copies or substantial portions of the Software.
24
+ #
25
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31
+ # SOFTWARE.
32
+
33
+ import math
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+
39
+ from .utils import capture_init
40
+
41
+ EPS = 1e-8
42
+
43
+
44
+ def overlap_and_add(signal, frame_step):
45
+ outer_dimensions = signal.size()[:-2]
46
+ frames, frame_length = signal.size()[-2:]
47
+
48
+ subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
49
+ subframe_step = frame_step // subframe_length
50
+ subframes_per_frame = frame_length // subframe_length
51
+ output_size = frame_step * (frames - 1) + frame_length
52
+ output_subframes = output_size // subframe_length
53
+
54
+ subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
55
+
56
+ frame = torch.arange(0, output_subframes,
57
+ device=signal.device).unfold(0, subframes_per_frame, subframe_step)
58
+ frame = frame.long() # signal may in GPU or CPU
59
+ frame = frame.contiguous().view(-1)
60
+
61
+ result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
62
+ result.index_add_(-2, frame, subframe_signal)
63
+ result = result.view(*outer_dimensions, -1)
64
+ return result
65
+
66
+
67
+ class ConvTasNet(nn.Module):
68
+ @capture_init
69
+ def __init__(self,
70
+ N=256,
71
+ L=20,
72
+ B=256,
73
+ H=512,
74
+ P=3,
75
+ X=8,
76
+ R=4,
77
+ C=4,
78
+ audio_channels=1,
79
+ samplerate=44100,
80
+ norm_type="gLN",
81
+ causal=False,
82
+ mask_nonlinear='relu'):
83
+ """
84
+ Args:
85
+ N: Number of filters in autoencoder
86
+ L: Length of the filters (in samples)
87
+ B: Number of channels in bottleneck 1 × 1-conv block
88
+ H: Number of channels in convolutional blocks
89
+ P: Kernel size in convolutional blocks
90
+ X: Number of convolutional blocks in each repeat
91
+ R: Number of repeats
92
+ C: Number of speakers
93
+ norm_type: BN, gLN, cLN
94
+ causal: causal or non-causal
95
+ mask_nonlinear: use which non-linear function to generate mask
96
+ """
97
+ super(ConvTasNet, self).__init__()
98
+ # Hyper-parameter
99
+ self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = N, L, B, H, P, X, R, C
100
+ self.norm_type = norm_type
101
+ self.causal = causal
102
+ self.mask_nonlinear = mask_nonlinear
103
+ self.audio_channels = audio_channels
104
+ self.samplerate = samplerate
105
+ # Components
106
+ self.encoder = Encoder(L, N, audio_channels)
107
+ self.separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear)
108
+ self.decoder = Decoder(N, L, audio_channels)
109
+ # init
110
+ for p in self.parameters():
111
+ if p.dim() > 1:
112
+ nn.init.xavier_normal_(p)
113
+
114
+ def valid_length(self, length):
115
+ return length
116
+
117
+ def forward(self, mixture):
118
+ """
119
+ Args:
120
+ mixture: [M, T], M is batch size, T is #samples
121
+ Returns:
122
+ est_source: [M, C, T]
123
+ """
124
+ mixture_w = self.encoder(mixture)
125
+ est_mask = self.separator(mixture_w)
126
+ est_source = self.decoder(mixture_w, est_mask)
127
+
128
+ # T changed after conv1d in encoder, fix it here
129
+ T_origin = mixture.size(-1)
130
+ T_conv = est_source.size(-1)
131
+ est_source = F.pad(est_source, (0, T_origin - T_conv))
132
+ return est_source
133
+
134
+
135
+ class Encoder(nn.Module):
136
+ """Estimation of the nonnegative mixture weight by a 1-D conv layer.
137
+ """
138
+ def __init__(self, L, N, audio_channels):
139
+ super(Encoder, self).__init__()
140
+ # Hyper-parameter
141
+ self.L, self.N = L, N
142
+ # Components
143
+ # 50% overlap
144
+ self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
145
+
146
+ def forward(self, mixture):
147
+ """
148
+ Args:
149
+ mixture: [M, T], M is batch size, T is #samples
150
+ Returns:
151
+ mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
152
+ """
153
+ mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
154
+ return mixture_w
155
+
156
+
157
+ class Decoder(nn.Module):
158
+ def __init__(self, N, L, audio_channels):
159
+ super(Decoder, self).__init__()
160
+ # Hyper-parameter
161
+ self.N, self.L = N, L
162
+ self.audio_channels = audio_channels
163
+ # Components
164
+ self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
165
+
166
+ def forward(self, mixture_w, est_mask):
167
+ """
168
+ Args:
169
+ mixture_w: [M, N, K]
170
+ est_mask: [M, C, N, K]
171
+ Returns:
172
+ est_source: [M, C, T]
173
+ """
174
+ # D = W * M
175
+ source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
176
+ source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
177
+ # S = DV
178
+ est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
179
+ m, c, k, _ = est_source.size()
180
+ est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
181
+ est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
182
+ return est_source
183
+
184
+
185
+ class TemporalConvNet(nn.Module):
186
+ def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'):
187
+ """
188
+ Args:
189
+ N: Number of filters in autoencoder
190
+ B: Number of channels in bottleneck 1 × 1-conv block
191
+ H: Number of channels in convolutional blocks
192
+ P: Kernel size in convolutional blocks
193
+ X: Number of convolutional blocks in each repeat
194
+ R: Number of repeats
195
+ C: Number of speakers
196
+ norm_type: BN, gLN, cLN
197
+ causal: causal or non-causal
198
+ mask_nonlinear: use which non-linear function to generate mask
199
+ """
200
+ super(TemporalConvNet, self).__init__()
201
+ # Hyper-parameter
202
+ self.C = C
203
+ self.mask_nonlinear = mask_nonlinear
204
+ # Components
205
+ # [M, N, K] -> [M, N, K]
206
+ layer_norm = ChannelwiseLayerNorm(N)
207
+ # [M, N, K] -> [M, B, K]
208
+ bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
209
+ # [M, B, K] -> [M, B, K]
210
+ repeats = []
211
+ for r in range(R):
212
+ blocks = []
213
+ for x in range(X):
214
+ dilation = 2**x
215
+ padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
216
+ blocks += [
217
+ TemporalBlock(B,
218
+ H,
219
+ P,
220
+ stride=1,
221
+ padding=padding,
222
+ dilation=dilation,
223
+ norm_type=norm_type,
224
+ causal=causal)
225
+ ]
226
+ repeats += [nn.Sequential(*blocks)]
227
+ temporal_conv_net = nn.Sequential(*repeats)
228
+ # [M, B, K] -> [M, C*N, K]
229
+ mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
230
+ # Put together
231
+ self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net,
232
+ mask_conv1x1)
233
+
234
+ def forward(self, mixture_w):
235
+ """
236
+ Keep this API same with TasNet
237
+ Args:
238
+ mixture_w: [M, N, K], M is batch size
239
+ returns:
240
+ est_mask: [M, C, N, K]
241
+ """
242
+ M, N, K = mixture_w.size()
243
+ score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
244
+ score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
245
+ if self.mask_nonlinear == 'softmax':
246
+ est_mask = F.softmax(score, dim=1)
247
+ elif self.mask_nonlinear == 'relu':
248
+ est_mask = F.relu(score)
249
+ else:
250
+ raise ValueError("Unsupported mask non-linear function")
251
+ return est_mask
252
+
253
+
254
+ class TemporalBlock(nn.Module):
255
+ def __init__(self,
256
+ in_channels,
257
+ out_channels,
258
+ kernel_size,
259
+ stride,
260
+ padding,
261
+ dilation,
262
+ norm_type="gLN",
263
+ causal=False):
264
+ super(TemporalBlock, self).__init__()
265
+ # [M, B, K] -> [M, H, K]
266
+ conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
267
+ prelu = nn.PReLU()
268
+ norm = chose_norm(norm_type, out_channels)
269
+ # [M, H, K] -> [M, B, K]
270
+ dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding,
271
+ dilation, norm_type, causal)
272
+ # Put together
273
+ self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
274
+
275
+ def forward(self, x):
276
+ """
277
+ Args:
278
+ x: [M, B, K]
279
+ Returns:
280
+ [M, B, K]
281
+ """
282
+ residual = x
283
+ out = self.net(x)
284
+ # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
285
+ return out + residual # look like w/o F.relu is better than w/ F.relu
286
+ # return F.relu(out + residual)
287
+
288
+
289
+ class DepthwiseSeparableConv(nn.Module):
290
+ def __init__(self,
291
+ in_channels,
292
+ out_channels,
293
+ kernel_size,
294
+ stride,
295
+ padding,
296
+ dilation,
297
+ norm_type="gLN",
298
+ causal=False):
299
+ super(DepthwiseSeparableConv, self).__init__()
300
+ # Use `groups` option to implement depthwise convolution
301
+ # [M, H, K] -> [M, H, K]
302
+ depthwise_conv = nn.Conv1d(in_channels,
303
+ in_channels,
304
+ kernel_size,
305
+ stride=stride,
306
+ padding=padding,
307
+ dilation=dilation,
308
+ groups=in_channels,
309
+ bias=False)
310
+ if causal:
311
+ chomp = Chomp1d(padding)
312
+ prelu = nn.PReLU()
313
+ norm = chose_norm(norm_type, in_channels)
314
+ # [M, H, K] -> [M, B, K]
315
+ pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
316
+ # Put together
317
+ if causal:
318
+ self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
319
+ else:
320
+ self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
321
+
322
+ def forward(self, x):
323
+ """
324
+ Args:
325
+ x: [M, H, K]
326
+ Returns:
327
+ result: [M, B, K]
328
+ """
329
+ return self.net(x)
330
+
331
+
332
+ class Chomp1d(nn.Module):
333
+ """To ensure the output length is the same as the input.
334
+ """
335
+ def __init__(self, chomp_size):
336
+ super(Chomp1d, self).__init__()
337
+ self.chomp_size = chomp_size
338
+
339
+ def forward(self, x):
340
+ """
341
+ Args:
342
+ x: [M, H, Kpad]
343
+ Returns:
344
+ [M, H, K]
345
+ """
346
+ return x[:, :, :-self.chomp_size].contiguous()
347
+
348
+
349
+ def chose_norm(norm_type, channel_size):
350
+ """The input of normlization will be (M, C, K), where M is batch size,
351
+ C is channel size and K is sequence length.
352
+ """
353
+ if norm_type == "gLN":
354
+ return GlobalLayerNorm(channel_size)
355
+ elif norm_type == "cLN":
356
+ return ChannelwiseLayerNorm(channel_size)
357
+ elif norm_type == "id":
358
+ return nn.Identity()
359
+ else: # norm_type == "BN":
360
+ # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
361
+ # along M and K, so this BN usage is right.
362
+ return nn.BatchNorm1d(channel_size)
363
+
364
+
365
+ # TODO: Use nn.LayerNorm to impl cLN to speed up
366
+ class ChannelwiseLayerNorm(nn.Module):
367
+ """Channel-wise Layer Normalization (cLN)"""
368
+ def __init__(self, channel_size):
369
+ super(ChannelwiseLayerNorm, self).__init__()
370
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
371
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
372
+ self.reset_parameters()
373
+
374
+ def reset_parameters(self):
375
+ self.gamma.data.fill_(1)
376
+ self.beta.data.zero_()
377
+
378
+ def forward(self, y):
379
+ """
380
+ Args:
381
+ y: [M, N, K], M is batch size, N is channel size, K is length
382
+ Returns:
383
+ cLN_y: [M, N, K]
384
+ """
385
+ mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
386
+ var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
387
+ cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
388
+ return cLN_y
389
+
390
+
391
+ class GlobalLayerNorm(nn.Module):
392
+ """Global Layer Normalization (gLN)"""
393
+ def __init__(self, channel_size):
394
+ super(GlobalLayerNorm, self).__init__()
395
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
396
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
397
+ self.reset_parameters()
398
+
399
+ def reset_parameters(self):
400
+ self.gamma.data.fill_(1)
401
+ self.beta.data.zero_()
402
+
403
+ def forward(self, y):
404
+ """
405
+ Args:
406
+ y: [M, N, K], M is batch size, N is channel size, K is length
407
+ Returns:
408
+ gLN_y: [M, N, K]
409
+ """
410
+ # TODO: in torch 1.0, torch.mean() support dim list
411
+ mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
412
+ var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
413
+ gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
414
+ return gLN_y
415
+
416
+
417
+ if __name__ == "__main__":
418
+ torch.manual_seed(123)
419
+ M, N, L, T = 2, 3, 4, 12
420
+ K = 2 * T // L - 1
421
+ B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
422
+ mixture = torch.randint(3, (M, T))
423
+ # test Encoder
424
+ encoder = Encoder(L, N)
425
+ encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
426
+ mixture_w = encoder(mixture)
427
+ print('mixture', mixture)
428
+ print('U', encoder.conv1d_U.weight)
429
+ print('mixture_w', mixture_w)
430
+ print('mixture_w size', mixture_w.size())
431
+
432
+ # test TemporalConvNet
433
+ separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
434
+ est_mask = separator(mixture_w)
435
+ print('est_mask', est_mask)
436
+
437
+ # test Decoder
438
+ decoder = Decoder(N, L)
439
+ est_mask = torch.randint(2, (B, K, C, N))
440
+ est_source = decoder(mixture_w, est_mask)
441
+ print('est_source', est_source)
442
+
443
+ # test Conv-TasNet
444
+ conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
445
+ est_source = conv_tasnet(mixture)
446
+ print('est_source', est_source)
447
+ print('est_source size', est_source.size())
demucs/tasnet_v2.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # Created on 2018/12
8
+ # Author: Kaituo XU
9
+ # Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
10
+ # Here is the original license:
11
+ # The MIT License (MIT)
12
+ #
13
+ # Copyright (c) 2018 Kaituo XU
14
+ #
15
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
16
+ # of this software and associated documentation files (the "Software"), to deal
17
+ # in the Software without restriction, including without limitation the rights
18
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
19
+ # copies of the Software, and to permit persons to whom the Software is
20
+ # furnished to do so, subject to the following conditions:
21
+ #
22
+ # The above copyright notice and this permission notice shall be included in all
23
+ # copies or substantial portions of the Software.
24
+ #
25
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31
+ # SOFTWARE.
32
+
33
+ import math
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+
39
+ from .utils import capture_init
40
+
41
+ EPS = 1e-8
42
+
43
+
44
+ def overlap_and_add(signal, frame_step):
45
+ outer_dimensions = signal.size()[:-2]
46
+ frames, frame_length = signal.size()[-2:]
47
+
48
+ subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
49
+ subframe_step = frame_step // subframe_length
50
+ subframes_per_frame = frame_length // subframe_length
51
+ output_size = frame_step * (frames - 1) + frame_length
52
+ output_subframes = output_size // subframe_length
53
+
54
+ subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
55
+
56
+ frame = torch.arange(0, output_subframes,
57
+ device=signal.device).unfold(0, subframes_per_frame, subframe_step)
58
+ frame = frame.long() # signal may in GPU or CPU
59
+ frame = frame.contiguous().view(-1)
60
+
61
+ result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
62
+ result.index_add_(-2, frame, subframe_signal)
63
+ result = result.view(*outer_dimensions, -1)
64
+ return result
65
+
66
+
67
+ class ConvTasNet(nn.Module):
68
+ @capture_init
69
+ def __init__(self,
70
+ sources,
71
+ N=256,
72
+ L=20,
73
+ B=256,
74
+ H=512,
75
+ P=3,
76
+ X=8,
77
+ R=4,
78
+ audio_channels=2,
79
+ norm_type="gLN",
80
+ causal=False,
81
+ mask_nonlinear='relu',
82
+ samplerate=44100,
83
+ segment_length=44100 * 2 * 4):
84
+ """
85
+ Args:
86
+ sources: list of sources
87
+ N: Number of filters in autoencoder
88
+ L: Length of the filters (in samples)
89
+ B: Number of channels in bottleneck 1 × 1-conv block
90
+ H: Number of channels in convolutional blocks
91
+ P: Kernel size in convolutional blocks
92
+ X: Number of convolutional blocks in each repeat
93
+ R: Number of repeats
94
+ norm_type: BN, gLN, cLN
95
+ causal: causal or non-causal
96
+ mask_nonlinear: use which non-linear function to generate mask
97
+ """
98
+ super(ConvTasNet, self).__init__()
99
+ # Hyper-parameter
100
+ self.sources = sources
101
+ self.C = len(sources)
102
+ self.N, self.L, self.B, self.H, self.P, self.X, self.R = N, L, B, H, P, X, R
103
+ self.norm_type = norm_type
104
+ self.causal = causal
105
+ self.mask_nonlinear = mask_nonlinear
106
+ self.audio_channels = audio_channels
107
+ self.samplerate = samplerate
108
+ self.segment_length = segment_length
109
+ # Components
110
+ self.encoder = Encoder(L, N, audio_channels)
111
+ self.separator = TemporalConvNet(
112
+ N, B, H, P, X, R, self.C, norm_type, causal, mask_nonlinear)
113
+ self.decoder = Decoder(N, L, audio_channels)
114
+ # init
115
+ for p in self.parameters():
116
+ if p.dim() > 1:
117
+ nn.init.xavier_normal_(p)
118
+
119
+ def valid_length(self, length):
120
+ return length
121
+
122
+ def forward(self, mixture):
123
+ """
124
+ Args:
125
+ mixture: [M, T], M is batch size, T is #samples
126
+ Returns:
127
+ est_source: [M, C, T]
128
+ """
129
+ mixture_w = self.encoder(mixture)
130
+ est_mask = self.separator(mixture_w)
131
+ est_source = self.decoder(mixture_w, est_mask)
132
+
133
+ # T changed after conv1d in encoder, fix it here
134
+ T_origin = mixture.size(-1)
135
+ T_conv = est_source.size(-1)
136
+ est_source = F.pad(est_source, (0, T_origin - T_conv))
137
+ return est_source
138
+
139
+
140
+ class Encoder(nn.Module):
141
+ """Estimation of the nonnegative mixture weight by a 1-D conv layer.
142
+ """
143
+ def __init__(self, L, N, audio_channels):
144
+ super(Encoder, self).__init__()
145
+ # Hyper-parameter
146
+ self.L, self.N = L, N
147
+ # Components
148
+ # 50% overlap
149
+ self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
150
+
151
+ def forward(self, mixture):
152
+ """
153
+ Args:
154
+ mixture: [M, T], M is batch size, T is #samples
155
+ Returns:
156
+ mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
157
+ """
158
+ mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
159
+ return mixture_w
160
+
161
+
162
+ class Decoder(nn.Module):
163
+ def __init__(self, N, L, audio_channels):
164
+ super(Decoder, self).__init__()
165
+ # Hyper-parameter
166
+ self.N, self.L = N, L
167
+ self.audio_channels = audio_channels
168
+ # Components
169
+ self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
170
+
171
+ def forward(self, mixture_w, est_mask):
172
+ """
173
+ Args:
174
+ mixture_w: [M, N, K]
175
+ est_mask: [M, C, N, K]
176
+ Returns:
177
+ est_source: [M, C, T]
178
+ """
179
+ # D = W * M
180
+ source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
181
+ source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
182
+ # S = DV
183
+ est_source = self.basis_signals(source_w) # [M, C, K, ac * L]
184
+ m, c, k, _ = est_source.size()
185
+ est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
186
+ est_source = overlap_and_add(est_source, self.L // 2) # M x C x ac x T
187
+ return est_source
188
+
189
+
190
+ class TemporalConvNet(nn.Module):
191
+ def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'):
192
+ """
193
+ Args:
194
+ N: Number of filters in autoencoder
195
+ B: Number of channels in bottleneck 1 × 1-conv block
196
+ H: Number of channels in convolutional blocks
197
+ P: Kernel size in convolutional blocks
198
+ X: Number of convolutional blocks in each repeat
199
+ R: Number of repeats
200
+ C: Number of speakers
201
+ norm_type: BN, gLN, cLN
202
+ causal: causal or non-causal
203
+ mask_nonlinear: use which non-linear function to generate mask
204
+ """
205
+ super(TemporalConvNet, self).__init__()
206
+ # Hyper-parameter
207
+ self.C = C
208
+ self.mask_nonlinear = mask_nonlinear
209
+ # Components
210
+ # [M, N, K] -> [M, N, K]
211
+ layer_norm = ChannelwiseLayerNorm(N)
212
+ # [M, N, K] -> [M, B, K]
213
+ bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
214
+ # [M, B, K] -> [M, B, K]
215
+ repeats = []
216
+ for r in range(R):
217
+ blocks = []
218
+ for x in range(X):
219
+ dilation = 2**x
220
+ padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
221
+ blocks += [
222
+ TemporalBlock(B,
223
+ H,
224
+ P,
225
+ stride=1,
226
+ padding=padding,
227
+ dilation=dilation,
228
+ norm_type=norm_type,
229
+ causal=causal)
230
+ ]
231
+ repeats += [nn.Sequential(*blocks)]
232
+ temporal_conv_net = nn.Sequential(*repeats)
233
+ # [M, B, K] -> [M, C*N, K]
234
+ mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
235
+ # Put together
236
+ self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net,
237
+ mask_conv1x1)
238
+
239
+ def forward(self, mixture_w):
240
+ """
241
+ Keep this API same with TasNet
242
+ Args:
243
+ mixture_w: [M, N, K], M is batch size
244
+ returns:
245
+ est_mask: [M, C, N, K]
246
+ """
247
+ M, N, K = mixture_w.size()
248
+ score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K]
249
+ score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
250
+ if self.mask_nonlinear == 'softmax':
251
+ est_mask = F.softmax(score, dim=1)
252
+ elif self.mask_nonlinear == 'relu':
253
+ est_mask = F.relu(score)
254
+ else:
255
+ raise ValueError("Unsupported mask non-linear function")
256
+ return est_mask
257
+
258
+
259
+ class TemporalBlock(nn.Module):
260
+ def __init__(self,
261
+ in_channels,
262
+ out_channels,
263
+ kernel_size,
264
+ stride,
265
+ padding,
266
+ dilation,
267
+ norm_type="gLN",
268
+ causal=False):
269
+ super(TemporalBlock, self).__init__()
270
+ # [M, B, K] -> [M, H, K]
271
+ conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
272
+ prelu = nn.PReLU()
273
+ norm = chose_norm(norm_type, out_channels)
274
+ # [M, H, K] -> [M, B, K]
275
+ dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding,
276
+ dilation, norm_type, causal)
277
+ # Put together
278
+ self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
279
+
280
+ def forward(self, x):
281
+ """
282
+ Args:
283
+ x: [M, B, K]
284
+ Returns:
285
+ [M, B, K]
286
+ """
287
+ residual = x
288
+ out = self.net(x)
289
+ # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
290
+ return out + residual # look like w/o F.relu is better than w/ F.relu
291
+ # return F.relu(out + residual)
292
+
293
+
294
+ class DepthwiseSeparableConv(nn.Module):
295
+ def __init__(self,
296
+ in_channels,
297
+ out_channels,
298
+ kernel_size,
299
+ stride,
300
+ padding,
301
+ dilation,
302
+ norm_type="gLN",
303
+ causal=False):
304
+ super(DepthwiseSeparableConv, self).__init__()
305
+ # Use `groups` option to implement depthwise convolution
306
+ # [M, H, K] -> [M, H, K]
307
+ depthwise_conv = nn.Conv1d(in_channels,
308
+ in_channels,
309
+ kernel_size,
310
+ stride=stride,
311
+ padding=padding,
312
+ dilation=dilation,
313
+ groups=in_channels,
314
+ bias=False)
315
+ if causal:
316
+ chomp = Chomp1d(padding)
317
+ prelu = nn.PReLU()
318
+ norm = chose_norm(norm_type, in_channels)
319
+ # [M, H, K] -> [M, B, K]
320
+ pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
321
+ # Put together
322
+ if causal:
323
+ self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
324
+ else:
325
+ self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
326
+
327
+ def forward(self, x):
328
+ """
329
+ Args:
330
+ x: [M, H, K]
331
+ Returns:
332
+ result: [M, B, K]
333
+ """
334
+ return self.net(x)
335
+
336
+
337
+ class Chomp1d(nn.Module):
338
+ """To ensure the output length is the same as the input.
339
+ """
340
+ def __init__(self, chomp_size):
341
+ super(Chomp1d, self).__init__()
342
+ self.chomp_size = chomp_size
343
+
344
+ def forward(self, x):
345
+ """
346
+ Args:
347
+ x: [M, H, Kpad]
348
+ Returns:
349
+ [M, H, K]
350
+ """
351
+ return x[:, :, :-self.chomp_size].contiguous()
352
+
353
+
354
+ def chose_norm(norm_type, channel_size):
355
+ """The input of normlization will be (M, C, K), where M is batch size,
356
+ C is channel size and K is sequence length.
357
+ """
358
+ if norm_type == "gLN":
359
+ return GlobalLayerNorm(channel_size)
360
+ elif norm_type == "cLN":
361
+ return ChannelwiseLayerNorm(channel_size)
362
+ elif norm_type == "id":
363
+ return nn.Identity()
364
+ else: # norm_type == "BN":
365
+ # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
366
+ # along M and K, so this BN usage is right.
367
+ return nn.BatchNorm1d(channel_size)
368
+
369
+
370
+ # TODO: Use nn.LayerNorm to impl cLN to speed up
371
+ class ChannelwiseLayerNorm(nn.Module):
372
+ """Channel-wise Layer Normalization (cLN)"""
373
+ def __init__(self, channel_size):
374
+ super(ChannelwiseLayerNorm, self).__init__()
375
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
376
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
377
+ self.reset_parameters()
378
+
379
+ def reset_parameters(self):
380
+ self.gamma.data.fill_(1)
381
+ self.beta.data.zero_()
382
+
383
+ def forward(self, y):
384
+ """
385
+ Args:
386
+ y: [M, N, K], M is batch size, N is channel size, K is length
387
+ Returns:
388
+ cLN_y: [M, N, K]
389
+ """
390
+ mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
391
+ var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
392
+ cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
393
+ return cLN_y
394
+
395
+
396
+ class GlobalLayerNorm(nn.Module):
397
+ """Global Layer Normalization (gLN)"""
398
+ def __init__(self, channel_size):
399
+ super(GlobalLayerNorm, self).__init__()
400
+ self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
401
+ self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
402
+ self.reset_parameters()
403
+
404
+ def reset_parameters(self):
405
+ self.gamma.data.fill_(1)
406
+ self.beta.data.zero_()
407
+
408
+ def forward(self, y):
409
+ """
410
+ Args:
411
+ y: [M, N, K], M is batch size, N is channel size, K is length
412
+ Returns:
413
+ gLN_y: [M, N, K]
414
+ """
415
+ # TODO: in torch 1.0, torch.mean() support dim list
416
+ mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1]
417
+ var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
418
+ gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
419
+ return gLN_y
420
+
421
+
422
+ if __name__ == "__main__":
423
+ torch.manual_seed(123)
424
+ M, N, L, T = 2, 3, 4, 12
425
+ K = 2 * T // L - 1
426
+ B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
427
+ mixture = torch.randint(3, (M, T))
428
+ # test Encoder
429
+ encoder = Encoder(L, N)
430
+ encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
431
+ mixture_w = encoder(mixture)
432
+ print('mixture', mixture)
433
+ print('U', encoder.conv1d_U.weight)
434
+ print('mixture_w', mixture_w)
435
+ print('mixture_w size', mixture_w.size())
436
+
437
+ # test TemporalConvNet
438
+ separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
439
+ est_mask = separator(mixture_w)
440
+ print('est_mask', est_mask)
441
+
442
+ # test Decoder
443
+ decoder = Decoder(N, L)
444
+ est_mask = torch.randint(2, (B, K, C, N))
445
+ est_source = decoder(mixture_w, est_mask)
446
+ print('est_source', est_source)
447
+
448
+ # test Conv-TasNet
449
+ conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
450
+ est_source = conv_tasnet(mixture)
451
+ print('est_source', est_source)
452
+ print('est_source size', est_source.size())
demucs/transformer.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-present, Meta, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # First author is Simon Rouard.
7
+
8
+ import random
9
+ import typing as tp
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import numpy as np
15
+ import math
16
+ from einops import rearrange
17
+
18
+
19
+ def create_sin_embedding(
20
+ length: int, dim: int, shift: int = 0, device="cpu", max_period=10000
21
+ ):
22
+ # We aim for TBC format
23
+ assert dim % 2 == 0
24
+ pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
25
+ half_dim = dim // 2
26
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
27
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
28
+ return torch.cat(
29
+ [
30
+ torch.cos(phase),
31
+ torch.sin(phase),
32
+ ],
33
+ dim=-1,
34
+ )
35
+
36
+
37
+ def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
38
+ """
39
+ :param d_model: dimension of the model
40
+ :param height: height of the positions
41
+ :param width: width of the positions
42
+ :return: d_model*height*width position matrix
43
+ """
44
+ if d_model % 4 != 0:
45
+ raise ValueError(
46
+ "Cannot use sin/cos positional encoding with "
47
+ "odd dimension (got dim={:d})".format(d_model)
48
+ )
49
+ pe = torch.zeros(d_model, height, width)
50
+ # Each dimension use half of d_model
51
+ d_model = int(d_model / 2)
52
+ div_term = torch.exp(
53
+ torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model)
54
+ )
55
+ pos_w = torch.arange(0.0, width).unsqueeze(1)
56
+ pos_h = torch.arange(0.0, height).unsqueeze(1)
57
+ pe[0:d_model:2, :, :] = (
58
+ torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
59
+ )
60
+ pe[1:d_model:2, :, :] = (
61
+ torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
62
+ )
63
+ pe[d_model::2, :, :] = (
64
+ torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
65
+ )
66
+ pe[d_model + 1:: 2, :, :] = (
67
+ torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
68
+ )
69
+
70
+ return pe[None, :].to(device)
71
+
72
+
73
+ def create_sin_embedding_cape(
74
+ length: int,
75
+ dim: int,
76
+ batch_size: int,
77
+ mean_normalize: bool,
78
+ augment: bool, # True during training
79
+ max_global_shift: float = 0.0, # delta max
80
+ max_local_shift: float = 0.0, # epsilon max
81
+ max_scale: float = 1.0,
82
+ device: str = "cpu",
83
+ max_period: float = 10000.0,
84
+ ):
85
+ # We aim for TBC format
86
+ assert dim % 2 == 0
87
+ pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1)
88
+ pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1)
89
+ if mean_normalize:
90
+ pos -= torch.nanmean(pos, dim=0, keepdim=True)
91
+
92
+ if augment:
93
+ delta = np.random.uniform(
94
+ -max_global_shift, +max_global_shift, size=[1, batch_size, 1]
95
+ )
96
+ delta_local = np.random.uniform(
97
+ -max_local_shift, +max_local_shift, size=[length, batch_size, 1]
98
+ )
99
+ log_lambdas = np.random.uniform(
100
+ -np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1]
101
+ )
102
+ pos = (pos + delta + delta_local) * np.exp(log_lambdas)
103
+
104
+ pos = pos.to(device)
105
+
106
+ half_dim = dim // 2
107
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
108
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
109
+ return torch.cat(
110
+ [
111
+ torch.cos(phase),
112
+ torch.sin(phase),
113
+ ],
114
+ dim=-1,
115
+ ).float()
116
+
117
+
118
+ def get_causal_mask(length):
119
+ pos = torch.arange(length)
120
+ return pos > pos[:, None]
121
+
122
+
123
+ def get_elementary_mask(
124
+ T1,
125
+ T2,
126
+ mask_type,
127
+ sparse_attn_window,
128
+ global_window,
129
+ mask_random_seed,
130
+ sparsity,
131
+ device,
132
+ ):
133
+ """
134
+ When the input of the Decoder has length T1 and the output T2
135
+ The mask matrix has shape (T2, T1)
136
+ """
137
+ assert mask_type in ["diag", "jmask", "random", "global"]
138
+
139
+ if mask_type == "global":
140
+ mask = torch.zeros(T2, T1, dtype=torch.bool)
141
+ mask[:, :global_window] = True
142
+ line_window = int(global_window * T2 / T1)
143
+ mask[:line_window, :] = True
144
+
145
+ if mask_type == "diag":
146
+
147
+ mask = torch.zeros(T2, T1, dtype=torch.bool)
148
+ rows = torch.arange(T2)[:, None]
149
+ cols = (
150
+ (T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1))
151
+ .long()
152
+ .clamp(0, T1 - 1)
153
+ )
154
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
155
+
156
+ elif mask_type == "jmask":
157
+ mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool)
158
+ rows = torch.arange(T2 + 2)[:, None]
159
+ t = torch.arange(0, int((2 * T1) ** 0.5 + 1))
160
+ t = (t * (t + 1) / 2).int()
161
+ t = torch.cat([-t.flip(0)[:-1], t])
162
+ cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1)
163
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
164
+ mask = mask[1:-1, 1:-1]
165
+
166
+ elif mask_type == "random":
167
+ gene = torch.Generator(device=device)
168
+ gene.manual_seed(mask_random_seed)
169
+ mask = (
170
+ torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1)
171
+ > sparsity
172
+ )
173
+
174
+ mask = mask.to(device)
175
+ return mask
176
+
177
+
178
+ def get_mask(
179
+ T1,
180
+ T2,
181
+ mask_type,
182
+ sparse_attn_window,
183
+ global_window,
184
+ mask_random_seed,
185
+ sparsity,
186
+ device,
187
+ ):
188
+ """
189
+ Return a SparseCSRTensor mask that is a combination of elementary masks
190
+ mask_type can be a combination of multiple masks: for instance "diag_jmask_random"
191
+ """
192
+ from xformers.sparse import SparseCSRTensor
193
+ # create a list
194
+ mask_types = mask_type.split("_")
195
+
196
+ all_masks = [
197
+ get_elementary_mask(
198
+ T1,
199
+ T2,
200
+ mask,
201
+ sparse_attn_window,
202
+ global_window,
203
+ mask_random_seed,
204
+ sparsity,
205
+ device,
206
+ )
207
+ for mask in mask_types
208
+ ]
209
+
210
+ final_mask = torch.stack(all_masks).sum(axis=0) > 0
211
+
212
+ return SparseCSRTensor.from_dense(final_mask[None])
213
+
214
+
215
+ class ScaledEmbedding(nn.Module):
216
+ def __init__(
217
+ self,
218
+ num_embeddings: int,
219
+ embedding_dim: int,
220
+ scale: float = 1.0,
221
+ boost: float = 3.0,
222
+ ):
223
+ super().__init__()
224
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
225
+ self.embedding.weight.data *= scale / boost
226
+ self.boost = boost
227
+
228
+ @property
229
+ def weight(self):
230
+ return self.embedding.weight * self.boost
231
+
232
+ def forward(self, x):
233
+ return self.embedding(x) * self.boost
234
+
235
+
236
+ class LayerScale(nn.Module):
237
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
238
+ This rescales diagonaly residual outputs close to 0 initially, then learnt.
239
+ """
240
+
241
+ def __init__(self, channels: int, init: float = 0, channel_last=False):
242
+ """
243
+ channel_last = False corresponds to (B, C, T) tensors
244
+ channel_last = True corresponds to (T, B, C) tensors
245
+ """
246
+ super().__init__()
247
+ self.channel_last = channel_last
248
+ self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
249
+ self.scale.data[:] = init
250
+
251
+ def forward(self, x):
252
+ if self.channel_last:
253
+ return self.scale * x
254
+ else:
255
+ return self.scale[:, None] * x
256
+
257
+
258
+ class MyGroupNorm(nn.GroupNorm):
259
+ def __init__(self, *args, **kwargs):
260
+ super().__init__(*args, **kwargs)
261
+
262
+ def forward(self, x):
263
+ """
264
+ x: (B, T, C)
265
+ if num_groups=1: Normalisation on all T and C together for each B
266
+ """
267
+ x = x.transpose(1, 2)
268
+ return super().forward(x).transpose(1, 2)
269
+
270
+
271
+ class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
272
+ def __init__(
273
+ self,
274
+ d_model,
275
+ nhead,
276
+ dim_feedforward=2048,
277
+ dropout=0.1,
278
+ activation=F.relu,
279
+ group_norm=0,
280
+ norm_first=False,
281
+ norm_out=False,
282
+ layer_norm_eps=1e-5,
283
+ layer_scale=False,
284
+ init_values=1e-4,
285
+ device=None,
286
+ dtype=None,
287
+ sparse=False,
288
+ mask_type="diag",
289
+ mask_random_seed=42,
290
+ sparse_attn_window=500,
291
+ global_window=50,
292
+ auto_sparsity=False,
293
+ sparsity=0.95,
294
+ batch_first=False,
295
+ ):
296
+ factory_kwargs = {"device": device, "dtype": dtype}
297
+ super().__init__(
298
+ d_model=d_model,
299
+ nhead=nhead,
300
+ dim_feedforward=dim_feedforward,
301
+ dropout=dropout,
302
+ activation=activation,
303
+ layer_norm_eps=layer_norm_eps,
304
+ batch_first=batch_first,
305
+ norm_first=norm_first,
306
+ device=device,
307
+ dtype=dtype,
308
+ )
309
+ self.sparse = sparse
310
+ self.auto_sparsity = auto_sparsity
311
+ if sparse:
312
+ if not auto_sparsity:
313
+ self.mask_type = mask_type
314
+ self.sparse_attn_window = sparse_attn_window
315
+ self.global_window = global_window
316
+ self.sparsity = sparsity
317
+ if group_norm:
318
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
319
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
320
+
321
+ self.norm_out = None
322
+ if self.norm_first & norm_out:
323
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
324
+ self.gamma_1 = (
325
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
326
+ )
327
+ self.gamma_2 = (
328
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
329
+ )
330
+
331
+ if sparse:
332
+ self.self_attn = MultiheadAttention(
333
+ d_model, nhead, dropout=dropout, batch_first=batch_first,
334
+ auto_sparsity=sparsity if auto_sparsity else 0,
335
+ )
336
+ self.__setattr__("src_mask", torch.zeros(1, 1))
337
+ self.mask_random_seed = mask_random_seed
338
+
339
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
340
+ """
341
+ if batch_first = False, src shape is (T, B, C)
342
+ the case where batch_first=True is not covered
343
+ """
344
+ device = src.device
345
+ x = src
346
+ T, B, C = x.shape
347
+ if self.sparse and not self.auto_sparsity:
348
+ assert src_mask is None
349
+ src_mask = self.src_mask
350
+ if src_mask.shape[-1] != T:
351
+ src_mask = get_mask(
352
+ T,
353
+ T,
354
+ self.mask_type,
355
+ self.sparse_attn_window,
356
+ self.global_window,
357
+ self.mask_random_seed,
358
+ self.sparsity,
359
+ device,
360
+ )
361
+ self.__setattr__("src_mask", src_mask)
362
+
363
+ if self.norm_first:
364
+ x = x + self.gamma_1(
365
+ self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
366
+ )
367
+ x = x + self.gamma_2(self._ff_block(self.norm2(x)))
368
+
369
+ if self.norm_out:
370
+ x = self.norm_out(x)
371
+ else:
372
+ x = self.norm1(
373
+ x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask))
374
+ )
375
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
376
+
377
+ return x
378
+
379
+
380
+ class CrossTransformerEncoderLayer(nn.Module):
381
+ def __init__(
382
+ self,
383
+ d_model: int,
384
+ nhead: int,
385
+ dim_feedforward: int = 2048,
386
+ dropout: float = 0.1,
387
+ activation=F.relu,
388
+ layer_norm_eps: float = 1e-5,
389
+ layer_scale: bool = False,
390
+ init_values: float = 1e-4,
391
+ norm_first: bool = False,
392
+ group_norm: bool = False,
393
+ norm_out: bool = False,
394
+ sparse=False,
395
+ mask_type="diag",
396
+ mask_random_seed=42,
397
+ sparse_attn_window=500,
398
+ global_window=50,
399
+ sparsity=0.95,
400
+ auto_sparsity=None,
401
+ device=None,
402
+ dtype=None,
403
+ batch_first=False,
404
+ ):
405
+ factory_kwargs = {"device": device, "dtype": dtype}
406
+ super().__init__()
407
+
408
+ self.sparse = sparse
409
+ self.auto_sparsity = auto_sparsity
410
+ if sparse:
411
+ if not auto_sparsity:
412
+ self.mask_type = mask_type
413
+ self.sparse_attn_window = sparse_attn_window
414
+ self.global_window = global_window
415
+ self.sparsity = sparsity
416
+
417
+ self.cross_attn: nn.Module
418
+ self.cross_attn = nn.MultiheadAttention(
419
+ d_model, nhead, dropout=dropout, batch_first=batch_first)
420
+ # Implementation of Feedforward model
421
+ self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
422
+ self.dropout = nn.Dropout(dropout)
423
+ self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
424
+
425
+ self.norm_first = norm_first
426
+ self.norm1: nn.Module
427
+ self.norm2: nn.Module
428
+ self.norm3: nn.Module
429
+ if group_norm:
430
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
431
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
432
+ self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
433
+ else:
434
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
435
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
436
+ self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
437
+
438
+ self.norm_out = None
439
+ if self.norm_first & norm_out:
440
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
441
+
442
+ self.gamma_1 = (
443
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
444
+ )
445
+ self.gamma_2 = (
446
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
447
+ )
448
+
449
+ self.dropout1 = nn.Dropout(dropout)
450
+ self.dropout2 = nn.Dropout(dropout)
451
+
452
+ # Legacy string support for activation function.
453
+ if isinstance(activation, str):
454
+ self.activation = self._get_activation_fn(activation)
455
+ else:
456
+ self.activation = activation
457
+
458
+ if sparse:
459
+ self.cross_attn = MultiheadAttention(
460
+ d_model, nhead, dropout=dropout, batch_first=batch_first,
461
+ auto_sparsity=sparsity if auto_sparsity else 0)
462
+ if not auto_sparsity:
463
+ self.__setattr__("mask", torch.zeros(1, 1))
464
+ self.mask_random_seed = mask_random_seed
465
+
466
+ def forward(self, q, k, mask=None):
467
+ """
468
+ Args:
469
+ q: tensor of shape (T, B, C)
470
+ k: tensor of shape (S, B, C)
471
+ mask: tensor of shape (T, S)
472
+
473
+ """
474
+ device = q.device
475
+ T, B, C = q.shape
476
+ S, B, C = k.shape
477
+ if self.sparse and not self.auto_sparsity:
478
+ assert mask is None
479
+ mask = self.mask
480
+ if mask.shape[-1] != S or mask.shape[-2] != T:
481
+ mask = get_mask(
482
+ S,
483
+ T,
484
+ self.mask_type,
485
+ self.sparse_attn_window,
486
+ self.global_window,
487
+ self.mask_random_seed,
488
+ self.sparsity,
489
+ device,
490
+ )
491
+ self.__setattr__("mask", mask)
492
+
493
+ if self.norm_first:
494
+ x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
495
+ x = x + self.gamma_2(self._ff_block(self.norm3(x)))
496
+ if self.norm_out:
497
+ x = self.norm_out(x)
498
+ else:
499
+ x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
500
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
501
+
502
+ return x
503
+
504
+ # self-attention block
505
+ def _ca_block(self, q, k, attn_mask=None):
506
+ x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
507
+ return self.dropout1(x)
508
+
509
+ # feed forward block
510
+ def _ff_block(self, x):
511
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
512
+ return self.dropout2(x)
513
+
514
+ def _get_activation_fn(self, activation):
515
+ if activation == "relu":
516
+ return F.relu
517
+ elif activation == "gelu":
518
+ return F.gelu
519
+
520
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
521
+
522
+
523
+ # ----------------- MULTI-BLOCKS MODELS: -----------------------
524
+
525
+
526
+ class CrossTransformerEncoder(nn.Module):
527
+ def __init__(
528
+ self,
529
+ dim: int,
530
+ emb: str = "sin",
531
+ hidden_scale: float = 4.0,
532
+ num_heads: int = 8,
533
+ num_layers: int = 6,
534
+ cross_first: bool = False,
535
+ dropout: float = 0.0,
536
+ max_positions: int = 1000,
537
+ norm_in: bool = True,
538
+ norm_in_group: bool = False,
539
+ group_norm: int = False,
540
+ norm_first: bool = False,
541
+ norm_out: bool = False,
542
+ max_period: float = 10000.0,
543
+ weight_decay: float = 0.0,
544
+ lr: tp.Optional[float] = None,
545
+ layer_scale: bool = False,
546
+ gelu: bool = True,
547
+ sin_random_shift: int = 0,
548
+ weight_pos_embed: float = 1.0,
549
+ cape_mean_normalize: bool = True,
550
+ cape_augment: bool = True,
551
+ cape_glob_loc_scale: list = [5000.0, 1.0, 1.4],
552
+ sparse_self_attn: bool = False,
553
+ sparse_cross_attn: bool = False,
554
+ mask_type: str = "diag",
555
+ mask_random_seed: int = 42,
556
+ sparse_attn_window: int = 500,
557
+ global_window: int = 50,
558
+ auto_sparsity: bool = False,
559
+ sparsity: float = 0.95,
560
+ ):
561
+ super().__init__()
562
+ """
563
+ """
564
+ assert dim % num_heads == 0
565
+
566
+ hidden_dim = int(dim * hidden_scale)
567
+
568
+ self.num_layers = num_layers
569
+ # classic parity = 1 means that if idx%2 == 1 there is a
570
+ # classical encoder else there is a cross encoder
571
+ self.classic_parity = 1 if cross_first else 0
572
+ self.emb = emb
573
+ self.max_period = max_period
574
+ self.weight_decay = weight_decay
575
+ self.weight_pos_embed = weight_pos_embed
576
+ self.sin_random_shift = sin_random_shift
577
+ if emb == "cape":
578
+ self.cape_mean_normalize = cape_mean_normalize
579
+ self.cape_augment = cape_augment
580
+ self.cape_glob_loc_scale = cape_glob_loc_scale
581
+ if emb == "scaled":
582
+ self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
583
+
584
+ self.lr = lr
585
+
586
+ activation: tp.Any = F.gelu if gelu else F.relu
587
+
588
+ self.norm_in: nn.Module
589
+ self.norm_in_t: nn.Module
590
+ if norm_in:
591
+ self.norm_in = nn.LayerNorm(dim)
592
+ self.norm_in_t = nn.LayerNorm(dim)
593
+ elif norm_in_group:
594
+ self.norm_in = MyGroupNorm(int(norm_in_group), dim)
595
+ self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
596
+ else:
597
+ self.norm_in = nn.Identity()
598
+ self.norm_in_t = nn.Identity()
599
+
600
+ # spectrogram layers
601
+ self.layers = nn.ModuleList()
602
+ # temporal layers
603
+ self.layers_t = nn.ModuleList()
604
+
605
+ kwargs_common = {
606
+ "d_model": dim,
607
+ "nhead": num_heads,
608
+ "dim_feedforward": hidden_dim,
609
+ "dropout": dropout,
610
+ "activation": activation,
611
+ "group_norm": group_norm,
612
+ "norm_first": norm_first,
613
+ "norm_out": norm_out,
614
+ "layer_scale": layer_scale,
615
+ "mask_type": mask_type,
616
+ "mask_random_seed": mask_random_seed,
617
+ "sparse_attn_window": sparse_attn_window,
618
+ "global_window": global_window,
619
+ "sparsity": sparsity,
620
+ "auto_sparsity": auto_sparsity,
621
+ "batch_first": True,
622
+ }
623
+
624
+ kwargs_classic_encoder = dict(kwargs_common)
625
+ kwargs_classic_encoder.update({
626
+ "sparse": sparse_self_attn,
627
+ })
628
+ kwargs_cross_encoder = dict(kwargs_common)
629
+ kwargs_cross_encoder.update({
630
+ "sparse": sparse_cross_attn,
631
+ })
632
+
633
+ for idx in range(num_layers):
634
+ if idx % 2 == self.classic_parity:
635
+
636
+ self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
637
+ self.layers_t.append(
638
+ MyTransformerEncoderLayer(**kwargs_classic_encoder)
639
+ )
640
+
641
+ else:
642
+ self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
643
+
644
+ self.layers_t.append(
645
+ CrossTransformerEncoderLayer(**kwargs_cross_encoder)
646
+ )
647
+
648
+ def forward(self, x, xt):
649
+ B, C, Fr, T1 = x.shape
650
+ pos_emb_2d = create_2d_sin_embedding(
651
+ C, Fr, T1, x.device, self.max_period
652
+ ) # (1, C, Fr, T1)
653
+ pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
654
+ x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
655
+ x = self.norm_in(x)
656
+ x = x + self.weight_pos_embed * pos_emb_2d
657
+
658
+ B, C, T2 = xt.shape
659
+ xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C
660
+ pos_emb = self._get_pos_embedding(T2, B, C, x.device)
661
+ pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
662
+ xt = self.norm_in_t(xt)
663
+ xt = xt + self.weight_pos_embed * pos_emb
664
+
665
+ for idx in range(self.num_layers):
666
+ if idx % 2 == self.classic_parity:
667
+ x = self.layers[idx](x)
668
+ xt = self.layers_t[idx](xt)
669
+ else:
670
+ old_x = x
671
+ x = self.layers[idx](x, xt)
672
+ xt = self.layers_t[idx](xt, old_x)
673
+
674
+ x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1)
675
+ xt = rearrange(xt, "b t2 c -> b c t2")
676
+ return x, xt
677
+
678
+ def _get_pos_embedding(self, T, B, C, device):
679
+ if self.emb == "sin":
680
+ shift = random.randrange(self.sin_random_shift + 1)
681
+ pos_emb = create_sin_embedding(
682
+ T, C, shift=shift, device=device, max_period=self.max_period
683
+ )
684
+ elif self.emb == "cape":
685
+ if self.training:
686
+ pos_emb = create_sin_embedding_cape(
687
+ T,
688
+ C,
689
+ B,
690
+ device=device,
691
+ max_period=self.max_period,
692
+ mean_normalize=self.cape_mean_normalize,
693
+ augment=self.cape_augment,
694
+ max_global_shift=self.cape_glob_loc_scale[0],
695
+ max_local_shift=self.cape_glob_loc_scale[1],
696
+ max_scale=self.cape_glob_loc_scale[2],
697
+ )
698
+ else:
699
+ pos_emb = create_sin_embedding_cape(
700
+ T,
701
+ C,
702
+ B,
703
+ device=device,
704
+ max_period=self.max_period,
705
+ mean_normalize=self.cape_mean_normalize,
706
+ augment=False,
707
+ )
708
+
709
+ elif self.emb == "scaled":
710
+ pos = torch.arange(T, device=device)
711
+ pos_emb = self.position_embeddings(pos)[:, None]
712
+
713
+ return pos_emb
714
+
715
+ def make_optim_group(self):
716
+ group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
717
+ if self.lr is not None:
718
+ group["lr"] = self.lr
719
+ return group
720
+
721
+
722
+ # Attention Modules
723
+
724
+
725
+ class MultiheadAttention(nn.Module):
726
+ def __init__(
727
+ self,
728
+ embed_dim,
729
+ num_heads,
730
+ dropout=0.0,
731
+ bias=True,
732
+ add_bias_kv=False,
733
+ add_zero_attn=False,
734
+ kdim=None,
735
+ vdim=None,
736
+ batch_first=False,
737
+ auto_sparsity=None,
738
+ ):
739
+ super().__init__()
740
+ assert auto_sparsity is not None, "sanity check"
741
+ self.num_heads = num_heads
742
+ self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
743
+ self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
744
+ self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias)
745
+ self.attn_drop = torch.nn.Dropout(dropout)
746
+ self.proj = torch.nn.Linear(embed_dim, embed_dim, bias)
747
+ self.proj_drop = torch.nn.Dropout(dropout)
748
+ self.batch_first = batch_first
749
+ self.auto_sparsity = auto_sparsity
750
+
751
+ def forward(
752
+ self,
753
+ query,
754
+ key,
755
+ value,
756
+ key_padding_mask=None,
757
+ need_weights=True,
758
+ attn_mask=None,
759
+ average_attn_weights=True,
760
+ ):
761
+
762
+ if not self.batch_first: # N, B, C
763
+ query = query.permute(1, 0, 2) # B, N_q, C
764
+ key = key.permute(1, 0, 2) # B, N_k, C
765
+ value = value.permute(1, 0, 2) # B, N_k, C
766
+ B, N_q, C = query.shape
767
+ B, N_k, C = key.shape
768
+
769
+ q = (
770
+ self.q(query)
771
+ .reshape(B, N_q, self.num_heads, C // self.num_heads)
772
+ .permute(0, 2, 1, 3)
773
+ )
774
+ q = q.flatten(0, 1)
775
+ k = (
776
+ self.k(key)
777
+ .reshape(B, N_k, self.num_heads, C // self.num_heads)
778
+ .permute(0, 2, 1, 3)
779
+ )
780
+ k = k.flatten(0, 1)
781
+ v = (
782
+ self.v(value)
783
+ .reshape(B, N_k, self.num_heads, C // self.num_heads)
784
+ .permute(0, 2, 1, 3)
785
+ )
786
+ v = v.flatten(0, 1)
787
+
788
+ if self.auto_sparsity:
789
+ assert attn_mask is None
790
+ x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity)
791
+ else:
792
+ x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop)
793
+ x = x.reshape(B, self.num_heads, N_q, C // self.num_heads)
794
+
795
+ x = x.transpose(1, 2).reshape(B, N_q, C)
796
+ x = self.proj(x)
797
+ x = self.proj_drop(x)
798
+ if not self.batch_first:
799
+ x = x.permute(1, 0, 2)
800
+ return x, None
801
+
802
+
803
+ def scaled_query_key_softmax(q, k, att_mask):
804
+ from xformers.ops import masked_matmul
805
+ q = q / (k.size(-1)) ** 0.5
806
+ att = masked_matmul(q, k.transpose(-2, -1), att_mask)
807
+ att = torch.nn.functional.softmax(att, -1)
808
+ return att
809
+
810
+
811
+ def scaled_dot_product_attention(q, k, v, att_mask, dropout):
812
+ att = scaled_query_key_softmax(q, k, att_mask=att_mask)
813
+ att = dropout(att)
814
+ y = att @ v
815
+ return y
816
+
817
+
818
+ def _compute_buckets(x, R):
819
+ qq = torch.einsum('btf,bfhi->bhti', x, R)
820
+ qq = torch.cat([qq, -qq], dim=-1)
821
+ buckets = qq.argmax(dim=-1)
822
+
823
+ return buckets.permute(0, 2, 1).byte().contiguous()
824
+
825
+
826
+ def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None):
827
+ # assert False, "The code for the custom sparse kernel is not ready for release yet."
828
+ from xformers.ops import find_locations, sparse_memory_efficient_attention
829
+ n_hashes = 32
830
+ proj_size = 4
831
+ query, key, value = [x.contiguous() for x in [query, key, value]]
832
+ with torch.no_grad():
833
+ R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device)
834
+ bucket_query = _compute_buckets(query, R)
835
+ bucket_key = _compute_buckets(key, R)
836
+ row_offsets, column_indices = find_locations(
837
+ bucket_query, bucket_key, sparsity, infer_sparsity)
838
+ return sparse_memory_efficient_attention(
839
+ query, key, value, row_offsets, column_indices, attn_bias)
demucs/utils.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import defaultdict
8
+ from contextlib import contextmanager
9
+ import math
10
+ import os
11
+ import tempfile
12
+ import typing as tp
13
+
14
+ import errno
15
+ import functools
16
+ import hashlib
17
+ import inspect
18
+ import io
19
+ import os
20
+ import random
21
+ import socket
22
+ import tempfile
23
+ import warnings
24
+ import zlib
25
+ import tkinter as tk
26
+
27
+ from diffq import UniformQuantizer, DiffQuantizer
28
+ import torch as th
29
+ import tqdm
30
+ from torch import distributed
31
+ from torch.nn import functional as F
32
+
33
+ import torch
34
+
35
+ def unfold(a, kernel_size, stride):
36
+ """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
37
+ with K the kernel size, by extracting frames with the given stride.
38
+
39
+ This will pad the input so that `F = ceil(T / K)`.
40
+
41
+ see https://github.com/pytorch/pytorch/issues/60466
42
+ """
43
+ *shape, length = a.shape
44
+ n_frames = math.ceil(length / stride)
45
+ tgt_length = (n_frames - 1) * stride + kernel_size
46
+ a = F.pad(a, (0, tgt_length - length))
47
+ strides = list(a.stride())
48
+ assert strides[-1] == 1, 'data should be contiguous'
49
+ strides = strides[:-1] + [stride, 1]
50
+ return a.as_strided([*shape, n_frames, kernel_size], strides)
51
+
52
+
53
+ def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
54
+ """
55
+ Center trim `tensor` with respect to `reference`, along the last dimension.
56
+ `reference` can also be a number, representing the length to trim to.
57
+ If the size difference != 0 mod 2, the extra sample is removed on the right side.
58
+ """
59
+ ref_size: int
60
+ if isinstance(reference, torch.Tensor):
61
+ ref_size = reference.size(-1)
62
+ else:
63
+ ref_size = reference
64
+ delta = tensor.size(-1) - ref_size
65
+ if delta < 0:
66
+ raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
67
+ if delta:
68
+ tensor = tensor[..., delta // 2:-(delta - delta // 2)]
69
+ return tensor
70
+
71
+
72
+ def pull_metric(history: tp.List[dict], name: str):
73
+ out = []
74
+ for metrics in history:
75
+ metric = metrics
76
+ for part in name.split("."):
77
+ metric = metric[part]
78
+ out.append(metric)
79
+ return out
80
+
81
+
82
+ def EMA(beta: float = 1):
83
+ """
84
+ Exponential Moving Average callback.
85
+ Returns a single function that can be called to repeatidly update the EMA
86
+ with a dict of metrics. The callback will return
87
+ the new averaged dict of metrics.
88
+
89
+ Note that for `beta=1`, this is just plain averaging.
90
+ """
91
+ fix: tp.Dict[str, float] = defaultdict(float)
92
+ total: tp.Dict[str, float] = defaultdict(float)
93
+
94
+ def _update(metrics: dict, weight: float = 1) -> dict:
95
+ nonlocal total, fix
96
+ for key, value in metrics.items():
97
+ total[key] = total[key] * beta + weight * float(value)
98
+ fix[key] = fix[key] * beta + weight
99
+ return {key: tot / fix[key] for key, tot in total.items()}
100
+ return _update
101
+
102
+
103
+ def sizeof_fmt(num: float, suffix: str = 'B'):
104
+ """
105
+ Given `num` bytes, return human readable size.
106
+ Taken from https://stackoverflow.com/a/1094933
107
+ """
108
+ for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
109
+ if abs(num) < 1024.0:
110
+ return "%3.1f%s%s" % (num, unit, suffix)
111
+ num /= 1024.0
112
+ return "%.1f%s%s" % (num, 'Yi', suffix)
113
+
114
+
115
+ @contextmanager
116
+ def temp_filenames(count: int, delete=True):
117
+ names = []
118
+ try:
119
+ for _ in range(count):
120
+ names.append(tempfile.NamedTemporaryFile(delete=False).name)
121
+ yield names
122
+ finally:
123
+ if delete:
124
+ for name in names:
125
+ os.unlink(name)
126
+
127
+ def average_metric(metric, count=1.):
128
+ """
129
+ Average `metric` which should be a float across all hosts. `count` should be
130
+ the weight for this particular host (i.e. number of examples).
131
+ """
132
+ metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda')
133
+ distributed.all_reduce(metric, op=distributed.ReduceOp.SUM)
134
+ return metric[1].item() / metric[0].item()
135
+
136
+
137
+ def free_port(host='', low=20000, high=40000):
138
+ """
139
+ Return a port number that is most likely free.
140
+ This could suffer from a race condition although
141
+ it should be quite rare.
142
+ """
143
+ sock = socket.socket()
144
+ while True:
145
+ port = random.randint(low, high)
146
+ try:
147
+ sock.bind((host, port))
148
+ except OSError as error:
149
+ if error.errno == errno.EADDRINUSE:
150
+ continue
151
+ raise
152
+ return port
153
+
154
+
155
+ def sizeof_fmt(num, suffix='B'):
156
+ """
157
+ Given `num` bytes, return human readable size.
158
+ Taken from https://stackoverflow.com/a/1094933
159
+ """
160
+ for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
161
+ if abs(num) < 1024.0:
162
+ return "%3.1f%s%s" % (num, unit, suffix)
163
+ num /= 1024.0
164
+ return "%.1f%s%s" % (num, 'Yi', suffix)
165
+
166
+
167
+ def human_seconds(seconds, display='.2f'):
168
+ """
169
+ Given `seconds` seconds, return human readable duration.
170
+ """
171
+ value = seconds * 1e6
172
+ ratios = [1e3, 1e3, 60, 60, 24]
173
+ names = ['us', 'ms', 's', 'min', 'hrs', 'days']
174
+ last = names.pop(0)
175
+ for name, ratio in zip(names, ratios):
176
+ if value / ratio < 0.3:
177
+ break
178
+ value /= ratio
179
+ last = name
180
+ return f"{format(value, display)} {last}"
181
+
182
+
183
+ class TensorChunk:
184
+ def __init__(self, tensor, offset=0, length=None):
185
+ total_length = tensor.shape[-1]
186
+ assert offset >= 0
187
+ assert offset < total_length
188
+
189
+ if length is None:
190
+ length = total_length - offset
191
+ else:
192
+ length = min(total_length - offset, length)
193
+
194
+ self.tensor = tensor
195
+ self.offset = offset
196
+ self.length = length
197
+ self.device = tensor.device
198
+
199
+ @property
200
+ def shape(self):
201
+ shape = list(self.tensor.shape)
202
+ shape[-1] = self.length
203
+ return shape
204
+
205
+ def padded(self, target_length):
206
+ delta = target_length - self.length
207
+ total_length = self.tensor.shape[-1]
208
+ assert delta >= 0
209
+
210
+ start = self.offset - delta // 2
211
+ end = start + target_length
212
+
213
+ correct_start = max(0, start)
214
+ correct_end = min(total_length, end)
215
+
216
+ pad_left = correct_start - start
217
+ pad_right = end - correct_end
218
+
219
+ out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
220
+ assert out.shape[-1] == target_length
221
+ return out
222
+
223
+
224
+ def tensor_chunk(tensor_or_chunk):
225
+ if isinstance(tensor_or_chunk, TensorChunk):
226
+ return tensor_or_chunk
227
+ else:
228
+ assert isinstance(tensor_or_chunk, th.Tensor)
229
+ return TensorChunk(tensor_or_chunk)
230
+
231
+
232
+ def apply_model_v1(model, mix, shifts=None, split=False, progress=False, set_progress_bar=None):
233
+ """
234
+ Apply model to a given mixture.
235
+
236
+ Args:
237
+ shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
238
+ and apply the oppositve shift to the output. This is repeated `shifts` time and
239
+ all predictions are averaged. This effectively makes the model time equivariant
240
+ and improves SDR by up to 0.2 points.
241
+ split (bool): if True, the input will be broken down in 8 seconds extracts
242
+ and predictions will be performed individually on each and concatenated.
243
+ Useful for model with large memory footprint like Tasnet.
244
+ progress (bool): if True, show a progress bar (requires split=True)
245
+ """
246
+
247
+ channels, length = mix.size()
248
+ device = mix.device
249
+ progress_value = 0
250
+
251
+ if split:
252
+ out = th.zeros(4, channels, length, device=device)
253
+ shift = model.samplerate * 10
254
+ offsets = range(0, length, shift)
255
+ scale = 10
256
+ if progress:
257
+ offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds')
258
+ for offset in offsets:
259
+ chunk = mix[..., offset:offset + shift]
260
+ if set_progress_bar:
261
+ progress_value += 1
262
+ set_progress_bar(0.1, (0.8/len(offsets)*progress_value))
263
+ chunk_out = apply_model_v1(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar)
264
+ else:
265
+ chunk_out = apply_model_v1(model, chunk, shifts=shifts)
266
+ out[..., offset:offset + shift] = chunk_out
267
+ offset += shift
268
+ return out
269
+ elif shifts:
270
+ max_shift = int(model.samplerate / 2)
271
+ mix = F.pad(mix, (max_shift, max_shift))
272
+ offsets = list(range(max_shift))
273
+ random.shuffle(offsets)
274
+ out = 0
275
+ for offset in offsets[:shifts]:
276
+ shifted = mix[..., offset:offset + length + max_shift]
277
+ if set_progress_bar:
278
+ shifted_out = apply_model_v1(model, shifted, set_progress_bar=set_progress_bar)
279
+ else:
280
+ shifted_out = apply_model_v1(model, shifted)
281
+ out += shifted_out[..., max_shift - offset:max_shift - offset + length]
282
+ out /= shifts
283
+ return out
284
+ else:
285
+ valid_length = model.valid_length(length)
286
+ delta = valid_length - length
287
+ padded = F.pad(mix, (delta // 2, delta - delta // 2))
288
+ with th.no_grad():
289
+ out = model(padded.unsqueeze(0))[0]
290
+ return center_trim(out, mix)
291
+
292
+ def apply_model_v2(model, mix, shifts=None, split=False,
293
+ overlap=0.25, transition_power=1., progress=False, set_progress_bar=None):
294
+ """
295
+ Apply model to a given mixture.
296
+
297
+ Args:
298
+ shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
299
+ and apply the oppositve shift to the output. This is repeated `shifts` time and
300
+ all predictions are averaged. This effectively makes the model time equivariant
301
+ and improves SDR by up to 0.2 points.
302
+ split (bool): if True, the input will be broken down in 8 seconds extracts
303
+ and predictions will be performed individually on each and concatenated.
304
+ Useful for model with large memory footprint like Tasnet.
305
+ progress (bool): if True, show a progress bar (requires split=True)
306
+ """
307
+
308
+ assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
309
+ device = mix.device
310
+ channels, length = mix.shape
311
+ progress_value = 0
312
+
313
+ if split:
314
+ out = th.zeros(len(model.sources), channels, length, device=device)
315
+ sum_weight = th.zeros(length, device=device)
316
+ segment = model.segment_length
317
+ stride = int((1 - overlap) * segment)
318
+ offsets = range(0, length, stride)
319
+ scale = stride / model.samplerate
320
+ if progress:
321
+ offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds')
322
+ # We start from a triangle shaped weight, with maximal weight in the middle
323
+ # of the segment. Then we normalize and take to the power `transition_power`.
324
+ # Large values of transition power will lead to sharper transitions.
325
+ weight = th.cat([th.arange(1, segment // 2 + 1),
326
+ th.arange(segment - segment // 2, 0, -1)]).to(device)
327
+ assert len(weight) == segment
328
+ # If the overlap < 50%, this will translate to linear transition when
329
+ # transition_power is 1.
330
+ weight = (weight / weight.max())**transition_power
331
+ for offset in offsets:
332
+ chunk = TensorChunk(mix, offset, segment)
333
+ if set_progress_bar:
334
+ progress_value += 1
335
+ set_progress_bar(0.1, (0.8/len(offsets)*progress_value))
336
+ chunk_out = apply_model_v2(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar)
337
+ else:
338
+ chunk_out = apply_model_v2(model, chunk, shifts=shifts)
339
+ chunk_length = chunk_out.shape[-1]
340
+ out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out
341
+ sum_weight[offset:offset + segment] += weight[:chunk_length]
342
+ offset += segment
343
+ assert sum_weight.min() > 0
344
+ out /= sum_weight
345
+ return out
346
+ elif shifts:
347
+ max_shift = int(0.5 * model.samplerate)
348
+ mix = tensor_chunk(mix)
349
+ padded_mix = mix.padded(length + 2 * max_shift)
350
+ out = 0
351
+ for _ in range(shifts):
352
+ offset = random.randint(0, max_shift)
353
+ shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
354
+
355
+ if set_progress_bar:
356
+ progress_value += 1
357
+ shifted_out = apply_model_v2(model, shifted, set_progress_bar=set_progress_bar)
358
+ else:
359
+ shifted_out = apply_model_v2(model, shifted)
360
+ out += shifted_out[..., max_shift - offset:]
361
+ out /= shifts
362
+ return out
363
+ else:
364
+ valid_length = model.valid_length(length)
365
+ mix = tensor_chunk(mix)
366
+ padded_mix = mix.padded(valid_length)
367
+ with th.no_grad():
368
+ out = model(padded_mix.unsqueeze(0))[0]
369
+ return center_trim(out, length)
370
+
371
+
372
+ @contextmanager
373
+ def temp_filenames(count, delete=True):
374
+ names = []
375
+ try:
376
+ for _ in range(count):
377
+ names.append(tempfile.NamedTemporaryFile(delete=False).name)
378
+ yield names
379
+ finally:
380
+ if delete:
381
+ for name in names:
382
+ os.unlink(name)
383
+
384
+
385
+ def get_quantizer(model, args, optimizer=None):
386
+ quantizer = None
387
+ if args.diffq:
388
+ quantizer = DiffQuantizer(
389
+ model, min_size=args.q_min_size, group_size=8)
390
+ if optimizer is not None:
391
+ quantizer.setup_optimizer(optimizer)
392
+ elif args.qat:
393
+ quantizer = UniformQuantizer(
394
+ model, bits=args.qat, min_size=args.q_min_size)
395
+ return quantizer
396
+
397
+
398
+ def load_model(path, strict=False):
399
+ with warnings.catch_warnings():
400
+ warnings.simplefilter("ignore")
401
+ load_from = path
402
+ package = th.load(load_from, 'cpu')
403
+
404
+ klass = package["klass"]
405
+ args = package["args"]
406
+ kwargs = package["kwargs"]
407
+
408
+ if strict:
409
+ model = klass(*args, **kwargs)
410
+ else:
411
+ sig = inspect.signature(klass)
412
+ for key in list(kwargs):
413
+ if key not in sig.parameters:
414
+ warnings.warn("Dropping inexistant parameter " + key)
415
+ del kwargs[key]
416
+ model = klass(*args, **kwargs)
417
+
418
+ state = package["state"]
419
+ training_args = package["training_args"]
420
+ quantizer = get_quantizer(model, training_args)
421
+
422
+ set_state(model, quantizer, state)
423
+ return model
424
+
425
+
426
+ def get_state(model, quantizer):
427
+ if quantizer is None:
428
+ state = {k: p.data.to('cpu') for k, p in model.state_dict().items()}
429
+ else:
430
+ state = quantizer.get_quantized_state()
431
+ buf = io.BytesIO()
432
+ th.save(state, buf)
433
+ state = {'compressed': zlib.compress(buf.getvalue())}
434
+ return state
435
+
436
+
437
+ def set_state(model, quantizer, state):
438
+ if quantizer is None:
439
+ model.load_state_dict(state)
440
+ else:
441
+ buf = io.BytesIO(zlib.decompress(state["compressed"]))
442
+ state = th.load(buf, "cpu")
443
+ quantizer.restore_quantized_state(state)
444
+
445
+ return state
446
+
447
+
448
+ def save_state(state, path):
449
+ buf = io.BytesIO()
450
+ th.save(state, buf)
451
+ sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8]
452
+
453
+ path = path.parent / (path.stem + "-" + sig + path.suffix)
454
+ path.write_bytes(buf.getvalue())
455
+
456
+
457
+ def save_model(model, quantizer, training_args, path):
458
+ args, kwargs = model._init_args_kwargs
459
+ klass = model.__class__
460
+
461
+ state = get_state(model, quantizer)
462
+
463
+ save_to = path
464
+ package = {
465
+ 'klass': klass,
466
+ 'args': args,
467
+ 'kwargs': kwargs,
468
+ 'state': state,
469
+ 'training_args': training_args,
470
+ }
471
+ th.save(package, save_to)
472
+
473
+
474
+ def capture_init(init):
475
+ @functools.wraps(init)
476
+ def __init__(self, *args, **kwargs):
477
+ self._init_args_kwargs = (args, kwargs)
478
+ init(self, *args, **kwargs)
479
+
480
+ return __init__
481
+
482
+ class DummyPoolExecutor:
483
+ class DummyResult:
484
+ def __init__(self, func, *args, **kwargs):
485
+ self.func = func
486
+ self.args = args
487
+ self.kwargs = kwargs
488
+
489
+ def result(self):
490
+ return self.func(*self.args, **self.kwargs)
491
+
492
+ def __init__(self, workers=0):
493
+ pass
494
+
495
+ def submit(self, func, *args, **kwargs):
496
+ return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
497
+
498
+ def __enter__(self):
499
+ return self
500
+
501
+ def __exit__(self, exc_type, exc_value, exc_tb):
502
+ return