Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from collections import defaultdict | |
from contextlib import contextmanager | |
import math | |
import os | |
import tempfile | |
import typing as tp | |
import errno | |
import functools | |
import hashlib | |
import inspect | |
import io | |
import os | |
import random | |
import socket | |
import tempfile | |
import warnings | |
import zlib | |
import tkinter as tk | |
from diffq import UniformQuantizer, DiffQuantizer | |
import torch as th | |
import tqdm | |
from torch import distributed | |
from torch.nn import functional as F | |
import torch | |
def unfold(a, kernel_size, stride): | |
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K] | |
with K the kernel size, by extracting frames with the given stride. | |
This will pad the input so that `F = ceil(T / K)`. | |
see https://github.com/pytorch/pytorch/issues/60466 | |
""" | |
*shape, length = a.shape | |
n_frames = math.ceil(length / stride) | |
tgt_length = (n_frames - 1) * stride + kernel_size | |
a = F.pad(a, (0, tgt_length - length)) | |
strides = list(a.stride()) | |
assert strides[-1] == 1, 'data should be contiguous' | |
strides = strides[:-1] + [stride, 1] | |
return a.as_strided([*shape, n_frames, kernel_size], strides) | |
def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]): | |
""" | |
Center trim `tensor` with respect to `reference`, along the last dimension. | |
`reference` can also be a number, representing the length to trim to. | |
If the size difference != 0 mod 2, the extra sample is removed on the right side. | |
""" | |
ref_size: int | |
if isinstance(reference, torch.Tensor): | |
ref_size = reference.size(-1) | |
else: | |
ref_size = reference | |
delta = tensor.size(-1) - ref_size | |
if delta < 0: | |
raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") | |
if delta: | |
tensor = tensor[..., delta // 2:-(delta - delta // 2)] | |
return tensor | |
def pull_metric(history: tp.List[dict], name: str): | |
out = [] | |
for metrics in history: | |
metric = metrics | |
for part in name.split("."): | |
metric = metric[part] | |
out.append(metric) | |
return out | |
def EMA(beta: float = 1): | |
""" | |
Exponential Moving Average callback. | |
Returns a single function that can be called to repeatidly update the EMA | |
with a dict of metrics. The callback will return | |
the new averaged dict of metrics. | |
Note that for `beta=1`, this is just plain averaging. | |
""" | |
fix: tp.Dict[str, float] = defaultdict(float) | |
total: tp.Dict[str, float] = defaultdict(float) | |
def _update(metrics: dict, weight: float = 1) -> dict: | |
nonlocal total, fix | |
for key, value in metrics.items(): | |
total[key] = total[key] * beta + weight * float(value) | |
fix[key] = fix[key] * beta + weight | |
return {key: tot / fix[key] for key, tot in total.items()} | |
return _update | |
def sizeof_fmt(num: float, suffix: str = 'B'): | |
""" | |
Given `num` bytes, return human readable size. | |
Taken from https://stackoverflow.com/a/1094933 | |
""" | |
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: | |
if abs(num) < 1024.0: | |
return "%3.1f%s%s" % (num, unit, suffix) | |
num /= 1024.0 | |
return "%.1f%s%s" % (num, 'Yi', suffix) | |
def temp_filenames(count: int, delete=True): | |
names = [] | |
try: | |
for _ in range(count): | |
names.append(tempfile.NamedTemporaryFile(delete=False).name) | |
yield names | |
finally: | |
if delete: | |
for name in names: | |
os.unlink(name) | |
def average_metric(metric, count=1.): | |
""" | |
Average `metric` which should be a float across all hosts. `count` should be | |
the weight for this particular host (i.e. number of examples). | |
""" | |
metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda') | |
distributed.all_reduce(metric, op=distributed.ReduceOp.SUM) | |
return metric[1].item() / metric[0].item() | |
def free_port(host='', low=20000, high=40000): | |
""" | |
Return a port number that is most likely free. | |
This could suffer from a race condition although | |
it should be quite rare. | |
""" | |
sock = socket.socket() | |
while True: | |
port = random.randint(low, high) | |
try: | |
sock.bind((host, port)) | |
except OSError as error: | |
if error.errno == errno.EADDRINUSE: | |
continue | |
raise | |
return port | |
def sizeof_fmt(num, suffix='B'): | |
""" | |
Given `num` bytes, return human readable size. | |
Taken from https://stackoverflow.com/a/1094933 | |
""" | |
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: | |
if abs(num) < 1024.0: | |
return "%3.1f%s%s" % (num, unit, suffix) | |
num /= 1024.0 | |
return "%.1f%s%s" % (num, 'Yi', suffix) | |
def human_seconds(seconds, display='.2f'): | |
""" | |
Given `seconds` seconds, return human readable duration. | |
""" | |
value = seconds * 1e6 | |
ratios = [1e3, 1e3, 60, 60, 24] | |
names = ['us', 'ms', 's', 'min', 'hrs', 'days'] | |
last = names.pop(0) | |
for name, ratio in zip(names, ratios): | |
if value / ratio < 0.3: | |
break | |
value /= ratio | |
last = name | |
return f"{format(value, display)} {last}" | |
class TensorChunk: | |
def __init__(self, tensor, offset=0, length=None): | |
total_length = tensor.shape[-1] | |
assert offset >= 0 | |
assert offset < total_length | |
if length is None: | |
length = total_length - offset | |
else: | |
length = min(total_length - offset, length) | |
self.tensor = tensor | |
self.offset = offset | |
self.length = length | |
self.device = tensor.device | |
def shape(self): | |
shape = list(self.tensor.shape) | |
shape[-1] = self.length | |
return shape | |
def padded(self, target_length): | |
delta = target_length - self.length | |
total_length = self.tensor.shape[-1] | |
assert delta >= 0 | |
start = self.offset - delta // 2 | |
end = start + target_length | |
correct_start = max(0, start) | |
correct_end = min(total_length, end) | |
pad_left = correct_start - start | |
pad_right = end - correct_end | |
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) | |
assert out.shape[-1] == target_length | |
return out | |
def tensor_chunk(tensor_or_chunk): | |
if isinstance(tensor_or_chunk, TensorChunk): | |
return tensor_or_chunk | |
else: | |
assert isinstance(tensor_or_chunk, th.Tensor) | |
return TensorChunk(tensor_or_chunk) | |
def apply_model_v1(model, mix, shifts=None, split=False, progress=False, set_progress_bar=None): | |
""" | |
Apply model to a given mixture. | |
Args: | |
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec | |
and apply the oppositve shift to the output. This is repeated `shifts` time and | |
all predictions are averaged. This effectively makes the model time equivariant | |
and improves SDR by up to 0.2 points. | |
split (bool): if True, the input will be broken down in 8 seconds extracts | |
and predictions will be performed individually on each and concatenated. | |
Useful for model with large memory footprint like Tasnet. | |
progress (bool): if True, show a progress bar (requires split=True) | |
""" | |
channels, length = mix.size() | |
device = mix.device | |
progress_value = 0 | |
if split: | |
out = th.zeros(4, channels, length, device=device) | |
shift = model.samplerate * 10 | |
offsets = range(0, length, shift) | |
scale = 10 | |
if progress: | |
offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds') | |
for offset in offsets: | |
chunk = mix[..., offset:offset + shift] | |
if set_progress_bar: | |
progress_value += 1 | |
set_progress_bar(0.1, (0.8/len(offsets)*progress_value)) | |
chunk_out = apply_model_v1(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar) | |
else: | |
chunk_out = apply_model_v1(model, chunk, shifts=shifts) | |
out[..., offset:offset + shift] = chunk_out | |
offset += shift | |
return out | |
elif shifts: | |
max_shift = int(model.samplerate / 2) | |
mix = F.pad(mix, (max_shift, max_shift)) | |
offsets = list(range(max_shift)) | |
random.shuffle(offsets) | |
out = 0 | |
for offset in offsets[:shifts]: | |
shifted = mix[..., offset:offset + length + max_shift] | |
if set_progress_bar: | |
shifted_out = apply_model_v1(model, shifted, set_progress_bar=set_progress_bar) | |
else: | |
shifted_out = apply_model_v1(model, shifted) | |
out += shifted_out[..., max_shift - offset:max_shift - offset + length] | |
out /= shifts | |
return out | |
else: | |
valid_length = model.valid_length(length) | |
delta = valid_length - length | |
padded = F.pad(mix, (delta // 2, delta - delta // 2)) | |
with th.no_grad(): | |
out = model(padded.unsqueeze(0))[0] | |
return center_trim(out, mix) | |
def apply_model_v2(model, mix, shifts=None, split=False, | |
overlap=0.25, transition_power=1., progress=False, set_progress_bar=None): | |
""" | |
Apply model to a given mixture. | |
Args: | |
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec | |
and apply the oppositve shift to the output. This is repeated `shifts` time and | |
all predictions are averaged. This effectively makes the model time equivariant | |
and improves SDR by up to 0.2 points. | |
split (bool): if True, the input will be broken down in 8 seconds extracts | |
and predictions will be performed individually on each and concatenated. | |
Useful for model with large memory footprint like Tasnet. | |
progress (bool): if True, show a progress bar (requires split=True) | |
""" | |
assert transition_power >= 1, "transition_power < 1 leads to weird behavior." | |
device = mix.device | |
channels, length = mix.shape | |
progress_value = 0 | |
if split: | |
out = th.zeros(len(model.sources), channels, length, device=device) | |
sum_weight = th.zeros(length, device=device) | |
segment = model.segment_length | |
stride = int((1 - overlap) * segment) | |
offsets = range(0, length, stride) | |
scale = stride / model.samplerate | |
if progress: | |
offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds') | |
# We start from a triangle shaped weight, with maximal weight in the middle | |
# of the segment. Then we normalize and take to the power `transition_power`. | |
# Large values of transition power will lead to sharper transitions. | |
weight = th.cat([th.arange(1, segment // 2 + 1), | |
th.arange(segment - segment // 2, 0, -1)]).to(device) | |
assert len(weight) == segment | |
# If the overlap < 50%, this will translate to linear transition when | |
# transition_power is 1. | |
weight = (weight / weight.max())**transition_power | |
for offset in offsets: | |
chunk = TensorChunk(mix, offset, segment) | |
if set_progress_bar: | |
progress_value += 1 | |
set_progress_bar(0.1, (0.8/len(offsets)*progress_value)) | |
chunk_out = apply_model_v2(model, chunk, shifts=shifts, set_progress_bar=set_progress_bar) | |
else: | |
chunk_out = apply_model_v2(model, chunk, shifts=shifts) | |
chunk_length = chunk_out.shape[-1] | |
out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out | |
sum_weight[offset:offset + segment] += weight[:chunk_length] | |
offset += segment | |
assert sum_weight.min() > 0 | |
out /= sum_weight | |
return out | |
elif shifts: | |
max_shift = int(0.5 * model.samplerate) | |
mix = tensor_chunk(mix) | |
padded_mix = mix.padded(length + 2 * max_shift) | |
out = 0 | |
for _ in range(shifts): | |
offset = random.randint(0, max_shift) | |
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) | |
if set_progress_bar: | |
progress_value += 1 | |
shifted_out = apply_model_v2(model, shifted, set_progress_bar=set_progress_bar) | |
else: | |
shifted_out = apply_model_v2(model, shifted) | |
out += shifted_out[..., max_shift - offset:] | |
out /= shifts | |
return out | |
else: | |
valid_length = model.valid_length(length) | |
mix = tensor_chunk(mix) | |
padded_mix = mix.padded(valid_length) | |
with th.no_grad(): | |
out = model(padded_mix.unsqueeze(0))[0] | |
return center_trim(out, length) | |
def temp_filenames(count, delete=True): | |
names = [] | |
try: | |
for _ in range(count): | |
names.append(tempfile.NamedTemporaryFile(delete=False).name) | |
yield names | |
finally: | |
if delete: | |
for name in names: | |
os.unlink(name) | |
def get_quantizer(model, args, optimizer=None): | |
quantizer = None | |
if args.diffq: | |
quantizer = DiffQuantizer( | |
model, min_size=args.q_min_size, group_size=8) | |
if optimizer is not None: | |
quantizer.setup_optimizer(optimizer) | |
elif args.qat: | |
quantizer = UniformQuantizer( | |
model, bits=args.qat, min_size=args.q_min_size) | |
return quantizer | |
def load_model(path, strict=False): | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
load_from = path | |
package = th.load(load_from, 'cpu') | |
klass = package["klass"] | |
args = package["args"] | |
kwargs = package["kwargs"] | |
if strict: | |
model = klass(*args, **kwargs) | |
else: | |
sig = inspect.signature(klass) | |
for key in list(kwargs): | |
if key not in sig.parameters: | |
warnings.warn("Dropping inexistant parameter " + key) | |
del kwargs[key] | |
model = klass(*args, **kwargs) | |
state = package["state"] | |
training_args = package["training_args"] | |
quantizer = get_quantizer(model, training_args) | |
set_state(model, quantizer, state) | |
return model | |
def get_state(model, quantizer): | |
if quantizer is None: | |
state = {k: p.data.to('cpu') for k, p in model.state_dict().items()} | |
else: | |
state = quantizer.get_quantized_state() | |
buf = io.BytesIO() | |
th.save(state, buf) | |
state = {'compressed': zlib.compress(buf.getvalue())} | |
return state | |
def set_state(model, quantizer, state): | |
if quantizer is None: | |
model.load_state_dict(state) | |
else: | |
buf = io.BytesIO(zlib.decompress(state["compressed"])) | |
state = th.load(buf, "cpu") | |
quantizer.restore_quantized_state(state) | |
return state | |
def save_state(state, path): | |
buf = io.BytesIO() | |
th.save(state, buf) | |
sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] | |
path = path.parent / (path.stem + "-" + sig + path.suffix) | |
path.write_bytes(buf.getvalue()) | |
def save_model(model, quantizer, training_args, path): | |
args, kwargs = model._init_args_kwargs | |
klass = model.__class__ | |
state = get_state(model, quantizer) | |
save_to = path | |
package = { | |
'klass': klass, | |
'args': args, | |
'kwargs': kwargs, | |
'state': state, | |
'training_args': training_args, | |
} | |
th.save(package, save_to) | |
def capture_init(init): | |
def __init__(self, *args, **kwargs): | |
self._init_args_kwargs = (args, kwargs) | |
init(self, *args, **kwargs) | |
return __init__ | |
class DummyPoolExecutor: | |
class DummyResult: | |
def __init__(self, func, *args, **kwargs): | |
self.func = func | |
self.args = args | |
self.kwargs = kwargs | |
def result(self): | |
return self.func(*self.args, **self.kwargs) | |
def __init__(self, workers=0): | |
pass | |
def submit(self, func, *args, **kwargs): | |
return DummyPoolExecutor.DummyResult(func, *args, **kwargs) | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_value, exc_tb): | |
return | |