diff --git a/audiocraft/audiocraft/__init__.py b/audiocraft/audiocraft/__init__.py deleted file mode 100644 index 6ab346075f1b35366e7231054513097b87552c6f..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -""" -AudioCraft is a general framework for training audio generative models. -At the moment we provide the training code for: - -- [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art - text-to-music and melody+text autoregressive generative model. - For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model, - `audiocraft.models.musicgen.MusicGen`. -- [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art - text-to-general-audio generative model. -- [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity - neural audio codec which provides an excellent tokenizer for autoregressive language models. - See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`. -- [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that - improves the perceived quality and reduces the artifacts coming from adversarial decoders. -""" - -# flake8: noqa -from . import data, modules, models - -__version__ = '1.0.0' diff --git a/audiocraft/audiocraft/__pycache__/__init__.cpython-311.pyc b/audiocraft/audiocraft/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 86fa306b5e9a2d1640efc285f194a4467aa43c56..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/__pycache__/environment.cpython-311.pyc b/audiocraft/audiocraft/__pycache__/environment.cpython-311.pyc deleted file mode 100644 index 1d5e9b099c329be39daed775b20be484b9eb96cc..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/__pycache__/environment.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/__pycache__/train.cpython-311.pyc b/audiocraft/audiocraft/__pycache__/train.cpython-311.pyc deleted file mode 100644 index 40e12c00bca52906ac5864e5b50eddd3008f0207..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/__pycache__/train.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/adversarial/__init__.py b/audiocraft/audiocraft/adversarial/__init__.py deleted file mode 100644 index 864058706fbfae13d7f7dc850cc411a2f27d1510..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/adversarial/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Adversarial losses and discriminator architectures.""" - -# flake8: noqa -from .discriminators import ( - MultiPeriodDiscriminator, - MultiScaleDiscriminator, - MultiScaleSTFTDiscriminator -) -from .losses import ( - AdversarialLoss, - AdvLossType, - get_adv_criterion, - get_fake_criterion, - get_real_criterion, - FeatLossType, - FeatureMatchingLoss -) diff --git a/audiocraft/audiocraft/adversarial/__pycache__/__init__.cpython-311.pyc b/audiocraft/audiocraft/adversarial/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 8351a050986017eb21b892eb45fe3048b4d9e100..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/adversarial/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/adversarial/__pycache__/losses.cpython-311.pyc b/audiocraft/audiocraft/adversarial/__pycache__/losses.cpython-311.pyc deleted file mode 100644 index fc78db435ee9a3d5ba7a14b96fb715a1e8350a6b..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/adversarial/__pycache__/losses.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/adversarial/discriminators/__init__.py b/audiocraft/audiocraft/adversarial/discriminators/__init__.py deleted file mode 100644 index f9e5ff59950ee0b1d1a67c9b3831d67d08048148..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/adversarial/discriminators/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# flake8: noqa -from .mpd import MultiPeriodDiscriminator -from .msd import MultiScaleDiscriminator -from .msstftd import MultiScaleSTFTDiscriminator diff --git a/audiocraft/audiocraft/adversarial/discriminators/__pycache__/__init__.cpython-311.pyc b/audiocraft/audiocraft/adversarial/discriminators/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 609a6a962bb6b29c1da747c8f0d396752582776b..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/adversarial/discriminators/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/adversarial/discriminators/__pycache__/base.cpython-311.pyc b/audiocraft/audiocraft/adversarial/discriminators/__pycache__/base.cpython-311.pyc deleted file mode 100644 index 9f24c63486bef279219b1ecb0c713c1f595b55dc..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/adversarial/discriminators/__pycache__/base.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/adversarial/discriminators/__pycache__/mpd.cpython-311.pyc b/audiocraft/audiocraft/adversarial/discriminators/__pycache__/mpd.cpython-311.pyc deleted file mode 100644 index ca3f4f1eaed8ef2179ccbe1b3032a7b2f6de82f3..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/adversarial/discriminators/__pycache__/mpd.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/adversarial/discriminators/__pycache__/msd.cpython-311.pyc b/audiocraft/audiocraft/adversarial/discriminators/__pycache__/msd.cpython-311.pyc deleted file mode 100644 index 20c29000d813fbf6e4adc10d20b255b30c0b1689..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/adversarial/discriminators/__pycache__/msd.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/adversarial/discriminators/__pycache__/msstftd.cpython-311.pyc b/audiocraft/audiocraft/adversarial/discriminators/__pycache__/msstftd.cpython-311.pyc deleted file mode 100644 index 0bb5f8d391ad5e488cebd3e0c63865d4bf13f127..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/adversarial/discriminators/__pycache__/msstftd.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/adversarial/discriminators/base.py b/audiocraft/audiocraft/adversarial/discriminators/base.py deleted file mode 100644 index a9d517e9f5bf0f4e18252c45c8db3a35a7255f69..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/adversarial/discriminators/base.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from abc import ABC, abstractmethod -import typing as tp - -import torch -import torch.nn as nn - - -FeatureMapType = tp.List[torch.Tensor] -LogitsType = torch.Tensor -MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]] - - -class MultiDiscriminator(ABC, nn.Module): - """Base implementation for discriminators composed of sub-discriminators acting at different scales. - """ - def __init__(self): - super().__init__() - - @abstractmethod - def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: - ... - - @property - @abstractmethod - def num_discriminators(self) -> int: - """Number of discriminators. - """ - ... diff --git a/audiocraft/audiocraft/adversarial/discriminators/mpd.py b/audiocraft/audiocraft/adversarial/discriminators/mpd.py deleted file mode 100644 index 8debd1fa72d77ca03df680facb60bdf79638cade..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/adversarial/discriminators/mpd.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import typing as tp - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ...modules import NormConv2d -from .base import MultiDiscriminator, MultiDiscriminatorOutputType - - -def get_padding(kernel_size: int, dilation: int = 1) -> int: - return int((kernel_size * dilation - dilation) / 2) - - -class PeriodDiscriminator(nn.Module): - """Period sub-discriminator. - - Args: - period (int): Period between samples of audio. - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. - n_layers (int): Number of convolutional layers. - kernel_sizes (list of int): Kernel sizes for convolutions. - stride (int): Stride for convolutions. - filters (int): Initial number of filters in convolutions. - filters_scale (int): Multiplier of number of filters as we increase depth. - max_filters (int): Maximum number of filters. - norm (str): Normalization method. - activation (str): Activation function. - activation_params (dict): Parameters to provide to the activation function. - """ - def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1, - n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3, - filters: int = 8, filters_scale: int = 4, max_filters: int = 1024, - norm: str = 'weight_norm', activation: str = 'LeakyReLU', - activation_params: dict = {'negative_slope': 0.2}): - super().__init__() - self.period = period - self.n_layers = n_layers - self.activation = getattr(torch.nn, activation)(**activation_params) - self.convs = nn.ModuleList() - in_chs = in_channels - for i in range(self.n_layers): - out_chs = min(filters * (filters_scale ** (i + 1)), max_filters) - eff_stride = 1 if i == self.n_layers - 1 else stride - self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1), - padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm)) - in_chs = out_chs - self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1, - padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm) - - def forward(self, x: torch.Tensor): - fmap = [] - # 1d to 2d - b, c, t = x.shape - if t % self.period != 0: # pad first - n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), 'reflect') - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for conv in self.convs: - x = conv(x) - x = self.activation(x) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - # x = torch.flatten(x, 1, -1) - - return x, fmap - - -class MultiPeriodDiscriminator(MultiDiscriminator): - """Multi-Period (MPD) Discriminator. - - Args: - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. - periods (Sequence[int]): Periods between samples of audio for the sub-discriminators. - **kwargs: Additional args for `PeriodDiscriminator` - """ - def __init__(self, in_channels: int = 1, out_channels: int = 1, - periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs): - super().__init__() - self.discriminators = nn.ModuleList([ - PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods - ]) - - @property - def num_discriminators(self): - return len(self.discriminators) - - def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: - logits = [] - fmaps = [] - for disc in self.discriminators: - logit, fmap = disc(x) - logits.append(logit) - fmaps.append(fmap) - return logits, fmaps diff --git a/audiocraft/audiocraft/adversarial/discriminators/msd.py b/audiocraft/audiocraft/adversarial/discriminators/msd.py deleted file mode 100644 index c4e67e29b46ab22f6ffeec85ffc64d8b99800b1b..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/adversarial/discriminators/msd.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import typing as tp - -import numpy as np -import torch -import torch.nn as nn - -from ...modules import NormConv1d -from .base import MultiDiscriminator, MultiDiscriminatorOutputType - - -class ScaleDiscriminator(nn.Module): - """Waveform sub-discriminator. - - Args: - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. - kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions. - filters (int): Number of initial filters for convolutions. - max_filters (int): Maximum number of filters. - downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions. - inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions. - groups (Sequence[int] or None): Groups for inner convolutions. - strides (Sequence[int] or None): Strides for inner convolutions. - paddings (Sequence[int] or None): Paddings for inner convolutions. - norm (str): Normalization method. - activation (str): Activation function. - activation_params (dict): Parameters to provide to the activation function. - pad (str): Padding for initial convolution. - pad_params (dict): Parameters to provide to the padding module. - """ - def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3], - filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4], - inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None, - strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None, - norm: str = 'weight_norm', activation: str = 'LeakyReLU', - activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d', - pad_params: dict = {}): - super().__init__() - assert len(kernel_sizes) == 2 - assert kernel_sizes[0] % 2 == 1 - assert kernel_sizes[1] % 2 == 1 - assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales)) - assert (groups is None or len(groups) == len(downsample_scales)) - assert (strides is None or len(strides) == len(downsample_scales)) - assert (paddings is None or len(paddings) == len(downsample_scales)) - self.activation = getattr(torch.nn, activation)(**activation_params) - self.convs = nn.ModuleList() - self.convs.append( - nn.Sequential( - getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params), - NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm) - ) - ) - - in_chs = filters - for i, downsample_scale in enumerate(downsample_scales): - out_chs = min(in_chs * downsample_scale, max_filters) - default_kernel_size = downsample_scale * 10 + 1 - default_stride = downsample_scale - default_padding = (default_kernel_size - 1) // 2 - default_groups = in_chs // 4 - self.convs.append( - NormConv1d(in_chs, out_chs, - kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size, - stride=strides[i] if strides else default_stride, - groups=groups[i] if groups else default_groups, - padding=paddings[i] if paddings else default_padding, - norm=norm)) - in_chs = out_chs - - out_chs = min(in_chs * 2, max_filters) - self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1, - padding=(kernel_sizes[0] - 1) // 2, norm=norm)) - self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1, - padding=(kernel_sizes[1] - 1) // 2, norm=norm) - - def forward(self, x: torch.Tensor): - fmap = [] - for layer in self.convs: - x = layer(x) - x = self.activation(x) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - # x = torch.flatten(x, 1, -1) - return x, fmap - - -class MultiScaleDiscriminator(MultiDiscriminator): - """Multi-Scale (MSD) Discriminator, - - Args: - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. - downsample_factor (int): Downsampling factor between the different scales. - scale_norms (Sequence[str]): Normalization for each sub-discriminator. - **kwargs: Additional args for ScaleDiscriminator. - """ - def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2, - scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs): - super().__init__() - self.discriminators = nn.ModuleList([ - ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms - ]) - self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor) - - @property - def num_discriminators(self): - return len(self.discriminators) - - def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: - logits = [] - fmaps = [] - for i, disc in enumerate(self.discriminators): - if i != 0: - self.downsample(x) - logit, fmap = disc(x) - logits.append(logit) - fmaps.append(fmap) - return logits, fmaps diff --git a/audiocraft/audiocraft/adversarial/discriminators/msstftd.py b/audiocraft/audiocraft/adversarial/discriminators/msstftd.py deleted file mode 100644 index 81a9100961c7a89a39df2643b24268fb90bfeaa4..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/adversarial/discriminators/msstftd.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import typing as tp - -import torchaudio -import torch -from torch import nn -from einops import rearrange - -from ...modules import NormConv2d -from .base import MultiDiscriminator, MultiDiscriminatorOutputType - - -def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)): - return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2) - - -class DiscriminatorSTFT(nn.Module): - """STFT sub-discriminator. - - Args: - filters (int): Number of filters in convolutions. - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. - n_fft (int): Size of FFT for each scale. - hop_length (int): Length of hop between STFT windows for each scale. - kernel_size (tuple of int): Inner Conv2d kernel sizes. - stride (tuple of int): Inner Conv2d strides. - dilations (list of int): Inner Conv2d dilation on the time dimension. - win_length (int): Window size for each scale. - normalized (bool): Whether to normalize by magnitude after stft. - norm (str): Normalization method. - activation (str): Activation function. - activation_params (dict): Parameters to provide to the activation function. - growth (int): Growth factor for the filters. - """ - def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, - n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024, - filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4], - stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm', - activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}): - super().__init__() - assert len(kernel_size) == 2 - assert len(stride) == 2 - self.filters = filters - self.in_channels = in_channels - self.out_channels = out_channels - self.n_fft = n_fft - self.hop_length = hop_length - self.win_length = win_length - self.normalized = normalized - self.activation = getattr(torch.nn, activation)(**activation_params) - self.spec_transform = torchaudio.transforms.Spectrogram( - n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window, - normalized=self.normalized, center=False, pad_mode=None, power=None) - spec_channels = 2 * self.in_channels - self.convs = nn.ModuleList() - self.convs.append( - NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size)) - ) - in_chs = min(filters_scale * self.filters, max_filters) - for i, dilation in enumerate(dilations): - out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) - self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, - dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)), - norm=norm)) - in_chs = out_chs - out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters) - self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]), - padding=get_2d_padding((kernel_size[0], kernel_size[0])), - norm=norm)) - self.conv_post = NormConv2d(out_chs, self.out_channels, - kernel_size=(kernel_size[0], kernel_size[0]), - padding=get_2d_padding((kernel_size[0], kernel_size[0])), - norm=norm) - - def forward(self, x: torch.Tensor): - fmap = [] - z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] - z = torch.cat([z.real, z.imag], dim=1) - z = rearrange(z, 'b c w t -> b c t w') - for i, layer in enumerate(self.convs): - z = layer(z) - z = self.activation(z) - fmap.append(z) - z = self.conv_post(z) - return z, fmap - - -class MultiScaleSTFTDiscriminator(MultiDiscriminator): - """Multi-Scale STFT (MS-STFT) discriminator. - - Args: - filters (int): Number of filters in convolutions. - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. - sep_channels (bool): Separate channels to distinct samples for stereo support. - n_ffts (Sequence[int]): Size of FFT for each scale. - hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale. - win_lengths (Sequence[int]): Window size for each scale. - **kwargs: Additional args for STFTDiscriminator. - """ - def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False, - n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128], - win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs): - super().__init__() - assert len(n_ffts) == len(hop_lengths) == len(win_lengths) - self.sep_channels = sep_channels - self.discriminators = nn.ModuleList([ - DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels, - n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs) - for i in range(len(n_ffts)) - ]) - - @property - def num_discriminators(self): - return len(self.discriminators) - - def _separate_channels(self, x: torch.Tensor) -> torch.Tensor: - B, C, T = x.shape - return x.view(-1, 1, T) - - def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: - logits = [] - fmaps = [] - for disc in self.discriminators: - logit, fmap = disc(x) - logits.append(logit) - fmaps.append(fmap) - return logits, fmaps diff --git a/audiocraft/audiocraft/adversarial/losses.py b/audiocraft/audiocraft/adversarial/losses.py deleted file mode 100644 index be293e739bdc2d91273f30fb789befe7c8b49a43..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/adversarial/losses.py +++ /dev/null @@ -1,228 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Utility module to handle adversarial losses without requiring to mess up the main training loop. -""" - -import typing as tp - -import flashy -import torch -import torch.nn as nn -import torch.nn.functional as F - - -ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2'] - - -AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]] -FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] - - -class AdversarialLoss(nn.Module): - """Adversary training wrapper. - - Args: - adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples. - We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]`` - where the first item is a list of logits and the second item is a list of feature maps. - optimizer (torch.optim.Optimizer): Optimizer used for training the given module. - loss (AdvLossType): Loss function for generator training. - loss_real (AdvLossType): Loss function for adversarial training on logits from real samples. - loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples. - loss_feat (FeatLossType): Feature matching loss function for generator training. - normalize (bool): Whether to normalize by number of sub-discriminators. - - Example of usage: - adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake) - for real in loader: - noise = torch.randn(...) - fake = model(noise) - adv_loss.train_adv(fake, real) - loss, _ = adv_loss(fake, real) - loss.backward() - """ - def __init__(self, - adversary: nn.Module, - optimizer: torch.optim.Optimizer, - loss: AdvLossType, - loss_real: AdvLossType, - loss_fake: AdvLossType, - loss_feat: tp.Optional[FeatLossType] = None, - normalize: bool = True): - super().__init__() - self.adversary: nn.Module = adversary - flashy.distrib.broadcast_model(self.adversary) - self.optimizer = optimizer - self.loss = loss - self.loss_real = loss_real - self.loss_fake = loss_fake - self.loss_feat = loss_feat - self.normalize = normalize - - def _save_to_state_dict(self, destination, prefix, keep_vars): - # Add the optimizer state dict inside our own. - super()._save_to_state_dict(destination, prefix, keep_vars) - destination[prefix + 'optimizer'] = self.optimizer.state_dict() - return destination - - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): - # Load optimizer state. - self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer')) - super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) - - def get_adversary_pred(self, x): - """Run adversary model, validating expected output format.""" - logits, fmaps = self.adversary(x) - assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \ - f'Expecting a list of tensors as logits but {type(logits)} found.' - assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.' - for fmap in fmaps: - assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \ - f'Expecting a list of tensors as feature maps but {type(fmap)} found.' - return logits, fmaps - - def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor: - """Train the adversary with the given fake and real example. - - We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]]. - The first item being the logits and second item being a list of feature maps for each sub-discriminator. - - This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`) - and call the optimizer. - """ - loss = torch.tensor(0., device=fake.device) - all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach()) - all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach()) - n_sub_adversaries = len(all_logits_fake_is_fake) - for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake): - loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake) - - if self.normalize: - loss /= n_sub_adversaries - - self.optimizer.zero_grad() - with flashy.distrib.eager_sync_model(self.adversary): - loss.backward() - self.optimizer.step() - - return loss - - def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: - """Return the loss for the generator, i.e. trying to fool the adversary, - and feature matching loss if provided. - """ - adv = torch.tensor(0., device=fake.device) - feat = torch.tensor(0., device=fake.device) - with flashy.utils.readonly(self.adversary): - all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake) - all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real) - n_sub_adversaries = len(all_logits_fake_is_fake) - for logit_fake_is_fake in all_logits_fake_is_fake: - adv += self.loss(logit_fake_is_fake) - if self.loss_feat: - for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real): - feat += self.loss_feat(fmap_fake, fmap_real) - - if self.normalize: - adv /= n_sub_adversaries - feat /= n_sub_adversaries - - return adv, feat - - -def get_adv_criterion(loss_type: str) -> tp.Callable: - assert loss_type in ADVERSARIAL_LOSSES - if loss_type == 'mse': - return mse_loss - elif loss_type == 'hinge': - return hinge_loss - elif loss_type == 'hinge2': - return hinge2_loss - raise ValueError('Unsupported loss') - - -def get_fake_criterion(loss_type: str) -> tp.Callable: - assert loss_type in ADVERSARIAL_LOSSES - if loss_type == 'mse': - return mse_fake_loss - elif loss_type in ['hinge', 'hinge2']: - return hinge_fake_loss - raise ValueError('Unsupported loss') - - -def get_real_criterion(loss_type: str) -> tp.Callable: - assert loss_type in ADVERSARIAL_LOSSES - if loss_type == 'mse': - return mse_real_loss - elif loss_type in ['hinge', 'hinge2']: - return hinge_real_loss - raise ValueError('Unsupported loss') - - -def mse_real_loss(x: torch.Tensor) -> torch.Tensor: - return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x)) - - -def mse_fake_loss(x: torch.Tensor) -> torch.Tensor: - return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x)) - - -def hinge_real_loss(x: torch.Tensor) -> torch.Tensor: - return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x))) - - -def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor: - return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x))) - - -def mse_loss(x: torch.Tensor) -> torch.Tensor: - if x.numel() == 0: - return torch.tensor([0.0], device=x.device) - return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x)) - - -def hinge_loss(x: torch.Tensor) -> torch.Tensor: - if x.numel() == 0: - return torch.tensor([0.0], device=x.device) - return -x.mean() - - -def hinge2_loss(x: torch.Tensor) -> torch.Tensor: - if x.numel() == 0: - return torch.tensor([0.0]) - return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x))) - - -class FeatureMatchingLoss(nn.Module): - """Feature matching loss for adversarial training. - - Args: - loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1). - normalize (bool): Whether to normalize the loss. - by number of feature maps. - """ - def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True): - super().__init__() - self.loss = loss - self.normalize = normalize - - def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor: - assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0 - feat_loss = torch.tensor(0., device=fmap_fake[0].device) - feat_scale = torch.tensor(0., device=fmap_fake[0].device) - n_fmaps = 0 - for (feat_fake, feat_real) in zip(fmap_fake, fmap_real): - assert feat_fake.shape == feat_real.shape - n_fmaps += 1 - feat_loss += self.loss(feat_fake, feat_real) - feat_scale += torch.mean(torch.abs(feat_real)) - - if self.normalize: - feat_loss /= n_fmaps - - return feat_loss diff --git a/audiocraft/audiocraft/data/__init__.py b/audiocraft/audiocraft/data/__init__.py deleted file mode 100644 index 3c9447208f3b3e620c1ee5ea3f68e49d43b8ef33..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/data/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Audio loading and writing support. Datasets for raw audio -or also including some metadata.""" - -# flake8: noqa -from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset, btc_chords diff --git a/audiocraft/audiocraft/data/__pycache__/__init__.cpython-311.pyc b/audiocraft/audiocraft/data/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 33b8903e49c71c8b16938b2ee0673913e7dfe698..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/data/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/data/__pycache__/audio.cpython-311.pyc b/audiocraft/audiocraft/data/__pycache__/audio.cpython-311.pyc deleted file mode 100644 index dd295262702b5a0eccc7b7389109f48d4217bb29..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/data/__pycache__/audio.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/data/__pycache__/audio_dataset.cpython-311.pyc b/audiocraft/audiocraft/data/__pycache__/audio_dataset.cpython-311.pyc deleted file mode 100644 index 6cbe6141d9db316dbc205aea721e7f24affca540..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/data/__pycache__/audio_dataset.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/data/__pycache__/audio_utils.cpython-311.pyc b/audiocraft/audiocraft/data/__pycache__/audio_utils.cpython-311.pyc deleted file mode 100644 index e78660c508e6d09fc7ba697de4e96bb6d09513f8..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/data/__pycache__/audio_utils.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/data/__pycache__/btc_chords.cpython-311.pyc b/audiocraft/audiocraft/data/__pycache__/btc_chords.cpython-311.pyc deleted file mode 100644 index f04bcb553fae5367ca4250ff45b5a35bfd59c925..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/data/__pycache__/btc_chords.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/data/__pycache__/chords.cpython-311.pyc b/audiocraft/audiocraft/data/__pycache__/chords.cpython-311.pyc deleted file mode 100644 index cfd648bd86e1b8436aa9d892a0a750655f631e6f..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/data/__pycache__/chords.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/data/__pycache__/info_audio_dataset.cpython-311.pyc b/audiocraft/audiocraft/data/__pycache__/info_audio_dataset.cpython-311.pyc deleted file mode 100644 index 70748fcce561ab3f15f55dcb5460709484130fa2..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/data/__pycache__/info_audio_dataset.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/data/__pycache__/music_dataset.cpython-311.pyc b/audiocraft/audiocraft/data/__pycache__/music_dataset.cpython-311.pyc deleted file mode 100644 index 0c3e8ec044d21279b85d7bdefc741e1c7bcdb9ac..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/data/__pycache__/music_dataset.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/data/__pycache__/sound_dataset.cpython-311.pyc b/audiocraft/audiocraft/data/__pycache__/sound_dataset.cpython-311.pyc deleted file mode 100644 index 8ead434b04295cc05ccc5d8b669b0ffc6d2e1f67..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/data/__pycache__/sound_dataset.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/data/__pycache__/zip.cpython-311.pyc b/audiocraft/audiocraft/data/__pycache__/zip.cpython-311.pyc deleted file mode 100644 index 52ef5c4c9c7ace559374cc08fef3b865049cee8c..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/data/__pycache__/zip.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/data/audio.py b/audiocraft/audiocraft/data/audio.py deleted file mode 100644 index 8348791b63a19685f163136c0eccb7bc04e503d0..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/data/audio.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Audio IO methods are defined in this module (info, read, write), -We rely on av library for faster read when possible, otherwise on torchaudio. -""" - -from dataclasses import dataclass -from pathlib import Path -import logging -import typing as tp - -import numpy as np -import soundfile -import torch -from torch.nn import functional as F -import torchaudio as ta - -import av - -from .audio_utils import f32_pcm, i16_pcm, normalize_audio - - -_av_initialized = False - - -def _init_av(): - global _av_initialized - if _av_initialized: - return - logger = logging.getLogger('libav.mp3') - logger.setLevel(logging.ERROR) - _av_initialized = True - - -@dataclass(frozen=True) -class AudioFileInfo: - sample_rate: int - duration: float - channels: int - - -def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: - _init_av() - with av.open(str(filepath)) as af: - stream = af.streams.audio[0] - sample_rate = stream.codec_context.sample_rate - duration = float(stream.duration * stream.time_base) - channels = stream.channels - return AudioFileInfo(sample_rate, duration, channels) - - -def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: - info = soundfile.info(filepath) - return AudioFileInfo(info.samplerate, info.duration, info.channels) - - -def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: - # torchaudio no longer returns useful duration informations for some formats like mp3s. - filepath = Path(filepath) - if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info - # ffmpeg has some weird issue with flac. - return _soundfile_info(filepath) - else: - return _av_info(filepath) - - -def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]: - """FFMPEG-based audio file reading using PyAV bindings. - Soundfile cannot read mp3 and av_read is more efficient than torchaudio. - - Args: - filepath (str or Path): Path to audio file to read. - seek_time (float): Time at which to start reading in the file. - duration (float): Duration to read from the file. If set to -1, the whole file is read. - Returns: - tuple of torch.Tensor, int: Tuple containing audio data and sample rate - """ - _init_av() - with av.open(str(filepath)) as af: - stream = af.streams.audio[0] - sr = stream.codec_context.sample_rate - num_frames = int(sr * duration) if duration >= 0 else -1 - frame_offset = int(sr * seek_time) - # we need a small negative offset otherwise we get some edge artifact - # from the mp3 decoder. - af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream) - frames = [] - length = 0 - for frame in af.decode(streams=stream.index): - current_offset = int(frame.rate * frame.pts * frame.time_base) - strip = max(0, frame_offset - current_offset) - buf = torch.from_numpy(frame.to_ndarray()) - if buf.shape[0] != stream.channels: - buf = buf.view(-1, stream.channels).t() - buf = buf[:, strip:] - frames.append(buf) - length += buf.shape[1] - if num_frames > 0 and length >= num_frames: - break - assert frames - # If the above assert fails, it is likely because we seeked past the end of file point, - # in which case ffmpeg returns a single frame with only zeros, and a weird timestamp. - # This will need proper debugging, in due time. - wav = torch.cat(frames, dim=1) - assert wav.shape[0] == stream.channels - if num_frames > 0: - wav = wav[:, :num_frames] - return f32_pcm(wav), sr - - -def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., - duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]: - """Read audio by picking the most appropriate backend tool based on the audio format. - - Args: - filepath (str or Path): Path to audio file to read. - seek_time (float): Time at which to start reading in the file. - duration (float): Duration to read from the file. If set to -1, the whole file is read. - pad (bool): Pad output audio if not reaching expected duration. - Returns: - tuple of torch.Tensor, int: Tuple containing audio data and sample rate. - """ - fp = Path(filepath) - if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg - # There is some bug with ffmpeg and reading flac - info = _soundfile_info(filepath) - frames = -1 if duration <= 0 else int(duration * info.sample_rate) - frame_offset = int(seek_time * info.sample_rate) - wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32) - assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}" - wav = torch.from_numpy(wav).t().contiguous() - if len(wav.shape) == 1: - wav = torch.unsqueeze(wav, 0) - elif ( - fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats() - and duration <= 0 and seek_time == 0 - ): - # Torchaudio is faster if we load an entire file at once. - wav, sr = ta.load(fp) - else: - wav, sr = _av_read(filepath, seek_time, duration) - if pad and duration > 0: - expected_frames = int(duration * sr) - wav = F.pad(wav, (0, expected_frames - wav.shape[-1])) - return wav, sr - - -def audio_write(stem_name: tp.Union[str, Path], - wav: torch.Tensor, sample_rate: int, - format: str = 'wav', mp3_rate: int = 320, normalize: bool = True, - strategy: str = 'peak', peak_clip_headroom_db: float = 1, - rms_headroom_db: float = 18, loudness_headroom_db: float = 14, - loudness_compressor: bool = False, - log_clipping: bool = True, make_parent_dir: bool = True, - add_suffix: bool = True) -> Path: - """Convenience function for saving audio to disk. Returns the filename the audio was written to. - - Args: - stem_name (str or Path): Filename without extension which will be added automatically. - format (str): Either "wav" or "mp3". - mp3_rate (int): kbps when using mp3s. - normalize (bool): if `True` (default), normalizes according to the prescribed - strategy (see after). If `False`, the strategy is only used in case clipping - would happen. - strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak', - i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square - with extra headroom to avoid clipping. 'clip' just clips. - peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy. - rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger - than the `peak_clip` one to avoid further clipping. - loudness_headroom_db (float): Target loudness for loudness normalization. - loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'. - when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still - occurs despite strategy (only for 'rms'). - make_parent_dir (bool): Make parent directory if it doesn't exist. - Returns: - Path: Path of the saved audio. - """ - assert wav.dtype.is_floating_point, "wav is not floating point" - if wav.dim() == 1: - wav = wav[None] - elif wav.dim() > 2: - raise ValueError("Input wav should be at most 2 dimension.") - assert wav.isfinite().all() - wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db, - rms_headroom_db, loudness_headroom_db, loudness_compressor, - log_clipping=log_clipping, sample_rate=sample_rate, - stem_name=str(stem_name)) - kwargs: dict = {} - if format == 'mp3': - suffix = '.mp3' - kwargs.update({"compression": mp3_rate}) - elif format == 'wav': - wav = i16_pcm(wav) - suffix = '.wav' - kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16}) - else: - raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.") - if not add_suffix: - suffix = '' - path = Path(str(stem_name) + suffix) - if make_parent_dir: - path.parent.mkdir(exist_ok=True, parents=True) - try: - ta.save(path, wav, sample_rate, **kwargs) - except Exception: - if path.exists(): - # we do not want to leave half written files around. - path.unlink() - raise - return path - -def audio_postproc(wav: torch.Tensor, sample_rate: int, normalize: bool = True, - strategy: str = 'peak', peak_clip_headroom_db: float = 1, - rms_headroom_db: float = 18, loudness_headroom_db: float = 14, - loudness_compressor: bool = False, log_clipping: bool = True) -> Path: - """Convenience function for saving audio to disk. Returns the filename the audio was written to. - - Args: - wav (torch.Tensor): Audio data to save. - sample_rate (int): Sample rate of audio data. - format (str): Either "wav" or "mp3". - mp3_rate (int): kbps when using mp3s. - normalize (bool): if `True` (default), normalizes according to the prescribed - strategy (see after). If `False`, the strategy is only used in case clipping - would happen. - strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak', - i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square - with extra headroom to avoid clipping. 'clip' just clips. - peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy. - rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger - than the `peak_clip` one to avoid further clipping. - loudness_headroom_db (float): Target loudness for loudness normalization. - loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'. - when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still - occurs despite strategy (only for 'rms'). - make_parent_dir (bool): Make parent directory if it doesn't exist. - Returns: - Path: Path of the saved audio. - """ - assert wav.dtype.is_floating_point, "wav is not floating point" - if wav.dim() == 1: - wav = wav[None] - elif wav.dim() > 2: - raise ValueError("Input wav should be at most 2 dimension.") - assert wav.isfinite().all() - wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db, - rms_headroom_db, loudness_headroom_db, loudness_compressor, - log_clipping=log_clipping, sample_rate=sample_rate, - stem_name=None) - - return wav diff --git a/audiocraft/audiocraft/data/audio_dataset.py b/audiocraft/audiocraft/data/audio_dataset.py deleted file mode 100644 index b508538f6b9cd4d0d9bd611ac24d9df36bbdba88..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/data/audio_dataset.py +++ /dev/null @@ -1,614 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""AudioDataset support. In order to handle a larger number of files -without having to scan again the folders, we precompute some metadata -(filename, sample rate, duration), and use that to efficiently sample audio segments. -""" -import argparse -import copy -from concurrent.futures import ThreadPoolExecutor, Future -from dataclasses import dataclass, fields -from contextlib import ExitStack -from functools import lru_cache -import gzip -import json -import logging -import os -from pathlib import Path -import random -import sys -import typing as tp - -import torch -import torch.nn.functional as F - -from .audio import audio_read, audio_info -from .audio_utils import convert_audio -from .zip import PathInZip - -try: - import dora -except ImportError: - dora = None # type: ignore - - -@dataclass(order=True) -class BaseInfo: - - @classmethod - def _dict2fields(cls, dictionary: dict): - return { - field.name: dictionary[field.name] - for field in fields(cls) if field.name in dictionary - } - - @classmethod - def from_dict(cls, dictionary: dict): - _dictionary = cls._dict2fields(dictionary) - return cls(**_dictionary) - - def to_dict(self): - return { - field.name: self.__getattribute__(field.name) - for field in fields(self) - } - - -@dataclass(order=True) -class AudioMeta(BaseInfo): - path: str - duration: float - sample_rate: int - bpm: float - # meter: int - amplitude: tp.Optional[float] = None - weight: tp.Optional[float] = None - phr_start: tp.List[tp.Optional[float]] = None - # info_path is used to load additional information about the audio file that is stored in zip files. - info_path: tp.Optional[PathInZip] = None - - @classmethod - def from_dict(cls, dictionary: dict): - base = cls._dict2fields(dictionary) - if 'info_path' in base and base['info_path'] is not None: - base['info_path'] = PathInZip(base['info_path']) - return cls(**base) - - def to_dict(self): - d = super().to_dict() - if d['info_path'] is not None: - d['info_path'] = str(d['info_path']) - return d - - -@dataclass(order=True) -class SegmentInfo(BaseInfo): - meta: AudioMeta - seek_time: float - # The following values are given once the audio is processed, e.g. - # at the target sample rate and target number of channels. - n_frames: int # actual number of frames without padding - total_frames: int # total number of frames, padding included - sample_rate: int # actual sample rate - channels: int # number of audio channels. - - -DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a'] - -logger = logging.getLogger(__name__) - - -def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta: - """AudioMeta from a path to an audio file. - - Args: - file_path (str): Resolved path of valid audio file. - minimal (bool): Whether to only load the minimal set of metadata (takes longer if not). - Returns: - AudioMeta: Audio file path and its metadata. - """ - info = audio_info(file_path) - amplitude: tp.Optional[float] = None - if not minimal: - wav, sr = audio_read(file_path) - amplitude = wav.abs().max().item() - - # load json info - json_file = file_path.replace('.wav', '.json') - with open(json_file ,'r') as f: - json_str = f.read() - info_json = json.loads(json_str) - - if "phr_start" not in info_json.keys(): - info_json["phr_start"] = None - - # return AudioMeta(file_path, info.duration, info.sample_rate, info_json["bpm"], info_json["meter"], amplitude, None, info_json["phr_start"]) - return AudioMeta(file_path, info.duration, info.sample_rate, info_json["bpm"], amplitude, None, info_json["phr_start"]) - -def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta: - """If Dora is available as a dependency, try to resolve potential relative paths - in list of AudioMeta. This method is expected to be used when loading meta from file. - - Args: - m (AudioMeta): Audio meta to resolve. - fast (bool): If True, uses a really fast check for determining if a file - is already absolute or not. Only valid on Linux/Mac. - Returns: - AudioMeta: Audio meta with resolved path. - """ - def is_abs(m): - if fast: - return str(m)[0] == '/' - else: - os.path.isabs(str(m)) - - if not dora: - return m - - if not is_abs(m.path): - m.path = dora.git_save.to_absolute_path(m.path) - if m.info_path is not None and not is_abs(m.info_path.zip_path): - m.info_path.zip_path = dora.git_save.to_absolute_path(m.path) - return m - - -def find_audio_files(path: tp.Union[Path, str], - exts: tp.List[str] = DEFAULT_EXTS, - resolve: bool = True, - minimal: bool = True, - progress: bool = False, - workers: int = 0) -> tp.List[AudioMeta]: - """Build a list of AudioMeta from a given path, - collecting relevant audio files and fetching meta info. - - Args: - path (str or Path): Path to folder containing audio files. - exts (list of str): List of file extensions to consider for audio files. - minimal (bool): Whether to only load the minimal set of metadata (takes longer if not). - progress (bool): Whether to log progress on audio files collection. - workers (int): number of parallel workers, if 0, use only the current thread. - Returns: - list of AudioMeta: List of audio file path and its metadata. - """ - audio_files = [] - futures: tp.List[Future] = [] - pool: tp.Optional[ThreadPoolExecutor] = None - with ExitStack() as stack: - if workers > 0: - pool = ThreadPoolExecutor(workers) - stack.enter_context(pool) - - if progress: - print("Finding audio files...") - for root, folders, files in os.walk(path, followlinks=True): - for file in files: - full_path = Path(root) / file - if full_path.suffix.lower() in exts: - audio_files.append(full_path) - if pool is not None: - futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal)) - if progress: - print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr) - - if progress: - print("Getting audio metadata...") - meta: tp.List[AudioMeta] = [] - for idx, file_path in enumerate(audio_files): - try: - if pool is None: - m = _get_audio_meta(str(file_path), minimal) - else: - m = futures[idx].result() - if resolve: - m = _resolve_audio_meta(m) - except Exception as err: - print("Error with", str(file_path), err, file=sys.stderr) - continue - meta.append(m) - if progress: - print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr) - meta.sort() - return meta - - -def load_audio_meta(path: tp.Union[str, Path], - resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]: - """Load list of AudioMeta from an optionally compressed json file. - - Args: - path (str or Path): Path to JSON file. - resolve (bool): Whether to resolve the path from AudioMeta (default=True). - fast (bool): activates some tricks to make things faster. - Returns: - list of AudioMeta: List of audio file path and its total duration. - """ - open_fn = gzip.open if str(path).lower().endswith('.gz') else open - with open_fn(path, 'rb') as fp: # type: ignore - lines = fp.readlines() - meta = [] - for line in lines: - d = json.loads(line) - m = AudioMeta.from_dict(d) - if resolve: - m = _resolve_audio_meta(m, fast=fast) - meta.append(m) - return meta - - -def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]): - """Save the audio metadata to the file pointer as json. - - Args: - path (str or Path): Path to JSON file. - metadata (list of BaseAudioMeta): List of audio meta to save. - """ - Path(path).parent.mkdir(exist_ok=True, parents=True) - open_fn = gzip.open if str(path).lower().endswith('.gz') else open - with open_fn(path, 'wb') as fp: # type: ignore - for m in meta: - json_str = json.dumps(m.to_dict()) + '\n' - json_bytes = json_str.encode('utf-8') - fp.write(json_bytes) - - -class AudioDataset: - """Base audio dataset. - - The dataset takes a list of AudioMeta and create a dataset composed of segments of audio - and potentially additional information, by creating random segments from the list of audio - files referenced in the metadata and applying minimal data pre-processing such as resampling, - mixing of channels, padding, etc. - - If no segment_duration value is provided, the AudioDataset will return the full wav for each - audio file. Otherwise, it will randomly sample audio files and create a segment of the specified - duration, applying padding if required. - - By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True - allows to return a tuple containing the torch Tensor and additional metadata on the segment and the - original audio meta. - - Note that you can call `start_epoch(epoch)` in order to get - a deterministic "randomization" for `shuffle=True`. - For a given epoch and dataset index, this will always return the same extract. - You can get back some diversity by setting the `shuffle_seed` param. - - Args: - meta (list of AudioMeta): List of audio files metadata. - segment_duration (float, optional): Optional segment duration of audio to load. - If not specified, the dataset will load the full audio segment from the file. - shuffle (bool): Set to `True` to have the data reshuffled at every epoch. - sample_rate (int): Target sample rate of the loaded audio samples. - channels (int): Target number of channels of the loaded audio samples. - sample_on_duration (bool): Set to `True` to sample segments with probability - dependent on audio file duration. This is only used if `segment_duration` is provided. - sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of - `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product - of the file duration and file weight. This is only used if `segment_duration` is provided. - min_segment_ratio (float): Minimum segment ratio to use when the audio file - is shorter than the desired segment. - max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset. - return_info (bool): Whether to return the wav only or return wav along with segment info and metadata. - min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided - audio shorter than this will be filtered out. - max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided - audio longer than this will be filtered out. - shuffle_seed (int): can be used to further randomize - load_wav (bool): if False, skip loading the wav but returns a tensor of 0 - with the expected segment_duration (which must be provided if load_wav is False). - permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration` - are False. Will ensure a permutation on files when going through the dataset. - In that case the epoch number must be provided in order for the model - to continue the permutation across epochs. In that case, it is assumed - that `num_samples = total_batch_size * num_updates_per_epoch`, with - `total_batch_size` the overall batch size accounting for all gpus. - """ - def __init__(self, - meta: tp.List[AudioMeta], - segment_duration: tp.Optional[float] = None, - shuffle: bool = True, - num_samples: int = 10_000, - sample_rate: int = 48_000, - channels: int = 2, - pad: bool = True, - sample_on_duration: bool = True, - sample_on_weight: bool = True, - min_segment_ratio: float = 1, - max_read_retry: int = 10, - return_info: bool = False, - min_audio_duration: tp.Optional[float] = None, - max_audio_duration: tp.Optional[float] = None, - shuffle_seed: int = 0, - load_wav: bool = True, - permutation_on_files: bool = False, - ): - assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta." - assert segment_duration is None or segment_duration > 0 - assert segment_duration is None or min_segment_ratio >= 0 - self.segment_duration = segment_duration - self.min_segment_ratio = min_segment_ratio - self.max_audio_duration = max_audio_duration - self.min_audio_duration = min_audio_duration - if self.min_audio_duration is not None and self.max_audio_duration is not None: - assert self.min_audio_duration <= self.max_audio_duration - self.meta: tp.List[AudioMeta] = self._filter_duration(meta) - assert len(self.meta) # Fail fast if all data has been filtered. - self.total_duration = sum(d.duration for d in self.meta) - - if segment_duration is None: - num_samples = len(self.meta) - self.num_samples = num_samples - self.shuffle = shuffle - self.sample_rate = sample_rate - self.channels = channels - self.pad = pad - self.sample_on_weight = sample_on_weight - self.sample_on_duration = sample_on_duration - self.sampling_probabilities = self._get_sampling_probabilities() - self.max_read_retry = max_read_retry - self.return_info = return_info - self.shuffle_seed = shuffle_seed - self.current_epoch: tp.Optional[int] = None - self.load_wav = load_wav - if not load_wav: - assert segment_duration is not None - self.permutation_on_files = permutation_on_files - if permutation_on_files: - assert not self.sample_on_duration - assert not self.sample_on_weight - assert self.shuffle - - def start_epoch(self, epoch: int): - self.current_epoch = epoch - - def __len__(self): - return self.num_samples - - def _get_sampling_probabilities(self, normalized: bool = True): - """Return the sampling probabilities for each file inside `self.meta`.""" - scores: tp.List[float] = [] - for file_meta in self.meta: - score = 1. - if self.sample_on_weight and file_meta.weight is not None: - score *= file_meta.weight - if self.sample_on_duration: - score *= file_meta.duration - scores.append(score) - probabilities = torch.tensor(scores) - if normalized: - probabilities /= probabilities.sum() - return probabilities - - @staticmethod - @lru_cache(16) - def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int): - # Used to keep the most recent files permutation in memory implicitely. - # will work unless someone is using a lot of Datasets in parallel. - rng = torch.Generator() - rng.manual_seed(base_seed + permutation_index) - return torch.randperm(num_files, generator=rng) - - def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta: - """Sample a given file from `self.meta`. Can be overridden in subclasses. - This is only called if `segment_duration` is not None. - - You must use the provided random number generator `rng` for reproducibility. - You can further make use of the index accessed. - """ - if self.permutation_on_files: - assert self.current_epoch is not None - total_index = self.current_epoch * len(self) + index - permutation_index = total_index // len(self.meta) - relative_index = total_index % len(self.meta) - permutation = AudioDataset._get_file_permutation( - len(self.meta), permutation_index, self.shuffle_seed) - file_index = permutation[relative_index] - return self.meta[file_index] - - if not self.sample_on_weight and not self.sample_on_duration: - file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item()) - else: - file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item()) - - return self.meta[file_index] - - def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1): - # Override this method in subclass if needed. - if self.load_wav: - return audio_read(path, seek_time, duration, pad=False) - else: - assert self.segment_duration is not None - n_frames = int(self.sample_rate * self.segment_duration) - return torch.zeros(self.channels, n_frames), self.sample_rate - - def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]: - if self.segment_duration is None: - file_meta = self.meta[index] - out, sr = audio_read(file_meta.path) - out = convert_audio(out, sr, self.sample_rate, self.channels) - n_frames = out.shape[-1] - segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames, - sample_rate=self.sample_rate, channels=out.shape[0]) - else: - rng = torch.Generator() - if self.shuffle: - # We use index, plus extra randomness, either totally random if we don't know the epoch. - # otherwise we make use of the epoch number and optional shuffle_seed. - if self.current_epoch is None: - rng.manual_seed(index + self.num_samples * random.randint(0, 2**24)) - else: - rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed)) - else: - # We only use index - rng.manual_seed(index) - - for retry in range(self.max_read_retry): - file_meta = self.sample_file(index, rng) - # We add some variance in the file position even if audio file is smaller than segment - # without ending up with empty segments - - # sample with phrase - if file_meta.phr_start is not None: - # max_seek = max(0, len(file_meta.phr_start[:-1])) - max_seek = max(0, len([start for start in file_meta.phr_start if start + self.segment_duration <= file_meta.duration])) # sample with time - seek_time = file_meta.phr_start[int(torch.rand(1, generator=rng).item() * max_seek)] # choose from phrase - - else: - max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio) - seek_time = torch.rand(1, generator=rng).item() * max_seek # can be change to choose phrase start - - if file_meta.duration == self.segment_duration: - seek_time = 0 - - # phr_dur = 60./file_meta.bpm * (file_meta.meter * 4.) # if meter=4 then 16 beats per phrase - try: - out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False) - # out, sr = audio_read(file_meta.path, seek_time, phr_dur, pad=False) # use phrase trunk as input - out = convert_audio(out, sr, self.sample_rate, self.channels) - n_frames = out.shape[-1] - target_frames = int(self.segment_duration * self.sample_rate) - if self.pad: - out = F.pad(out, (0, target_frames - n_frames)) - segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames, - sample_rate=self.sample_rate, channels=out.shape[0]) - except Exception as exc: - logger.warning("Error opening file %s: %r", file_meta.path, exc) - if retry == self.max_read_retry - 1: - raise - else: - break - - if self.return_info: - # Returns the wav and additional information on the wave segment - return out, segment_info - else: - return out - - def collater(self, samples): - """The collater function has to be provided to the dataloader - if AudioDataset has return_info=True in order to properly collate - the samples of a batch. - """ - if self.segment_duration is None and len(samples) > 1: - assert self.pad, "Must allow padding when batching examples of different durations." - - # In this case the audio reaching the collater is of variable length as segment_duration=None. - to_pad = self.segment_duration is None and self.pad - if to_pad: - max_len = max([wav.shape[-1] for wav, _ in samples]) - - def _pad_wav(wav): - return F.pad(wav, (0, max_len - wav.shape[-1])) - - if self.return_info: - if len(samples) > 0: - assert len(samples[0]) == 2 - assert isinstance(samples[0][0], torch.Tensor) - assert isinstance(samples[0][1], SegmentInfo) - - wavs = [wav for wav, _ in samples] - segment_infos = [copy.deepcopy(info) for _, info in samples] - - if to_pad: - # Each wav could be of a different duration as they are not segmented. - for i in range(len(samples)): - # Determines the total length of the signal with padding, so we update here as we pad. - segment_infos[i].total_frames = max_len - wavs[i] = _pad_wav(wavs[i]) - - wav = torch.stack(wavs) - return wav, segment_infos - else: - assert isinstance(samples[0], torch.Tensor) - if to_pad: - samples = [_pad_wav(s) for s in samples] - return torch.stack(samples) - - def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: - """Filters out audio files with audio durations that will not allow to sample examples from them.""" - orig_len = len(meta) - - # Filter data that is too short. - if self.min_audio_duration is not None: - meta = [m for m in meta if m.duration >= self.min_audio_duration] - - # Filter data that is too long. - if self.max_audio_duration is not None: - meta = [m for m in meta if m.duration <= self.max_audio_duration] - - filtered_len = len(meta) - removed_percentage = 100*(1-float(filtered_len)/orig_len) - msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage - if removed_percentage < 10: - logging.debug(msg) - else: - logging.warning(msg) - return meta - - @classmethod - def from_meta(cls, root: tp.Union[str, Path], **kwargs): - """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file. - - Args: - root (str or Path): Path to root folder containing audio files. - kwargs: Additional keyword arguments for the AudioDataset. - """ - root = Path(root) - if root.is_dir(): - if (root / 'data.jsonl').exists(): - root = root / 'data.jsonl' - elif (root / 'data.jsonl.gz').exists(): - root = root / 'data.jsonl.gz' - else: - raise ValueError("Don't know where to read metadata from in the dir. " - "Expecting either a data.jsonl or data.jsonl.gz file but none found.") - meta = load_audio_meta(root) - return cls(meta, **kwargs) - - @classmethod - def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True, - exts: tp.List[str] = DEFAULT_EXTS, **kwargs): - """Instantiate AudioDataset from a path containing (possibly nested) audio files. - - Args: - root (str or Path): Path to root folder containing audio files. - minimal_meta (bool): Whether to only load minimal metadata or not. - exts (list of str): Extensions for audio files. - kwargs: Additional keyword arguments for the AudioDataset. - """ - root = Path(root) - if root.is_file(): - meta = load_audio_meta(root, resolve=True) - else: - meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True) - return cls(meta, **kwargs) - - -def main(): - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - parser = argparse.ArgumentParser( - prog='audio_dataset', - description='Generate .jsonl files by scanning a folder.') - parser.add_argument('root', help='Root folder with all the audio files') - parser.add_argument('output_meta_file', - help='Output file to store the metadata, ') - parser.add_argument('--complete', - action='store_false', dest='minimal', default=True, - help='Retrieve all metadata, even the one that are expansive ' - 'to compute (e.g. normalization).') - parser.add_argument('--resolve', - action='store_true', default=False, - help='Resolve the paths to be absolute and with no symlinks.') - parser.add_argument('--workers', - default=10, type=int, - help='Number of workers.') - args = parser.parse_args() - meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True, - resolve=args.resolve, minimal=args.minimal, workers=args.workers) - save_audio_meta(args.output_meta_file, meta) - - -if __name__ == '__main__': - main() diff --git a/audiocraft/audiocraft/data/audio_utils.py b/audiocraft/audiocraft/data/audio_utils.py deleted file mode 100644 index e9fb715f9801ace1fb2d510f59c161f5ffbe8695..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/data/audio_utils.py +++ /dev/null @@ -1,385 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Various utilities for audio convertion (pcm format, sample rate and channels), -and volume normalization.""" -import sys -import typing as tp - -import julius -import torch -import torchaudio -import numpy as np - -from .chords import Chords -chords = Chords() # initiate object - - -def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor: - """Convert audio to the given number of channels. - - Args: - wav (torch.Tensor): Audio wave of shape [B, C, T]. - channels (int): Expected number of channels as output. - Returns: - torch.Tensor: Downmixed or unchanged audio wave [B, C, T]. - """ - *shape, src_channels, length = wav.shape - if src_channels == channels: - pass - elif channels == 1: - # Case 1: - # The caller asked 1-channel audio, and the stream has multiple - # channels, downmix all channels. - wav = wav.mean(dim=-2, keepdim=True) - elif src_channels == 1: - # Case 2: - # The caller asked for multiple channels, but the input file has - # a single channel, replicate the audio over all channels. - wav = wav.expand(*shape, channels, length) - elif src_channels >= channels: - # Case 3: - # The caller asked for multiple channels, and the input file has - # more channels than requested. In that case return the first channels. - wav = wav[..., :channels, :] - else: - # Case 4: What is a reasonable choice here? - raise ValueError('The audio file has less channels than requested but is not mono.') - return wav - - -def convert_audio(wav: torch.Tensor, from_rate: float, - to_rate: float, to_channels: int) -> torch.Tensor: - """Convert audio to new sample rate and number of audio channels.""" - wav = julius.resample_frac(wav, int(from_rate), int(to_rate)) - wav = convert_audio_channels(wav, to_channels) - return wav - - -def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14, - loudness_compressor: bool = False, energy_floor: float = 2e-3): - """Normalize an input signal to a user loudness in dB LKFS. - Audio loudness is defined according to the ITU-R BS.1770-4 recommendation. - - Args: - wav (torch.Tensor): Input multichannel audio data. - sample_rate (int): Sample rate. - loudness_headroom_db (float): Target loudness of the output in dB LUFS. - loudness_compressor (bool): Uses tanh for soft clipping. - energy_floor (float): anything below that RMS level will not be rescaled. - Returns: - torch.Tensor: Loudness normalized output data. - """ - energy = wav.pow(2).mean().sqrt().item() - if energy < energy_floor: - return wav - transform = torchaudio.transforms.Loudness(sample_rate) - input_loudness_db = transform(wav).item() - # calculate the gain needed to scale to the desired loudness level - delta_loudness = -loudness_headroom_db - input_loudness_db - gain = 10.0 ** (delta_loudness / 20.0) - output = gain * wav - if loudness_compressor: - output = torch.tanh(output) - assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt()) - return output - - -def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None: - """Utility function to clip the audio with logging if specified.""" - max_scale = wav.abs().max() - if log_clipping and max_scale > 1: - clamp_prob = (wav.abs() > 1).float().mean().item() - print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):", - clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr) - wav.clamp_(-1, 1) - - -def normalize_audio(wav: torch.Tensor, normalize: bool = True, - strategy: str = 'peak', peak_clip_headroom_db: float = 1, - rms_headroom_db: float = 18, loudness_headroom_db: float = 14, - loudness_compressor: bool = False, log_clipping: bool = False, - sample_rate: tp.Optional[int] = None, - stem_name: tp.Optional[str] = None) -> torch.Tensor: - """Normalize the audio according to the prescribed strategy (see after). - - Args: - wav (torch.Tensor): Audio data. - normalize (bool): if `True` (default), normalizes according to the prescribed - strategy (see after). If `False`, the strategy is only used in case clipping - would happen. - strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak', - i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square - with extra headroom to avoid clipping. 'clip' just clips. - peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy. - rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger - than the `peak_clip` one to avoid further clipping. - loudness_headroom_db (float): Target loudness for loudness normalization. - loudness_compressor (bool): If True, uses tanh based soft clipping. - log_clipping (bool): If True, basic logging on stderr when clipping still - occurs despite strategy (only for 'rms'). - sample_rate (int): Sample rate for the audio data (required for loudness). - stem_name (str, optional): Stem name for clipping logging. - Returns: - torch.Tensor: Normalized audio. - """ - scale_peak = 10 ** (-peak_clip_headroom_db / 20) - scale_rms = 10 ** (-rms_headroom_db / 20) - if strategy == 'peak': - rescaling = (scale_peak / wav.abs().max()) - if normalize or rescaling < 1: - wav = wav * rescaling - elif strategy == 'clip': - wav = wav.clamp(-scale_peak, scale_peak) - elif strategy == 'rms': - mono = wav.mean(dim=0) - rescaling = scale_rms / mono.pow(2).mean().sqrt() - if normalize or rescaling < 1: - wav = wav * rescaling - _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name) - elif strategy == 'loudness': - assert sample_rate is not None, "Loudness normalization requires sample rate." - wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor) - _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name) - else: - assert wav.abs().max() < 1 - assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'" - return wav - - -def f32_pcm(wav: torch.Tensor) -> torch.Tensor: - """Convert audio to float 32 bits PCM format. - """ - if wav.dtype.is_floating_point: - return wav - elif wav.dtype == torch.int16: - return wav.float() / 2**15 - elif wav.dtype == torch.int32: - return wav.float() / 2**31 - raise ValueError(f"Unsupported wav dtype: {wav.dtype}") - - -def i16_pcm(wav: torch.Tensor) -> torch.Tensor: - """Convert audio to int 16 bits PCM format. - - ..Warning:: There exist many formula for doing this conversion. None are perfect - due to the asymmetry of the int16 range. One either have possible clipping, DC offset, - or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom, - it is possible that `i16_pcm(f32_pcm)) != Identity`. - """ - if wav.dtype.is_floating_point: - assert wav.abs().max() <= 1 - candidate = (wav * 2 ** 15).round() - if candidate.max() >= 2 ** 15: # clipping would occur - candidate = (wav * (2 ** 15 - 1)).round() - return candidate.short() - else: - assert wav.dtype == torch.int16 - return wav - -def convert_txtchord2chroma_orig(text_chords, bpms, meters, gen_sec): - chromas = [] - # total_len = int(gen_sec * 44100 / 512) - total_len = int(gen_sec * 32000 / 640) - for chord, bpm, meter in zip(text_chords, bpms, meters): - phr_len = int(60. / bpm * (meter * 4) * 32000 / 640) - # phr_len = int(60. / bpm * (meter * 4) * 44100 / 2048) - chroma = torch.zeros([total_len, 12]) - count = 0 - offset = 0 - - stext = chord.split(" ") - timebin = phr_len // 4 # frames per bar - while count < total_len: - for tokens in stext: - if count >= total_len: - break - stoken = tokens.split(',') - for token in stoken: - off_timebin = timebin + offset - rounded_timebin = round(off_timebin) - offset = off_timebin - rounded_timebin - offset = offset/len(stoken) - add_step = rounded_timebin//len(stoken) - mhot = chords.chord(token) - rolled = np.roll(mhot[2], mhot[0]) - for i in range(count, count + add_step): - if count >= total_len: - break - chroma[i] = torch.Tensor(rolled) - count += 1 - chromas.append(chroma) - chroma = torch.stack(chromas) - return chroma - -def convert_txtchord2chroma(chord, bpm, meter, gen_sec): - total_len = int(gen_sec * 32000 / 640) - - phr_len = int(60. / bpm * (meter * 4) * 32000 / 640) - # phr_len = int(60. / bpm * (meter * 4) * 44100 / 2048) - chroma = torch.zeros([total_len, 12]) - count = 0 - offset = 0 - - stext = chord.split(" ") - timebin = phr_len // 4 # frames per bar - while count < total_len: - for tokens in stext: - if count >= total_len: - break - stoken = tokens.split(',') - for token in stoken: - off_timebin = timebin + offset - rounded_timebin = round(off_timebin) - offset = off_timebin - rounded_timebin - offset = offset/len(stoken) - add_step = rounded_timebin//len(stoken) - mhot = chords.chord(token) - rolled = np.roll(mhot[2], mhot[0]) - for i in range(count, count + add_step): - if count >= total_len: - break - chroma[i] = torch.Tensor(rolled) - count += 1 - return chroma - - - -def convert_txtchord2chroma_24(chord, bpm, meter, gen_sec): - total_len = int(gen_sec * 32000 / 640) - - phr_len = int(60. / bpm * (meter * 4) * 32000 / 640) - # phr_len = int(60. / bpm * (meter * 4) * 44100 / 2048) - chroma = torch.zeros([total_len, 24]) - count = 0 - offset = 0 - - stext = chord.split(" ") - timebin = phr_len // 4 # frames per bar - while count < total_len: - for tokens in stext: - if count >= total_len: - break - stoken = tokens.split(',') - for token in stoken: - off_timebin = timebin + offset - rounded_timebin = round(off_timebin) - offset = off_timebin - rounded_timebin - offset = offset/len(stoken) - add_step = rounded_timebin//len(stoken) - - root, bass, ivs_vec, _ = chords.chord(token) - root_vec = torch.zeros(12) - root_vec[root] = 1 - final_vec = np.concatenate([root_vec, ivs_vec]) # [C] - for i in range(count, count + add_step): - if count >= total_len: - break - chroma[i] = torch.Tensor(final_vec) - count += 1 - return chroma - -def get_chroma_chord_from_lab(chord_path, gen_sec): - total_len = int(gen_sec * 32000 / 640) - feat_hz = 32000/640 - intervals = [] - labels = [] - feat_chord = np.zeros((12, total_len)) # root| ivs - with open(chord_path, 'r') as f: - for line in f.readlines(): - splits = line.split() - if len(splits) == 3: - st_sec, ed_sec, ctag = splits - st_sec = float(st_sec) - ed_sec = float(ed_sec) - - st_frame = int(st_sec*feat_hz) - ed_frame = int(ed_sec*feat_hz) - - mhot = chords.chord(ctag) - final_vec = np.roll(mhot[2], mhot[0]) - - final_vec = final_vec[..., None] # [C, T] - feat_chord[:, st_frame:ed_frame] = final_vec - feat_chord = torch.from_numpy(feat_chord) - return feat_chord - - -def get_chroma_chord_from_text(text_chord, bpm, meter, gen_sec): - total_len = int(gen_sec * 32000 / 640) - - phr_len = int(60. / bpm * (meter * 4) * 32000 / 640) - chroma = np.zeros([12, total_len]) - count = 0 - offset = 0 - - stext = chord.split(" ") - timebin = phr_len // 4 # frames per bar - while count < total_len: - for tokens in stext: - if count >= total_len: - break - stoken = tokens.split(',') - for token in stoken: - off_timebin = timebin + offset - rounded_timebin = round(off_timebin) - offset = off_timebin - rounded_timebin - offset = offset/len(stoken) - add_step = rounded_timebin//len(stoken) - mhot = chords.chord(token) - final_vec = np.roll(mhot[2], mhot[0]) - final_vec = final_vec[..., None] # [C, T] - - for i in range(count, count + add_step): - if count >= total_len: - break - chroma[:, i] = final_vec - count += 1 - feat_chord = torch.from_numpy(feat_chord) - return feat_chord - -def get_beat_from_npy(beat_path, gen_sec): - total_len = int(gen_sec * 32000 / 640) - - beats_np = np.load(beat_path, allow_pickle=True) - feat_beats = np.zeros((2, total_len)) - meter = int(max(beats_np.T[1])) - beat_time = beats_np[:, 0] - bar_time = beats_np[np.where(beats_np[:, 1] == 1)[0], 0] - - beat_frame = [int((t)*feat_hz) for t in beat_time if (t >= 0 and t < duration)] - bar_frame =[int((t)*feat_hz) for t in bar_time if (t >= 0 and t < duration)] - - feat_beats[0, beat_frame] = 1 - feat_beats[1, bar_frame] = 1 - kernel = np.array([0.05, 0.1, 0.3, 0.9, 0.3, 0.1, 0.05]) - feat_beats[0] = np.convolve(feat_beats[0] , kernel, 'same') # apply soft kernel - beat_events = feat_beats[0] + feat_beats[1] - beat_events = torch.tensor(beat_events).unsqueeze(0) # [T] -> [1, T] - - bpm = 60 // np.mean([j-i for i, j in zip(beat_time[:-1], beat_time[1:])]) - return beat_events, bpm, meter - -def get_beat_from_bpm(bpm, meter, gen_sec): - total_len = int(gen_sec * 32000 / 640) - - feat_beats = np.zeros((2, total_len)) - - beat_time_gap = 60 / bpm - beat_gap = 60 / bpm * feat_hz - - beat_time = np.arange(0, duration, beat_time_gap) - beat_frame = np.round(np.arange(0, n_frames_feat, beat_gap)).astype(int) - if beat_frame[-1] == n_frames_feat: - beat_frame = beat_frame[:-1] - bar_frame = beat_frame[::meter] - - feat_beats[0, beat_frame] = 1 - feat_beats[1, bar_frame] = 1 - kernel = np.array([0.05, 0.1, 0.3, 0.9, 0.3, 0.1, 0.05]) - feat_beats[0] = np.convolve(feat_beats[0] , kernel, 'same') # apply soft kernel - beat_events = feat_beats[0] + feat_beats[1] - beat_events = torch.tensor(beat_events).unsqueeze(0) # [T] -> [1, T] - return beat_events, beat_time, meter \ No newline at end of file diff --git a/audiocraft/audiocraft/data/btc_chords.py b/audiocraft/audiocraft/data/btc_chords.py deleted file mode 100644 index 1208be9a2d22bb470550c3129fc930eece99ca87..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/data/btc_chords.py +++ /dev/null @@ -1,524 +0,0 @@ -# encoding: utf-8 -""" -This module contains chord evaluation functionality. - -It provides the evaluation measures used for the MIREX ACE task, and -tries to follow [1]_ and [2]_ as closely as possible. - -Notes ------ -This implementation tries to follow the references and their implementation -(e.g., https://github.com/jpauwels/MusOOEvaluator for [2]_). However, there -are some known (and possibly some unknown) differences. If you find one not -listed in the following, please file an issue: - - - Detected chord segments are adjusted to fit the length of the annotations. - In particular, this means that, if necessary, filler segments of 'no chord' - are added at beginnings and ends. This can result in different segmentation - scores compared to the original implementation. - -References ----------- -.. [1] Christopher Harte, "Towards Automatic Extraction of Harmony Information - from Music Signals." Dissertation, - Department for Electronic Engineering, Queen Mary University of London, - 2010. -.. [2] Johan Pauwels and Geoffroy Peeters. - "Evaluating Automatically Estimated Chord Sequences." - In Proceedings of ICASSP 2013, Vancouver, Canada, 2013. - -""" - -import numpy as np -import pandas as pd - - -CHORD_DTYPE = [('root', np.int_), - ('bass', np.int_), - ('intervals', np.int_, (12,)), - ('is_major',np.bool_)] - -CHORD_ANN_DTYPE = [('start', np.float32), - ('end', np.float32), - ('chord', CHORD_DTYPE)] - -NO_CHORD = (-1, -1, np.zeros(12, dtype=np.int_), False) -UNKNOWN_CHORD = (-1, -1, np.ones(12, dtype=np.int_) * -1, False) - -PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] - - -def idx_to_chord(idx): - if idx == 24: - return "-" - elif idx == 25: - return u"\u03B5" - - minmaj = idx % 2 - root = idx // 2 - - return PITCH_CLASS[root] + ("M" if minmaj == 0 else "m") - -class Chords: - - def __init__(self): - self._shorthands = { - 'maj': self.interval_list('(1,3,5)'), - 'min': self.interval_list('(1,b3,5)'), - 'dim': self.interval_list('(1,b3,b5)'), - 'aug': self.interval_list('(1,3,#5)'), - 'maj7': self.interval_list('(1,3,5,7)'), - 'min7': self.interval_list('(1,b3,5,b7)'), - '7': self.interval_list('(1,3,5,b7)'), - '6': self.interval_list('(1,6)'), # custom - '5': self.interval_list('(1,5)'), - '4': self.interval_list('(1,4)'), # custom - '1': self.interval_list('(1)'), - 'dim7': self.interval_list('(1,b3,b5,bb7)'), - 'hdim7': self.interval_list('(1,b3,b5,b7)'), - 'minmaj7': self.interval_list('(1,b3,5,7)'), - 'maj6': self.interval_list('(1,3,5,6)'), - 'min6': self.interval_list('(1,b3,5,6)'), - '9': self.interval_list('(1,3,5,b7,9)'), - 'maj9': self.interval_list('(1,3,5,7,9)'), - 'min9': self.interval_list('(1,b3,5,b7,9)'), - 'add9': self.interval_list('(1,3,5,9)'), # custom - 'sus2': self.interval_list('(1,2,5)'), - 'sus4': self.interval_list('(1,4,5)'), - '7sus2': self.interval_list('(1,2,5,b7)'), # custom - '7sus4': self.interval_list('(1,4,5,b7)'), # custom - '11': self.interval_list('(1,3,5,b7,9,11)'), - 'min11': self.interval_list('(1,b3,5,b7,9,11)'), - '13': self.interval_list('(1,3,5,b7,13)'), - 'maj13': self.interval_list('(1,3,5,7,13)'), - 'min13': self.interval_list('(1,b3,5,b7,13)') - } - - def chords(self, labels): - - """ - Transform a list of chord labels into an array of internal numeric - representations. - - Parameters - ---------- - labels : list - List of chord labels (str). - - Returns - ------- - chords : numpy.array - Structured array with columns 'root', 'bass', and 'intervals', - containing a numeric representation of chords. - - """ - crds = np.zeros(len(labels), dtype=CHORD_DTYPE) - cache = {} - for i, lbl in enumerate(labels): - cv = cache.get(lbl, None) - if cv is None: - cv = self.chord(lbl) - cache[lbl] = cv - crds[i] = cv - - return crds - - def label_error_modify(self, label): - if label == 'Emin/4': label = 'E:min/4' - elif label == 'A7/3': label = 'A:7/3' - elif label == 'Bb7/3': label = 'Bb:7/3' - elif label == 'Bb7/5': label = 'Bb:7/5' - elif label.find(':') == -1: - if label.find('min') != -1: - label = label[:label.find('min')] + ':' + label[label.find('min'):] - return label - - def chord(self, label): - """ - Transform a chord label into the internal numeric represenation of - (root, bass, intervals array). - - Parameters - ---------- - label : str - Chord label. - - Returns - ------- - chord : tuple - Numeric representation of the chord: (root, bass, intervals array). - - """ - - - is_major = False - - if label == 'N': - return NO_CHORD - if label == 'X': - return UNKNOWN_CHORD - - label = self.label_error_modify(label) - - c_idx = label.find(':') - s_idx = label.find('/') - - if c_idx == -1: - quality_str = 'maj' - if s_idx == -1: - root_str = label - bass_str = '' - else: - root_str = label[:s_idx] - bass_str = label[s_idx + 1:] - else: - root_str = label[:c_idx] - if s_idx == -1: - quality_str = label[c_idx + 1:] - bass_str = '' - else: - quality_str = label[c_idx + 1:s_idx] - bass_str = label[s_idx + 1:] - - root = self.pitch(root_str) - bass = self.interval(bass_str) if bass_str else 0 - ivs = self.chord_intervals(quality_str) - ivs[bass] = 1 - - if 'min' in quality_str: - is_major = False - else: - is_major = True - - - return root, bass, ivs, is_major - - _l = [0, 1, 1, 0, 1, 1, 1] - _chroma_id = (np.arange(len(_l) * 2) + 1) + np.array(_l + _l).cumsum() - 1 - - def modify(self, base_pitch, modifier): - """ - Modify a pitch class in integer representation by a given modifier string. - - A modifier string can be any sequence of 'b' (one semitone down) - and '#' (one semitone up). - - Parameters - ---------- - base_pitch : int - Pitch class as integer. - modifier : str - String of modifiers ('b' or '#'). - - Returns - ------- - modified_pitch : int - Modified root note. - - """ - for m in modifier: - if m == 'b': - base_pitch -= 1 - elif m == '#': - base_pitch += 1 - else: - raise ValueError('Unknown modifier: {}'.format(m)) - return base_pitch - - def pitch(self, pitch_str): - """ - Convert a string representation of a pitch class (consisting of root - note and modifiers) to an integer representation. - - Parameters - ---------- - pitch_str : str - String representation of a pitch class. - - Returns - ------- - pitch : int - Integer representation of a pitch class. - - """ - return self.modify(self._chroma_id[(ord(pitch_str[0]) - ord('C')) % 7], - pitch_str[1:]) % 12 - - def interval(self, interval_str): - """ - Convert a string representation of a musical interval into a pitch class - (e.g. a minor seventh 'b7' into 10, because it is 10 semitones above its - base note). - - Parameters - ---------- - interval_str : str - Musical interval. - - Returns - ------- - pitch_class : int - Number of semitones to base note of interval. - - """ - for i, c in enumerate(interval_str): - if c.isdigit(): - return self.modify(self._chroma_id[int(interval_str[i:]) - 1], - interval_str[:i]) % 12 - - def interval_list(self, intervals_str, given_pitch_classes=None): - """ - Convert a list of intervals given as string to a binary pitch class - representation. For example, 'b3, 5' would become - [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]. - - Parameters - ---------- - intervals_str : str - List of intervals as comma-separated string (e.g. 'b3, 5'). - given_pitch_classes : None or numpy array - If None, start with empty pitch class array, if numpy array of length - 12, this array will be modified. - - Returns - ------- - pitch_classes : numpy array - Binary pitch class representation of intervals. - - """ - if given_pitch_classes is None: - given_pitch_classes = np.zeros(12, dtype=np.int_) - for int_def in intervals_str[1:-1].split(','): - int_def = int_def.strip() - if int_def[0] == '*': - given_pitch_classes[self.interval(int_def[1:])] = 0 - else: - given_pitch_classes[self.interval(int_def)] = 1 - return given_pitch_classes - - # mapping of shorthand interval notations to the actual interval representation - - def chord_intervals(self, quality_str): - """ - Convert a chord quality string to a pitch class representation. For - example, 'maj' becomes [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0]. - - Parameters - ---------- - quality_str : str - String defining the chord quality. - - Returns - ------- - pitch_classes : numpy array - Binary pitch class representation of chord quality. - - """ - list_idx = quality_str.find('(') - if list_idx == -1: - return self._shorthands[quality_str].copy() - if list_idx != 0: - ivs = self._shorthands[quality_str[:list_idx]].copy() - else: - ivs = np.zeros(12, dtype=np.int_) - - - return self.interval_list(quality_str[list_idx:], ivs) - - def load_chords(self, filename): - """ - Load chords from a text file. - - The chord must follow the syntax defined in [1]_. - - Parameters - ---------- - filename : str - File containing chord segments. - - Returns - ------- - crds : numpy structured array - Structured array with columns "start", "end", and "chord", - containing the beginning, end, and chord definition of chord - segments. - - References - ---------- - .. [1] Christopher Harte, "Towards Automatic Extraction of Harmony - Information from Music Signals." Dissertation, - Department for Electronic Engineering, Queen Mary University of - London, 2010. - - """ - start, end, chord_labels = [], [], [] - with open(filename, 'r') as f: - for line in f: - if line: - - splits = line.split() - if len(splits) == 3: - - s = splits[0] - e = splits[1] - l = splits[2] - - start.append(float(s)) - end.append(float(e)) - chord_labels.append(l) - - crds = np.zeros(len(start), dtype=CHORD_ANN_DTYPE) - crds['start'] = start - crds['end'] = end - crds['chord'] = self.chords(chord_labels) - - return crds - - def reduce_to_triads(self, chords, keep_bass=False): - """ - Reduce chords to triads. - - The function follows the reduction rules implemented in [1]_. If a chord - chord does not contain a third, major second or fourth, it is reduced to - a power chord. If it does not contain neither a third nor a fifth, it is - reduced to a single note "chord". - - Parameters - ---------- - chords : numpy structured array - Chords to be reduced. - keep_bass : bool - Indicates whether to keep the bass note or set it to 0. - - Returns - ------- - reduced_chords : numpy structured array - Chords reduced to triads. - - References - ---------- - .. [1] Johan Pauwels and Geoffroy Peeters. - "Evaluating Automatically Estimated Chord Sequences." - In Proceedings of ICASSP 2013, Vancouver, Canada, 2013. - - """ - unison = chords['intervals'][:, 0].astype(bool) - maj_sec = chords['intervals'][:, 2].astype(bool) - min_third = chords['intervals'][:, 3].astype(bool) - maj_third = chords['intervals'][:, 4].astype(bool) - perf_fourth = chords['intervals'][:, 5].astype(bool) - dim_fifth = chords['intervals'][:, 6].astype(bool) - perf_fifth = chords['intervals'][:, 7].astype(bool) - aug_fifth = chords['intervals'][:, 8].astype(bool) - no_chord = (chords['intervals'] == NO_CHORD[-1]).all(axis=1) - - reduced_chords = chords.copy() - ivs = reduced_chords['intervals'] - - ivs[~no_chord] = self.interval_list('(1)') - ivs[unison & perf_fifth] = self.interval_list('(1,5)') - ivs[~perf_fourth & maj_sec] = self._shorthands['sus2'] - ivs[perf_fourth & ~maj_sec] = self._shorthands['sus4'] - - ivs[min_third] = self._shorthands['min'] - ivs[min_third & aug_fifth & ~perf_fifth] = self.interval_list('(1,b3,#5)') - ivs[min_third & dim_fifth & ~perf_fifth] = self._shorthands['dim'] - - ivs[maj_third] = self._shorthands['maj'] - ivs[maj_third & dim_fifth & ~perf_fifth] = self.interval_list('(1,3,b5)') - ivs[maj_third & aug_fifth & ~perf_fifth] = self._shorthands['aug'] - - if not keep_bass: - reduced_chords['bass'] = 0 - else: - # remove bass notes if they are not part of the intervals anymore - reduced_chords['bass'] *= ivs[range(len(reduced_chords)), - reduced_chords['bass']] - # keep -1 in bass for no chords - reduced_chords['bass'][no_chord] = -1 - - return reduced_chords - - def convert_to_id(self, root, is_major): - if root == -1: - return 24 - else: - if is_major: - return root * 2 - else: - return root * 2 + 1 - - def get_converted_chord(self, filename): - loaded_chord = self.load_chords(filename) - triads = self.reduce_to_triads(loaded_chord['chord']) - - df = self.assign_chord_id(triads) - df['start'] = loaded_chord['start'] - df['end'] = loaded_chord['end'] - - return df - - def assign_chord_id(self, entry): - # maj, min chord only - # if you want to add other chord, change this part and get_converted_chord(reduce_to_triads) - df = pd.DataFrame(data=entry[['root', 'is_major']]) - df['chord_id'] = df.apply(lambda row: self.convert_to_id(row['root'], row['is_major']), axis=1) - return df - - def convert_to_id_voca(self, root, quality): - if root == -1: - return 169 - else: - if quality == 'min': - return root * 14 - elif quality == 'maj': - return root * 14 + 1 - elif quality == 'dim': - return root * 14 + 2 - elif quality == 'aug': - return root * 14 + 3 - elif quality == 'min6': - return root * 14 + 4 - elif quality == 'maj6': - return root * 14 + 5 - elif quality == 'min7': - return root * 14 + 6 - elif quality == 'minmaj7': - return root * 14 + 7 - elif quality == 'maj7': - return root * 14 + 8 - elif quality == '7': - return root * 14 + 9 - elif quality == 'dim7': - return root * 14 + 10 - elif quality == 'hdim7': - return root * 14 + 11 - elif quality == 'sus2': - return root * 14 + 12 - elif quality == 'sus4': - return root * 14 + 13 - else: - return 168 - - - def lab_file_error_modify(self, ref_labels): - for i in range(len(ref_labels)): - if ref_labels[i][-2:] == ':4': - ref_labels[i] = ref_labels[i].replace(':4', ':sus4') - elif ref_labels[i][-2:] == ':6': - ref_labels[i] = ref_labels[i].replace(':6', ':maj6') - elif ref_labels[i][-4:] == ':6/2': - ref_labels[i] = ref_labels[i].replace(':6/2', ':maj6/2') - elif ref_labels[i] == 'Emin/4': - ref_labels[i] = 'E:min/4' - elif ref_labels[i] == 'A7/3': - ref_labels[i] = 'A:7/3' - elif ref_labels[i] == 'Bb7/3': - ref_labels[i] = 'Bb:7/3' - elif ref_labels[i] == 'Bb7/5': - ref_labels[i] = 'Bb:7/5' - elif ref_labels[i].find(':') == -1: - if ref_labels[i].find('min') != -1: - ref_labels[i] = ref_labels[i][:ref_labels[i].find('min')] + ':' + ref_labels[i][ref_labels[i].find('min'):] - return ref_labels - diff --git a/audiocraft/audiocraft/data/chords.py b/audiocraft/audiocraft/data/chords.py deleted file mode 100644 index 1208be9a2d22bb470550c3129fc930eece99ca87..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/data/chords.py +++ /dev/null @@ -1,524 +0,0 @@ -# encoding: utf-8 -""" -This module contains chord evaluation functionality. - -It provides the evaluation measures used for the MIREX ACE task, and -tries to follow [1]_ and [2]_ as closely as possible. - -Notes ------ -This implementation tries to follow the references and their implementation -(e.g., https://github.com/jpauwels/MusOOEvaluator for [2]_). However, there -are some known (and possibly some unknown) differences. If you find one not -listed in the following, please file an issue: - - - Detected chord segments are adjusted to fit the length of the annotations. - In particular, this means that, if necessary, filler segments of 'no chord' - are added at beginnings and ends. This can result in different segmentation - scores compared to the original implementation. - -References ----------- -.. [1] Christopher Harte, "Towards Automatic Extraction of Harmony Information - from Music Signals." Dissertation, - Department for Electronic Engineering, Queen Mary University of London, - 2010. -.. [2] Johan Pauwels and Geoffroy Peeters. - "Evaluating Automatically Estimated Chord Sequences." - In Proceedings of ICASSP 2013, Vancouver, Canada, 2013. - -""" - -import numpy as np -import pandas as pd - - -CHORD_DTYPE = [('root', np.int_), - ('bass', np.int_), - ('intervals', np.int_, (12,)), - ('is_major',np.bool_)] - -CHORD_ANN_DTYPE = [('start', np.float32), - ('end', np.float32), - ('chord', CHORD_DTYPE)] - -NO_CHORD = (-1, -1, np.zeros(12, dtype=np.int_), False) -UNKNOWN_CHORD = (-1, -1, np.ones(12, dtype=np.int_) * -1, False) - -PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] - - -def idx_to_chord(idx): - if idx == 24: - return "-" - elif idx == 25: - return u"\u03B5" - - minmaj = idx % 2 - root = idx // 2 - - return PITCH_CLASS[root] + ("M" if minmaj == 0 else "m") - -class Chords: - - def __init__(self): - self._shorthands = { - 'maj': self.interval_list('(1,3,5)'), - 'min': self.interval_list('(1,b3,5)'), - 'dim': self.interval_list('(1,b3,b5)'), - 'aug': self.interval_list('(1,3,#5)'), - 'maj7': self.interval_list('(1,3,5,7)'), - 'min7': self.interval_list('(1,b3,5,b7)'), - '7': self.interval_list('(1,3,5,b7)'), - '6': self.interval_list('(1,6)'), # custom - '5': self.interval_list('(1,5)'), - '4': self.interval_list('(1,4)'), # custom - '1': self.interval_list('(1)'), - 'dim7': self.interval_list('(1,b3,b5,bb7)'), - 'hdim7': self.interval_list('(1,b3,b5,b7)'), - 'minmaj7': self.interval_list('(1,b3,5,7)'), - 'maj6': self.interval_list('(1,3,5,6)'), - 'min6': self.interval_list('(1,b3,5,6)'), - '9': self.interval_list('(1,3,5,b7,9)'), - 'maj9': self.interval_list('(1,3,5,7,9)'), - 'min9': self.interval_list('(1,b3,5,b7,9)'), - 'add9': self.interval_list('(1,3,5,9)'), # custom - 'sus2': self.interval_list('(1,2,5)'), - 'sus4': self.interval_list('(1,4,5)'), - '7sus2': self.interval_list('(1,2,5,b7)'), # custom - '7sus4': self.interval_list('(1,4,5,b7)'), # custom - '11': self.interval_list('(1,3,5,b7,9,11)'), - 'min11': self.interval_list('(1,b3,5,b7,9,11)'), - '13': self.interval_list('(1,3,5,b7,13)'), - 'maj13': self.interval_list('(1,3,5,7,13)'), - 'min13': self.interval_list('(1,b3,5,b7,13)') - } - - def chords(self, labels): - - """ - Transform a list of chord labels into an array of internal numeric - representations. - - Parameters - ---------- - labels : list - List of chord labels (str). - - Returns - ------- - chords : numpy.array - Structured array with columns 'root', 'bass', and 'intervals', - containing a numeric representation of chords. - - """ - crds = np.zeros(len(labels), dtype=CHORD_DTYPE) - cache = {} - for i, lbl in enumerate(labels): - cv = cache.get(lbl, None) - if cv is None: - cv = self.chord(lbl) - cache[lbl] = cv - crds[i] = cv - - return crds - - def label_error_modify(self, label): - if label == 'Emin/4': label = 'E:min/4' - elif label == 'A7/3': label = 'A:7/3' - elif label == 'Bb7/3': label = 'Bb:7/3' - elif label == 'Bb7/5': label = 'Bb:7/5' - elif label.find(':') == -1: - if label.find('min') != -1: - label = label[:label.find('min')] + ':' + label[label.find('min'):] - return label - - def chord(self, label): - """ - Transform a chord label into the internal numeric represenation of - (root, bass, intervals array). - - Parameters - ---------- - label : str - Chord label. - - Returns - ------- - chord : tuple - Numeric representation of the chord: (root, bass, intervals array). - - """ - - - is_major = False - - if label == 'N': - return NO_CHORD - if label == 'X': - return UNKNOWN_CHORD - - label = self.label_error_modify(label) - - c_idx = label.find(':') - s_idx = label.find('/') - - if c_idx == -1: - quality_str = 'maj' - if s_idx == -1: - root_str = label - bass_str = '' - else: - root_str = label[:s_idx] - bass_str = label[s_idx + 1:] - else: - root_str = label[:c_idx] - if s_idx == -1: - quality_str = label[c_idx + 1:] - bass_str = '' - else: - quality_str = label[c_idx + 1:s_idx] - bass_str = label[s_idx + 1:] - - root = self.pitch(root_str) - bass = self.interval(bass_str) if bass_str else 0 - ivs = self.chord_intervals(quality_str) - ivs[bass] = 1 - - if 'min' in quality_str: - is_major = False - else: - is_major = True - - - return root, bass, ivs, is_major - - _l = [0, 1, 1, 0, 1, 1, 1] - _chroma_id = (np.arange(len(_l) * 2) + 1) + np.array(_l + _l).cumsum() - 1 - - def modify(self, base_pitch, modifier): - """ - Modify a pitch class in integer representation by a given modifier string. - - A modifier string can be any sequence of 'b' (one semitone down) - and '#' (one semitone up). - - Parameters - ---------- - base_pitch : int - Pitch class as integer. - modifier : str - String of modifiers ('b' or '#'). - - Returns - ------- - modified_pitch : int - Modified root note. - - """ - for m in modifier: - if m == 'b': - base_pitch -= 1 - elif m == '#': - base_pitch += 1 - else: - raise ValueError('Unknown modifier: {}'.format(m)) - return base_pitch - - def pitch(self, pitch_str): - """ - Convert a string representation of a pitch class (consisting of root - note and modifiers) to an integer representation. - - Parameters - ---------- - pitch_str : str - String representation of a pitch class. - - Returns - ------- - pitch : int - Integer representation of a pitch class. - - """ - return self.modify(self._chroma_id[(ord(pitch_str[0]) - ord('C')) % 7], - pitch_str[1:]) % 12 - - def interval(self, interval_str): - """ - Convert a string representation of a musical interval into a pitch class - (e.g. a minor seventh 'b7' into 10, because it is 10 semitones above its - base note). - - Parameters - ---------- - interval_str : str - Musical interval. - - Returns - ------- - pitch_class : int - Number of semitones to base note of interval. - - """ - for i, c in enumerate(interval_str): - if c.isdigit(): - return self.modify(self._chroma_id[int(interval_str[i:]) - 1], - interval_str[:i]) % 12 - - def interval_list(self, intervals_str, given_pitch_classes=None): - """ - Convert a list of intervals given as string to a binary pitch class - representation. For example, 'b3, 5' would become - [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]. - - Parameters - ---------- - intervals_str : str - List of intervals as comma-separated string (e.g. 'b3, 5'). - given_pitch_classes : None or numpy array - If None, start with empty pitch class array, if numpy array of length - 12, this array will be modified. - - Returns - ------- - pitch_classes : numpy array - Binary pitch class representation of intervals. - - """ - if given_pitch_classes is None: - given_pitch_classes = np.zeros(12, dtype=np.int_) - for int_def in intervals_str[1:-1].split(','): - int_def = int_def.strip() - if int_def[0] == '*': - given_pitch_classes[self.interval(int_def[1:])] = 0 - else: - given_pitch_classes[self.interval(int_def)] = 1 - return given_pitch_classes - - # mapping of shorthand interval notations to the actual interval representation - - def chord_intervals(self, quality_str): - """ - Convert a chord quality string to a pitch class representation. For - example, 'maj' becomes [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0]. - - Parameters - ---------- - quality_str : str - String defining the chord quality. - - Returns - ------- - pitch_classes : numpy array - Binary pitch class representation of chord quality. - - """ - list_idx = quality_str.find('(') - if list_idx == -1: - return self._shorthands[quality_str].copy() - if list_idx != 0: - ivs = self._shorthands[quality_str[:list_idx]].copy() - else: - ivs = np.zeros(12, dtype=np.int_) - - - return self.interval_list(quality_str[list_idx:], ivs) - - def load_chords(self, filename): - """ - Load chords from a text file. - - The chord must follow the syntax defined in [1]_. - - Parameters - ---------- - filename : str - File containing chord segments. - - Returns - ------- - crds : numpy structured array - Structured array with columns "start", "end", and "chord", - containing the beginning, end, and chord definition of chord - segments. - - References - ---------- - .. [1] Christopher Harte, "Towards Automatic Extraction of Harmony - Information from Music Signals." Dissertation, - Department for Electronic Engineering, Queen Mary University of - London, 2010. - - """ - start, end, chord_labels = [], [], [] - with open(filename, 'r') as f: - for line in f: - if line: - - splits = line.split() - if len(splits) == 3: - - s = splits[0] - e = splits[1] - l = splits[2] - - start.append(float(s)) - end.append(float(e)) - chord_labels.append(l) - - crds = np.zeros(len(start), dtype=CHORD_ANN_DTYPE) - crds['start'] = start - crds['end'] = end - crds['chord'] = self.chords(chord_labels) - - return crds - - def reduce_to_triads(self, chords, keep_bass=False): - """ - Reduce chords to triads. - - The function follows the reduction rules implemented in [1]_. If a chord - chord does not contain a third, major second or fourth, it is reduced to - a power chord. If it does not contain neither a third nor a fifth, it is - reduced to a single note "chord". - - Parameters - ---------- - chords : numpy structured array - Chords to be reduced. - keep_bass : bool - Indicates whether to keep the bass note or set it to 0. - - Returns - ------- - reduced_chords : numpy structured array - Chords reduced to triads. - - References - ---------- - .. [1] Johan Pauwels and Geoffroy Peeters. - "Evaluating Automatically Estimated Chord Sequences." - In Proceedings of ICASSP 2013, Vancouver, Canada, 2013. - - """ - unison = chords['intervals'][:, 0].astype(bool) - maj_sec = chords['intervals'][:, 2].astype(bool) - min_third = chords['intervals'][:, 3].astype(bool) - maj_third = chords['intervals'][:, 4].astype(bool) - perf_fourth = chords['intervals'][:, 5].astype(bool) - dim_fifth = chords['intervals'][:, 6].astype(bool) - perf_fifth = chords['intervals'][:, 7].astype(bool) - aug_fifth = chords['intervals'][:, 8].astype(bool) - no_chord = (chords['intervals'] == NO_CHORD[-1]).all(axis=1) - - reduced_chords = chords.copy() - ivs = reduced_chords['intervals'] - - ivs[~no_chord] = self.interval_list('(1)') - ivs[unison & perf_fifth] = self.interval_list('(1,5)') - ivs[~perf_fourth & maj_sec] = self._shorthands['sus2'] - ivs[perf_fourth & ~maj_sec] = self._shorthands['sus4'] - - ivs[min_third] = self._shorthands['min'] - ivs[min_third & aug_fifth & ~perf_fifth] = self.interval_list('(1,b3,#5)') - ivs[min_third & dim_fifth & ~perf_fifth] = self._shorthands['dim'] - - ivs[maj_third] = self._shorthands['maj'] - ivs[maj_third & dim_fifth & ~perf_fifth] = self.interval_list('(1,3,b5)') - ivs[maj_third & aug_fifth & ~perf_fifth] = self._shorthands['aug'] - - if not keep_bass: - reduced_chords['bass'] = 0 - else: - # remove bass notes if they are not part of the intervals anymore - reduced_chords['bass'] *= ivs[range(len(reduced_chords)), - reduced_chords['bass']] - # keep -1 in bass for no chords - reduced_chords['bass'][no_chord] = -1 - - return reduced_chords - - def convert_to_id(self, root, is_major): - if root == -1: - return 24 - else: - if is_major: - return root * 2 - else: - return root * 2 + 1 - - def get_converted_chord(self, filename): - loaded_chord = self.load_chords(filename) - triads = self.reduce_to_triads(loaded_chord['chord']) - - df = self.assign_chord_id(triads) - df['start'] = loaded_chord['start'] - df['end'] = loaded_chord['end'] - - return df - - def assign_chord_id(self, entry): - # maj, min chord only - # if you want to add other chord, change this part and get_converted_chord(reduce_to_triads) - df = pd.DataFrame(data=entry[['root', 'is_major']]) - df['chord_id'] = df.apply(lambda row: self.convert_to_id(row['root'], row['is_major']), axis=1) - return df - - def convert_to_id_voca(self, root, quality): - if root == -1: - return 169 - else: - if quality == 'min': - return root * 14 - elif quality == 'maj': - return root * 14 + 1 - elif quality == 'dim': - return root * 14 + 2 - elif quality == 'aug': - return root * 14 + 3 - elif quality == 'min6': - return root * 14 + 4 - elif quality == 'maj6': - return root * 14 + 5 - elif quality == 'min7': - return root * 14 + 6 - elif quality == 'minmaj7': - return root * 14 + 7 - elif quality == 'maj7': - return root * 14 + 8 - elif quality == '7': - return root * 14 + 9 - elif quality == 'dim7': - return root * 14 + 10 - elif quality == 'hdim7': - return root * 14 + 11 - elif quality == 'sus2': - return root * 14 + 12 - elif quality == 'sus4': - return root * 14 + 13 - else: - return 168 - - - def lab_file_error_modify(self, ref_labels): - for i in range(len(ref_labels)): - if ref_labels[i][-2:] == ':4': - ref_labels[i] = ref_labels[i].replace(':4', ':sus4') - elif ref_labels[i][-2:] == ':6': - ref_labels[i] = ref_labels[i].replace(':6', ':maj6') - elif ref_labels[i][-4:] == ':6/2': - ref_labels[i] = ref_labels[i].replace(':6/2', ':maj6/2') - elif ref_labels[i] == 'Emin/4': - ref_labels[i] = 'E:min/4' - elif ref_labels[i] == 'A7/3': - ref_labels[i] = 'A:7/3' - elif ref_labels[i] == 'Bb7/3': - ref_labels[i] = 'Bb:7/3' - elif ref_labels[i] == 'Bb7/5': - ref_labels[i] = 'Bb:7/5' - elif ref_labels[i].find(':') == -1: - if ref_labels[i].find('min') != -1: - ref_labels[i] = ref_labels[i][:ref_labels[i].find('min')] + ':' + ref_labels[i][ref_labels[i].find('min'):] - return ref_labels - diff --git a/audiocraft/audiocraft/data/info_audio_dataset.py b/audiocraft/audiocraft/data/info_audio_dataset.py deleted file mode 100644 index 47ab4b1594faf1e9f1ce962fb980d80295b1f079..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/data/info_audio_dataset.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Base classes for the datasets that also provide non-audio metadata, -e.g. description, text transcription etc. -""" -from dataclasses import dataclass -import logging -import math -import re -import typing as tp - -import torch - -from .audio_dataset import AudioDataset, AudioMeta -from ..environment import AudioCraftEnvironment -from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes - - -logger = logging.getLogger(__name__) - - -def _clusterify_meta(meta: AudioMeta) -> AudioMeta: - """Monkey-patch meta to match cluster specificities.""" - meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path) - if meta.info_path is not None: - meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path) - return meta - - -def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: - """Monkey-patch all meta to match cluster specificities.""" - return [_clusterify_meta(m) for m in meta] - - -@dataclass -class AudioInfo(SegmentWithAttributes): - """Dummy SegmentInfo with empty attributes. - - The InfoAudioDataset is expected to return metadata that inherits - from SegmentWithAttributes class and can return conditioning attributes. - - This basically guarantees all datasets will be compatible with current - solver that contain conditioners requiring this. - """ - audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM. - - def to_condition_attributes(self) -> ConditioningAttributes: - return ConditioningAttributes() - - -class InfoAudioDataset(AudioDataset): - """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform. - - See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments. - """ - def __init__(self, meta: tp.List[AudioMeta], **kwargs): - super().__init__(clusterify_all_meta(meta), **kwargs) - - def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]: - if not self.return_info: - wav = super().__getitem__(index) - assert isinstance(wav, torch.Tensor) - return wav - wav, meta = super().__getitem__(index) - return wav, AudioInfo(**meta.to_dict()) - - -def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]: - """Preprocess a single keyword or possible a list of keywords.""" - if isinstance(value, list): - return get_keyword_list(value) - else: - return get_keyword(value) - - -def get_string(value: tp.Optional[str]) -> tp.Optional[str]: - """Preprocess a single keyword.""" - if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': - return None - else: - return value.strip() - - -def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]: - """Preprocess a single keyword.""" - if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': - return None - else: - return value.strip().lower() - - -def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]: - """Preprocess a list of keywords.""" - if isinstance(values, str): - values = [v.strip() for v in re.split(r'[,\s]', values)] - elif isinstance(values, float) and math.isnan(values): - values = [] - if not isinstance(values, list): - logger.debug(f"Unexpected keyword list {values}") - values = [str(values)] - - kws = [get_keyword(v) for v in values] - kw_list = [k for k in kws if k is not None] - if len(kw_list) == 0: - return None - else: - return kw_list diff --git a/audiocraft/audiocraft/data/music_dataset.py b/audiocraft/audiocraft/data/music_dataset.py deleted file mode 100644 index 0d31516ddc3efa7669a946500932991be892a6e2..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/data/music_dataset.py +++ /dev/null @@ -1,349 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Dataset of music tracks with rich metadata. -""" -from dataclasses import dataclass, field, fields, replace -import gzip -import json -import logging -from pathlib import Path -import random -import typing as tp -import pretty_midi -import numpy as np - -import torch -import torch.nn.functional as F -from .btc_chords import Chords - -from .info_audio_dataset import ( - InfoAudioDataset, - AudioInfo, - get_keyword_list, - get_keyword, - get_string -) -from ..modules.conditioners import ( - ConditioningAttributes, - JointEmbedCondition, - WavCondition, - ChordCondition, - BeatCondition -) -from ..utils.utils import warn_once - - -logger = logging.getLogger(__name__) - -CHORDS = Chords() - - -@dataclass -class MusicInfo(AudioInfo): - """Segment info augmented with music metadata. - """ - # music-specific metadata - title: tp.Optional[str] = None - artist: tp.Optional[str] = None # anonymized artist id, used to ensure no overlap between splits - key: tp.Optional[str] = None - bpm: tp.Optional[float] = None - genre: tp.Optional[str] = None - moods: tp.Optional[list] = None - keywords: tp.Optional[list] = None - description: tp.Optional[str] = None - name: tp.Optional[str] = None - instrument: tp.Optional[str] = None - chord: tp.Optional[ChordCondition] = None - beat: tp.Optional[BeatCondition] = None - # original wav accompanying the metadata - self_wav: tp.Optional[WavCondition] = None - # dict mapping attributes names to tuple of wav, text and metadata - joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict) - - @property - def has_music_meta(self) -> bool: - return self.name is not None - - def to_condition_attributes(self) -> ConditioningAttributes: - out = ConditioningAttributes() - for _field in fields(self): - key, value = _field.name, getattr(self, _field.name) - if key == 'self_wav': - out.wav[key] = value - elif key == 'chord': - out.chord[key] = value - elif key == 'beat': - out.beat[key] = value - elif key == 'joint_embed': - for embed_attribute, embed_cond in value.items(): - out.joint_embed[embed_attribute] = embed_cond - else: - if isinstance(value, list): - value = ' '.join(value) - out.text[key] = value - return out - - @staticmethod - def attribute_getter(attribute): - if attribute == 'bpm': - preprocess_func = get_bpm - elif attribute == 'key': - preprocess_func = get_musical_key - elif attribute in ['moods', 'keywords']: - preprocess_func = get_keyword_list - elif attribute in ['genre', 'name', 'instrument']: - preprocess_func = get_keyword - elif attribute in ['title', 'artist', 'description']: - preprocess_func = get_string - else: - preprocess_func = None - return preprocess_func - - @classmethod - def from_dict(cls, dictionary: dict, fields_required: bool = False): - _dictionary: tp.Dict[str, tp.Any] = {} - - # allow a subset of attributes to not be loaded from the dictionary - # these attributes may be populated later - post_init_attributes = ['self_wav', 'chord', 'beat', 'joint_embed'] - optional_fields = ['keywords'] - - for _field in fields(cls): - if _field.name in post_init_attributes: - continue - elif _field.name not in dictionary: - if fields_required and _field.name not in optional_fields: - raise KeyError(f"Unexpected missing key: {_field.name}") - else: - preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name) - value = dictionary[_field.name] - if preprocess_func: - value = preprocess_func(value) - _dictionary[_field.name] = value - return cls(**_dictionary) - - -def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0., - drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo: - """Augment MusicInfo description with additional metadata fields and potential dropout. - Additional textual attributes are added given probability 'merge_text_conditions_p' and - the original textual description is dropped from the augmented description given probability drop_desc_p. - - Args: - music_info (MusicInfo): The music metadata to augment. - merge_text_p (float): Probability of merging additional metadata to the description. - If provided value is 0, then no merging is performed. - drop_desc_p (float): Probability of dropping the original description on text merge. - if provided value is 0, then no drop out is performed. - drop_other_p (float): Probability of dropping the other fields used for text augmentation. - Returns: - MusicInfo: The MusicInfo with augmented textual description. - """ - def is_valid_field(field_name: str, field_value: tp.Any) -> bool: - valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords'] - valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list)) - keep_field = random.uniform(0, 1) < drop_other_p - return valid_field_name and valid_field_value and keep_field - - def process_value(v: tp.Any) -> str: - if isinstance(v, (int, float, str)): - return str(v) - if isinstance(v, list): - return ", ".join(v) - else: - raise ValueError(f"Unknown type for text value! ({type(v), v})") - - description = music_info.description - - metadata_text = "" - # metadata_text = "rock style music, consistent rhythm, catchy song." - if random.uniform(0, 1) < merge_text_p: - meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}' - for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))] - random.shuffle(meta_pairs) - metadata_text = ". ".join(meta_pairs) - description = description if not random.uniform(0, 1) < drop_desc_p else None - logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}") - - if description is None: - description = metadata_text if len(metadata_text) > 1 else None - else: - description = ". ".join([description.rstrip('.'), metadata_text]) - description = description.strip() if description else None - - music_info = replace(music_info) - music_info.description = description - return music_info - - -class Paraphraser: - def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.): - self.paraphrase_p = paraphrase_p - open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open - with open_fn(paraphrase_source, 'rb') as f: # type: ignore - self.paraphrase_source = json.loads(f.read()) - logger.info(f"loaded paraphrasing source from: {paraphrase_source}") - - def sample_paraphrase(self, audio_path: str, description: str): - if random.random() >= self.paraphrase_p: - return description - info_path = Path(audio_path).with_suffix('.json') - if info_path not in self.paraphrase_source: - warn_once(logger, f"{info_path} not in paraphrase source!") - return description - new_desc = random.choice(self.paraphrase_source[info_path]) - logger.debug(f"{description} -> {new_desc}") - return new_desc - - -class MusicDataset(InfoAudioDataset): - """Music dataset is an AudioDataset with music-related metadata. - - Args: - info_fields_required (bool): Whether to enforce having required fields. - merge_text_p (float): Probability of merging additional metadata to the description. - drop_desc_p (float): Probability of dropping the original description on text merge. - drop_other_p (float): Probability of dropping the other fields used for text augmentation. - joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned. - paraphrase_source (str, optional): Path to the .json or .json.gz file containing the - paraphrases for the description. The json should be a dict with keys are the - original info path (e.g. track_path.json) and each value is a list of possible - paraphrased. - paraphrase_p (float): probability of taking a paraphrase. - - See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments. - """ - def __init__(self, *args, info_fields_required: bool = True, - merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0., - joint_embed_attributes: tp.List[str] = [], - paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0, - **kwargs): - kwargs['return_info'] = True # We require the info for each song of the dataset. - super().__init__(*args, **kwargs) - self.info_fields_required = info_fields_required - self.merge_text_p = merge_text_p - self.drop_desc_p = drop_desc_p - self.drop_other_p = drop_other_p - self.joint_embed_attributes = joint_embed_attributes - self.paraphraser = None - self.downsample_rate = 640 - self.sr = 32000 - if paraphrase_source is not None: - self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p) - - def __getitem__(self, index): - wav, info = super().__getitem__(index) # wav_seg and seg_info - info_data = info.to_dict() - - # unpack info - target_sr = self.sr - n_frames_wave = info.n_frames - n_frames_feat = int(info.n_frames // self.downsample_rate) - - music_info_path = str(info.meta.path).replace('no_vocal.wav', 'tags.json') - chord_path = str(info.meta.path).replace('no_vocal.wav', 'chord.lab') - beats_path = str(info.meta.path).replace('no_vocal.wav', 'beats.npy') - - if all([ - not Path(music_info_path).exists(), - not Path(beats_path).exists(), - not Path(chord_path).exists(), - ]): - raise FileNotFoundError - - ### music info - with open(music_info_path, 'r') as json_file: - music_data = json.load(json_file) - music_data.update(info_data) - music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required) - if self.paraphraser is not None: - music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description) - if self.merge_text_p: - music_info = augment_music_info_description( - music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p) - - - ### load features to tensors ### - feat_hz = target_sr/self.downsample_rate - ## beat&bar: 2 x T - feat_beats = np.zeros((2, n_frames_feat)) - - beats_np = np.load(beats_path) - beat_time = beats_np[:, 0] - bar_time = beats_np[np.where(beats_np[:, 1] == 1)[0], 0] - beat_frame = [ - int((t-info.seek_time)*feat_hz) for t in beat_time - if (t >= info.seek_time and t < info.seek_time + self.segment_duration)] - bar_frame =[ - int((t-info.seek_time)*feat_hz) for t in bar_time - if (t >= info.seek_time and t < info.seek_time + self.segment_duration)] - feat_beats[0, beat_frame] = 1 - feat_beats[1, bar_frame] = 1 - kernel = np.array([0.05, 0.1, 0.3, 0.9, 0.3, 0.1, 0.05]) - feat_beats[0] = np.convolve(feat_beats[0] , kernel, 'same') # apply soft kernel - beat_events = feat_beats[0] + feat_beats[1] - beat_events = torch.tensor(beat_events).unsqueeze(0) # [T] -> [1, T] - - music_info.beat = BeatCondition(beat=beat_events[None], length=torch.tensor([n_frames_feat]), - bpm=[music_data["bpm"]], path=[music_info_path], seek_frame=[info.seek_time*target_sr//self.downsample_rate]) - - ## chord: 12 x T - feat_chord = np.zeros((12, n_frames_feat)) # root| ivs - with open(chord_path, 'r') as f: - for line in f.readlines(): - splits = line.split() - if len(splits) == 3: - st_sec, ed_sec, ctag = splits - st_sec = float(st_sec) - info.seek_time - ed_sec = float(ed_sec) - info.seek_time - st_frame = int(st_sec*feat_hz) - ed_frame = int(ed_sec*feat_hz) - - # 12 chorma - mhot = CHORDS.chord(ctag) - final_vec = np.roll(mhot[2], mhot[0]) - - final_vec = final_vec[..., None] - feat_chord[:, st_frame:ed_frame] = final_vec - feat_chord = torch.from_numpy(feat_chord) - - music_info.chord = ChordCondition( - chord=feat_chord[None], length=torch.tensor([n_frames_feat]), - bpm=[music_data["bpm"]], path=[chord_path], seek_frame=[info.seek_time*self.sr//self.downsample_rate]) - - music_info.self_wav = WavCondition( - wav=wav[None], length=torch.tensor([info.n_frames]), - sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time]) - - for att in self.joint_embed_attributes: - att_value = getattr(music_info, att) - joint_embed_cond = JointEmbedCondition( - wav[None], [att_value], torch.tensor([info.n_frames]), - sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time]) - music_info.joint_embed[att] = joint_embed_cond - - return wav, music_info - - -def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]: - """Preprocess key keywords, discarding them if there are multiple key defined.""" - if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': - return None - elif ',' in value: - # For now, we discard when multiple keys are defined separated with comas - return None - else: - return value.strip().lower() - - -def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]: - """Preprocess to a float.""" - if value is None: - return None - try: - return float(value) - except ValueError: - return None diff --git a/audiocraft/audiocraft/data/sound_dataset.py b/audiocraft/audiocraft/data/sound_dataset.py deleted file mode 100644 index 8b88cbe8016b4bd28c2de749177c9af29f7755fc..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/data/sound_dataset.py +++ /dev/null @@ -1,330 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Dataset of audio with a simple description. -""" - -from dataclasses import dataclass, fields, replace -import json -from pathlib import Path -import random -import typing as tp - -import numpy as np -import torch - -from .info_audio_dataset import ( - InfoAudioDataset, - get_keyword_or_keyword_list -) -from ..modules.conditioners import ( - ConditioningAttributes, - SegmentWithAttributes, - WavCondition, -) - - -EPS = torch.finfo(torch.float32).eps -TARGET_LEVEL_LOWER = -35 -TARGET_LEVEL_UPPER = -15 - - -@dataclass -class SoundInfo(SegmentWithAttributes): - """Segment info augmented with Sound metadata. - """ - description: tp.Optional[str] = None - self_wav: tp.Optional[torch.Tensor] = None - - @property - def has_sound_meta(self) -> bool: - return self.description is not None - - def to_condition_attributes(self) -> ConditioningAttributes: - out = ConditioningAttributes() - - for _field in fields(self): - key, value = _field.name, getattr(self, _field.name) - if key == 'self_wav': - out.wav[key] = value - else: - out.text[key] = value - return out - - @staticmethod - def attribute_getter(attribute): - if attribute == 'description': - preprocess_func = get_keyword_or_keyword_list - else: - preprocess_func = None - return preprocess_func - - @classmethod - def from_dict(cls, dictionary: dict, fields_required: bool = False): - _dictionary: tp.Dict[str, tp.Any] = {} - - # allow a subset of attributes to not be loaded from the dictionary - # these attributes may be populated later - post_init_attributes = ['self_wav'] - - for _field in fields(cls): - if _field.name in post_init_attributes: - continue - elif _field.name not in dictionary: - if fields_required: - raise KeyError(f"Unexpected missing key: {_field.name}") - else: - preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name) - value = dictionary[_field.name] - if preprocess_func: - value = preprocess_func(value) - _dictionary[_field.name] = value - return cls(**_dictionary) - - -class SoundDataset(InfoAudioDataset): - """Sound audio dataset: Audio dataset with environmental sound-specific metadata. - - Args: - info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata. - external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset. - The metadata files contained in this folder are expected to match the stem of the audio file with - a json extension. - aug_p (float): Probability of performing audio mixing augmentation on the batch. - mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation. - mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation. - mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation. - mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation. - kwargs: Additional arguments for AudioDataset. - - See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments. - """ - def __init__( - self, - *args, - info_fields_required: bool = True, - external_metadata_source: tp.Optional[str] = None, - aug_p: float = 0., - mix_p: float = 0., - mix_snr_low: int = -5, - mix_snr_high: int = 5, - mix_min_overlap: float = 0.5, - **kwargs - ): - kwargs['return_info'] = True # We require the info for each song of the dataset. - super().__init__(*args, **kwargs) - self.info_fields_required = info_fields_required - self.external_metadata_source = external_metadata_source - self.aug_p = aug_p - self.mix_p = mix_p - if self.aug_p > 0: - assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0" - assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio" - self.mix_snr_low = mix_snr_low - self.mix_snr_high = mix_snr_high - self.mix_min_overlap = mix_min_overlap - - def _get_info_path(self, path: tp.Union[str, Path]) -> Path: - """Get path of JSON with metadata (description, etc.). - If there exists a JSON with the same name as 'path.name', then it will be used. - Else, such JSON will be searched for in an external json source folder if it exists. - """ - info_path = Path(path).with_suffix('.json') - if Path(info_path).exists(): - return info_path - elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists(): - return Path(self.external_metadata_source) / info_path.name - else: - raise Exception(f"Unable to find a metadata JSON for path: {path}") - - def __getitem__(self, index): - wav, info = super().__getitem__(index) - info_data = info.to_dict() - info_path = self._get_info_path(info.meta.path) - if Path(info_path).exists(): - with open(info_path, 'r') as json_file: - sound_data = json.load(json_file) - sound_data.update(info_data) - sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required) - # if there are multiple descriptions, sample one randomly - if isinstance(sound_info.description, list): - sound_info.description = random.choice(sound_info.description) - else: - sound_info = SoundInfo.from_dict(info_data, fields_required=False) - - sound_info.self_wav = WavCondition( - wav=wav[None], length=torch.tensor([info.n_frames]), - sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time]) - - return wav, sound_info - - def collater(self, samples): - # when training, audio mixing is performed in the collate function - wav, sound_info = super().collater(samples) # SoundDataset always returns infos - if self.aug_p > 0: - wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p, - snr_low=self.mix_snr_low, snr_high=self.mix_snr_high, - min_overlap=self.mix_min_overlap) - return wav, sound_info - - -def rms_f(x: torch.Tensor) -> torch.Tensor: - return (x ** 2).mean(1).pow(0.5) - - -def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor: - """Normalize the signal to the target level.""" - rms = rms_f(audio) - scalar = 10 ** (target_level / 20) / (rms + EPS) - audio = audio * scalar.unsqueeze(1) - return audio - - -def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor: - return (abs(audio) > clipping_threshold).any(1) - - -def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor: - start = random.randint(0, int(src.shape[1] * (1 - min_overlap))) - remainder = src.shape[1] - start - if dst.shape[1] > remainder: - src[:, start:] = src[:, start:] + dst[:, :remainder] - else: - src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst - return src - - -def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float, - target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor: - """Function to mix clean speech and noise at various SNR levels. - - Args: - clean (torch.Tensor): Clean audio source to mix, of shape [B, T]. - noise (torch.Tensor): Noise audio source to mix, of shape [B, T]. - snr (int): SNR level when mixing. - min_overlap (float): Minimum overlap between the two mixed sources. - target_level (int): Gain level in dB. - clipping_threshold (float): Threshold for clipping the audio. - Returns: - torch.Tensor: The mixed audio, of shape [B, T]. - """ - if clean.shape[1] > noise.shape[1]: - noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1])) - else: - noise = noise[:, :clean.shape[1]] - - # normalizing to -25 dB FS - clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS) - clean = normalize(clean, target_level) - rmsclean = rms_f(clean) - - noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS) - noise = normalize(noise, target_level) - rmsnoise = rms_f(noise) - - # set the noise level for a given SNR - noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1) - noisenewlevel = noise * noisescalar - - # mix noise and clean speech - noisyspeech = mix_pair(clean, noisenewlevel, min_overlap) - - # randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value - # there is a chance of clipping that might happen with very less probability, which is not a major issue. - noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER) - rmsnoisy = rms_f(noisyspeech) - scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1) - noisyspeech = noisyspeech * scalarnoisy - clean = clean * scalarnoisy - noisenewlevel = noisenewlevel * scalarnoisy - - # final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly - clipped = is_clipped(noisyspeech) - if clipped.any(): - noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS) - noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel - - return noisyspeech - - -def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float): - if snr_low == snr_high: - snr = snr_low - else: - snr = np.random.randint(snr_low, snr_high) - mix = snr_mixer(src, dst, snr, min_overlap) - return mix - - -def mix_text(src_text: str, dst_text: str): - """Mix text from different sources by concatenating them.""" - if src_text == dst_text: - return src_text - return src_text + " " + dst_text - - -def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float, - snr_low: int, snr_high: int, min_overlap: float): - """Mix samples within a batch, summing the waveforms and concatenating the text infos. - - Args: - wavs (torch.Tensor): Audio tensors of shape [B, C, T]. - infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio. - aug_p (float): Augmentation probability. - mix_p (float): Proportion of items in the batch to mix (and merge) together. - snr_low (int): Lowerbound for sampling SNR. - snr_high (int): Upperbound for sampling SNR. - min_overlap (float): Minimum overlap between mixed samples. - Returns: - tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs - and mixed SoundInfo for the given batch. - """ - # no mixing to perform within the batch - if mix_p == 0: - return wavs, infos - - if random.uniform(0, 1) < aug_p: - # perform all augmentations on waveforms as [B, T] - # randomly picking pairs of audio to mix - assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}" - wavs = wavs.mean(dim=1, keepdim=False) - B, T = wavs.shape - k = int(mix_p * B) - mixed_sources_idx = torch.randperm(B)[:k] - mixed_targets_idx = torch.randperm(B)[:k] - aug_wavs = snr_mix( - wavs[mixed_sources_idx], - wavs[mixed_targets_idx], - snr_low, - snr_high, - min_overlap, - ) - # mixing textual descriptions in metadata - descriptions = [info.description for info in infos] - aug_infos = [] - for i, j in zip(mixed_sources_idx, mixed_targets_idx): - text = mix_text(descriptions[i], descriptions[j]) - m = replace(infos[i]) - m.description = text - aug_infos.append(m) - - # back to [B, C, T] - aug_wavs = aug_wavs.unsqueeze(1) - assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch." - assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}" - assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch" - - return aug_wavs, aug_infos # [B, C, T] - else: - # randomly pick samples in the batch to match - # the batch size when performing audio mixing - B, C, T = wavs.shape - k = int(mix_p * B) - wav_idx = torch.randperm(B)[:k] - wavs = wavs[wav_idx] - infos = [infos[i] for i in wav_idx] - assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch" - - return wavs, infos # [B, C, T] diff --git a/audiocraft/audiocraft/data/zip.py b/audiocraft/audiocraft/data/zip.py deleted file mode 100644 index f0b17849d36991e7def35a14d3d518b9d867ce36..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/data/zip.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Utility for reading some info from inside a zip file. -""" - -import typing -import zipfile - -from dataclasses import dataclass -from functools import lru_cache -from typing_extensions import Literal - - -DEFAULT_SIZE = 32 -MODE = Literal['r', 'w', 'x', 'a'] - - -@dataclass(order=True) -class PathInZip: - """Hold a path of file within a zip file. - - Args: - path (str): The convention is :. - Let's assume there is a zip file /some/location/foo.zip - and inside of it is a json file located at /data/file1.json, - Then we expect path = "/some/location/foo.zip:/data/file1.json". - """ - - INFO_PATH_SEP = ':' - zip_path: str - file_path: str - - def __init__(self, path: str) -> None: - split_path = path.split(self.INFO_PATH_SEP) - assert len(split_path) == 2 - self.zip_path, self.file_path = split_path - - @classmethod - def from_paths(cls, zip_path: str, file_path: str): - return cls(zip_path + cls.INFO_PATH_SEP + file_path) - - def __str__(self) -> str: - return self.zip_path + self.INFO_PATH_SEP + self.file_path - - -def _open_zip(path: str, mode: MODE = 'r'): - return zipfile.ZipFile(path, mode) - - -_cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip) - - -def set_zip_cache_size(max_size: int): - """Sets the maximal LRU caching for zip file opening. - - Args: - max_size (int): the maximal LRU cache. - """ - global _cached_open_zip - _cached_open_zip = lru_cache(max_size)(_open_zip) - - -def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO: - """Opens a file stored inside a zip and returns a file-like object. - - Args: - path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of. - mode (str): The mode in which to open the file with. - Returns: - A file-like object for PathInZip. - """ - zf = _cached_open_zip(path_in_zip.zip_path) - return zf.open(path_in_zip.file_path) diff --git a/audiocraft/audiocraft/environment.py b/audiocraft/audiocraft/environment.py deleted file mode 100644 index adc7819305758bb50a9984928bfa7f13eabef5f5..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/environment.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Provides cluster and tools configuration across clusters (slurm, dora, utilities). -""" - -import logging -import os -from pathlib import Path -import re -import typing as tp - -import omegaconf - -from .utils.cluster import _guess_cluster_type - - -logger = logging.getLogger(__name__) - - -class AudioCraftEnvironment: - """Environment configuration for teams and clusters. - - AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment - or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment - provides pointers to a reference folder resolved automatically across clusters that is shared across team members, - allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically - map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters. - - The cluster type is identified automatically and base configuration file is read from config/teams.yaml. - Use the following environment variables to specify the cluster, team or configuration: - - AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type - cannot be inferred automatically. - AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration. - If not set, configuration is read from config/teams.yaml. - AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team. - Cluster configuration are shared across teams to match compute allocation, - specify your cluster configuration in the configuration file under a key mapping - your team name. - """ - _instance = None - DEFAULT_TEAM = "default" - - def __init__(self) -> None: - """Loads configuration.""" - self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM) - cluster_type = _guess_cluster_type() - cluster = os.getenv( - "AUDIOCRAFT_CLUSTER", cluster_type.value - ) - logger.info("Detecting cluster type %s", cluster_type) - - self.cluster: str = cluster - - config_path = os.getenv( - "AUDIOCRAFT_CONFIG", - Path(__file__) - .parent.parent.joinpath("config/teams", self.team) - .with_suffix(".yaml"), - ) - self.config = omegaconf.OmegaConf.load(config_path) - self._dataset_mappers = [] - cluster_config = self._get_cluster_config() - if "dataset_mappers" in cluster_config: - for pattern, repl in cluster_config["dataset_mappers"].items(): - regex = re.compile(pattern) - self._dataset_mappers.append((regex, repl)) - - def _get_cluster_config(self) -> omegaconf.DictConfig: - assert isinstance(self.config, omegaconf.DictConfig) - return self.config[self.cluster] - - @classmethod - def instance(cls): - if cls._instance is None: - cls._instance = cls() - return cls._instance - - @classmethod - def reset(cls): - """Clears the environment and forces a reload on next invocation.""" - cls._instance = None - - @classmethod - def get_team(cls) -> str: - """Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var. - If not defined, defaults to "labs". - """ - return cls.instance().team - - @classmethod - def get_cluster(cls) -> str: - """Gets the detected cluster. - This value can be overridden by the AUDIOCRAFT_CLUSTER env var. - """ - return cls.instance().cluster - - @classmethod - def get_dora_dir(cls) -> Path: - """Gets the path to the dora directory for the current team and cluster. - Value is overridden by the AUDIOCRAFT_DORA_DIR env var. - """ - cluster_config = cls.instance()._get_cluster_config() - dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"]) - logger.warning(f"Dora directory: {dora_dir}") - return Path(dora_dir) - - @classmethod - def get_reference_dir(cls) -> Path: - """Gets the path to the reference directory for the current team and cluster. - Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var. - """ - cluster_config = cls.instance()._get_cluster_config() - return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"])) - - @classmethod - def get_slurm_exclude(cls) -> tp.Optional[str]: - """Get the list of nodes to exclude for that cluster.""" - cluster_config = cls.instance()._get_cluster_config() - return cluster_config.get("slurm_exclude") - - @classmethod - def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str: - """Gets the requested partitions for the current team and cluster as a comma-separated string. - - Args: - partition_types (list[str], optional): partition types to retrieve. Values must be - from ['global', 'team']. If not provided, the global partition is returned. - """ - if not partition_types: - partition_types = ["global"] - - cluster_config = cls.instance()._get_cluster_config() - partitions = [ - cluster_config["partitions"][partition_type] - for partition_type in partition_types - ] - return ",".join(partitions) - - @classmethod - def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path: - """Converts reference placeholder in path with configured reference dir to resolve paths. - - Args: - path (str or Path): Path to resolve. - Returns: - Path: Resolved path. - """ - path = str(path) - - if path.startswith("//reference"): - reference_dir = cls.get_reference_dir() - logger.warn(f"Reference directory: {reference_dir}") - assert ( - reference_dir.exists() and reference_dir.is_dir() - ), f"Reference directory does not exist: {reference_dir}." - path = re.sub("^//reference", str(reference_dir), path) - - return Path(path) - - @classmethod - def apply_dataset_mappers(cls, path: str) -> str: - """Applies dataset mapping regex rules as defined in the configuration. - If no rules are defined, the path is returned as-is. - """ - instance = cls.instance() - - for pattern, repl in instance._dataset_mappers: - path = pattern.sub(repl, path) - - return path diff --git a/audiocraft/audiocraft/grids/__init__.py b/audiocraft/audiocraft/grids/__init__.py deleted file mode 100644 index 70643517cd1a8b4e712eca90e23411ae89937795..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Dora Grids.""" diff --git a/audiocraft/audiocraft/grids/_base_explorers.py b/audiocraft/audiocraft/grids/_base_explorers.py deleted file mode 100644 index d3f26666aa596f7bd2e8695c4f00e7963e978ceb..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/_base_explorers.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from abc import ABC, abstractmethod -import time -import typing as tp -from dora import Explorer -import treetable as tt - - -def get_sheep_ping(sheep) -> tp.Optional[str]: - """Return the amount of time since the Sheep made some update - to its log. Returns a str using the relevant time unit.""" - ping = None - if sheep.log is not None and sheep.log.exists(): - delta = time.time() - sheep.log.stat().st_mtime - if delta > 3600 * 24: - ping = f'{delta / (3600 * 24):.1f}d' - elif delta > 3600: - ping = f'{delta / (3600):.1f}h' - elif delta > 60: - ping = f'{delta / 60:.1f}m' - else: - ping = f'{delta:.1f}s' - return ping - - -class BaseExplorer(ABC, Explorer): - """Base explorer for AudioCraft grids. - - All task specific solvers are expected to implement the `get_grid_metrics` - method to specify logic about metrics to display for a given task. - - If additional stages are used, the child explorer must define how to handle - these new stages in the `process_history` and `process_sheep` methods. - """ - def stages(self): - return ["train", "valid", "evaluate"] - - def get_grid_meta(self): - """Returns the list of Meta information to display for each XP/job. - """ - return [ - tt.leaf("index", align=">"), - tt.leaf("name", wrap=140), - tt.leaf("state"), - tt.leaf("sig", align=">"), - tt.leaf("sid", align="<"), - ] - - @abstractmethod - def get_grid_metrics(self): - """Return the metrics that should be displayed in the tracking table. - """ - ... - - def process_sheep(self, sheep, history): - train = { - "epoch": len(history), - } - parts = {"train": train} - for metrics in history: - for key, sub in metrics.items(): - part = parts.get(key, {}) - if 'duration' in sub: - # Convert to minutes for readability. - sub['duration'] = sub['duration'] / 60. - part.update(sub) - parts[key] = part - ping = get_sheep_ping(sheep) - if ping is not None: - for name in self.stages(): - if name not in parts: - parts[name] = {} - # Add the ping to each part for convenience. - parts[name]['ping'] = ping - return parts diff --git a/audiocraft/audiocraft/grids/audiogen/__init__.py b/audiocraft/audiocraft/grids/audiogen/__init__.py deleted file mode 100644 index 8a0a2688450ce120088b79c3314a2f267394dc11..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/audiogen/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""AudioGen grids.""" diff --git a/audiocraft/audiocraft/grids/audiogen/audiogen_base_16khz.py b/audiocraft/audiocraft/grids/audiogen/audiogen_base_16khz.py deleted file mode 100644 index 190cc1d0a1e316347e8ebbdfc8de7e2942c1b3d7..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/audiogen/audiogen_base_16khz.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from ..musicgen._explorers import LMExplorer -from ...environment import AudioCraftEnvironment - - -@LMExplorer -def explorer(launcher): - partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) - launcher.slurm_(gpus=64, partition=partitions) - launcher.bind_(solver='audiogen/audiogen_base_16khz') - # replace this by the desired environmental sound dataset - launcher.bind_(dset='internal/sounds_16khz') - - fsdp = {'autocast': False, 'fsdp.use': True} - medium = {'model/lm/model_scale': 'medium'} - - launcher.bind_(fsdp) - launcher(medium) diff --git a/audiocraft/audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py b/audiocraft/audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py deleted file mode 100644 index 12f6d402a3c4a113d4c37be062790fa435b72104..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Evaluation with objective metrics for the pretrained AudioGen models. -This grid takes signature from the training grid and runs evaluation-only stage. - -When running the grid for the first time, please use: -REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval -and re-use the REGEN=1 option when the grid is changed to force regenerating it. - -Note that you need the proper metrics external libraries setup to use all -the objective metrics activated in this grid. Refer to the README for more information. -""" - -import os - -from ..musicgen._explorers import GenerationEvalExplorer -from ...environment import AudioCraftEnvironment -from ... import train - - -def eval(launcher, batch_size: int = 32): - opts = { - 'dset': 'audio/audiocaps_16khz', - 'solver/audiogen/evaluation': 'objective_eval', - 'execute_only': 'evaluate', - '+dataset.evaluate.batch_size': batch_size, - '+metrics.fad.tf.batch_size': 32, - } - # binary for FAD computation: replace this path with your own path - metrics_opts = { - 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research' - } - opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.} - opt2 = {'transformer_lm.two_step_cfg': True} - - sub = launcher.bind(opts) - sub.bind_(metrics_opts) - - # base objective metrics - sub(opt1, opt2) - - -@GenerationEvalExplorer -def explorer(launcher): - partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) - launcher.slurm_(gpus=4, partition=partitions) - - if 'REGEN' not in os.environ: - folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1] - with launcher.job_array(): - for sig in folder.iterdir(): - if not sig.is_symlink(): - continue - xp = train.main.get_xp_from_sig(sig.name) - launcher(xp.argv) - return - - audiogen_base = launcher.bind(solver="audiogen/audiogen_base_16khz") - audiogen_base.bind_({'autocast': False, 'fsdp.use': True}) - - audiogen_base_medium = audiogen_base.bind({'continue_from': '//pretrained/facebook/audiogen-medium'}) - audiogen_base_medium.bind_({'model/lm/model_scale': 'medium'}) - eval(audiogen_base_medium, batch_size=128) diff --git a/audiocraft/audiocraft/grids/compression/__init__.py b/audiocraft/audiocraft/grids/compression/__init__.py deleted file mode 100644 index 5b688528f1f3e4efc0c2a1e9d490f33c4158b3f0..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/compression/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""EnCodec grids.""" diff --git a/audiocraft/audiocraft/grids/compression/_explorers.py b/audiocraft/audiocraft/grids/compression/_explorers.py deleted file mode 100644 index eed30d5b8a1c14676503148ddf133c79ed2e33bf..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/compression/_explorers.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import treetable as tt - -from .._base_explorers import BaseExplorer - - -class CompressionExplorer(BaseExplorer): - eval_metrics = ["sisnr", "visqol"] - - def stages(self): - return ["train", "valid", "evaluate"] - - def get_grid_meta(self): - """Returns the list of Meta information to display for each XP/job. - """ - return [ - tt.leaf("index", align=">"), - tt.leaf("name", wrap=140), - tt.leaf("state"), - tt.leaf("sig", align=">"), - ] - - def get_grid_metrics(self): - """Return the metrics that should be displayed in the tracking table. - """ - return [ - tt.group( - "train", - [ - tt.leaf("epoch"), - tt.leaf("bandwidth", ".2f"), - tt.leaf("adv", ".4f"), - tt.leaf("d_loss", ".4f"), - ], - align=">", - ), - tt.group( - "valid", - [ - tt.leaf("bandwidth", ".2f"), - tt.leaf("adv", ".4f"), - tt.leaf("msspec", ".4f"), - tt.leaf("sisnr", ".2f"), - ], - align=">", - ), - tt.group( - "evaluate", [tt.leaf(name, ".3f") for name in self.eval_metrics], align=">" - ), - ] diff --git a/audiocraft/audiocraft/grids/compression/debug.py b/audiocraft/audiocraft/grids/compression/debug.py deleted file mode 100644 index 5612ff5688d85fede0e605b244919e8081cb1da9..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/compression/debug.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Grid search file, simply list all the exp you want in `explorer`. -Any new exp added there will be scheduled. -You can cancel and experiment by commenting its line. - -This grid is a minimal example for debugging compression task -and how to override parameters directly in a grid. -Learn more about dora grids: https://github.com/facebookresearch/dora -""" - -from ._explorers import CompressionExplorer -from ...environment import AudioCraftEnvironment - - -@CompressionExplorer -def explorer(launcher): - partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) - launcher.slurm_(gpus=2, partition=partitions) - launcher.bind_(solver='compression/debug') - - with launcher.job_array(): - # base debug task using config from solver=compression/debug - launcher() - # we can override parameters in the grid to launch additional xps - launcher({'rvq.bins': 2048, 'rvq.n_q': 4}) diff --git a/audiocraft/audiocraft/grids/compression/encodec_audiogen_16khz.py b/audiocraft/audiocraft/grids/compression/encodec_audiogen_16khz.py deleted file mode 100644 index c9b41f684045594bb264cfb7f4f15d1da439382c..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/compression/encodec_audiogen_16khz.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Grid search file, simply list all the exp you want in `explorer`. -Any new exp added there will be scheduled. -You can cancel and experiment by commenting its line. - -This grid shows how to train the new AudioGen EnCodec model at 16 kHz. -""" - -from ._explorers import CompressionExplorer -from ...environment import AudioCraftEnvironment - - -@CompressionExplorer -def explorer(launcher): - partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) - launcher.slurm_(gpus=8, partition=partitions) - # use configuration for AudioGen's EnCodec model trained on monophonic audio sampled at 16 kHz - # AudioGen's EnCodec is trained with a total stride of 320 leading to a frame rate of 50 hz - launcher.bind_(solver='compression/encodec_audiogen_16khz') - # replace this by the desired sound dataset - launcher.bind_(dset='internal/sounds_16khz') - # launch xp - launcher() diff --git a/audiocraft/audiocraft/grids/compression/encodec_base_24khz.py b/audiocraft/audiocraft/grids/compression/encodec_base_24khz.py deleted file mode 100644 index 117b2b1e496ca31b3d614672b472c9213cedb4ad..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/compression/encodec_base_24khz.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Grid search file, simply list all the exp you want in `explorer`. -Any new exp added there will be scheduled. -You can cancel and experiment by commenting its line. - -This grid shows how to train a base causal EnCodec model at 24 kHz. -""" - -from ._explorers import CompressionExplorer -from ...environment import AudioCraftEnvironment - - -@CompressionExplorer -def explorer(launcher): - partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) - launcher.slurm_(gpus=8, partition=partitions) - # base causal EnCodec trained on monophonic audio sampled at 24 kHz - launcher.bind_(solver='compression/encodec_base_24khz') - # replace this by the desired dataset - launcher.bind_(dset='audio/example') - # launch xp - launcher() diff --git a/audiocraft/audiocraft/grids/compression/encodec_musicgen_32khz.py b/audiocraft/audiocraft/grids/compression/encodec_musicgen_32khz.py deleted file mode 100644 index 9da31daa5f009f46e753601a51a06391594b8f9b..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/compression/encodec_musicgen_32khz.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Grid search file, simply list all the exp you want in `explorer`. -Any new exp added there will be scheduled. -You can cancel and experiment by commenting its line. - -This grid shows how to train a MusicGen EnCodec model at 32 kHz. -""" - -from ._explorers import CompressionExplorer -from ...environment import AudioCraftEnvironment - - -@CompressionExplorer -def explorer(launcher): - partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) - launcher.slurm_(gpus=8, partition=partitions) - # use configuration for MusicGen's EnCodec model trained on monophonic audio sampled at 32 kHz - # MusicGen's EnCodec is trained with a total stride of 640 leading to a frame rate of 50 hz - launcher.bind_(solver='compression/encodec_musicgen_32khz') - # replace this by the desired music dataset - launcher.bind_(dset='internal/music_400k_32khz') - # launch xp - launcher() - launcher({ - 'metrics.visqol.bin': '/data/home/jadecopet/local/usr/opt/visqol', - 'label': 'visqol', - 'evaluate.metrics.visqol': True - }) diff --git a/audiocraft/audiocraft/grids/diffusion/4_bands_base_32khz.py b/audiocraft/audiocraft/grids/diffusion/4_bands_base_32khz.py deleted file mode 100644 index f7e67bcc89dd0c8e50d770e600b55f179fe19588..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/diffusion/4_bands_base_32khz.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Training of the 4 diffusion models described in -"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion" -(paper link). -""" - -from ._explorers import DiffusionExplorer - - -@DiffusionExplorer -def explorer(launcher): - launcher.slurm_(gpus=4, partition='learnfair') - - launcher.bind_({'solver': 'diffusion/default', - 'dset': 'internal/music_10k_32khz'}) - - with launcher.job_array(): - launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4}) - launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4}) - launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4}) - launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75}) diff --git a/audiocraft/audiocraft/grids/diffusion/__init__.py b/audiocraft/audiocraft/grids/diffusion/__init__.py deleted file mode 100644 index e5737294ae16c0de52085b8dcf6825c348f617e4..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/diffusion/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Diffusion grids.""" diff --git a/audiocraft/audiocraft/grids/diffusion/_explorers.py b/audiocraft/audiocraft/grids/diffusion/_explorers.py deleted file mode 100644 index 0bf4ca57b63f5f9308bd1178ddbde5d8f06748e5..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/diffusion/_explorers.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import treetable as tt - -from .._base_explorers import BaseExplorer - - -class DiffusionExplorer(BaseExplorer): - eval_metrics = ["sisnr", "visqol"] - - def stages(self): - return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"] - - def get_grid_meta(self): - """Returns the list of Meta information to display for each XP/job. - """ - return [ - tt.leaf("index", align=">"), - tt.leaf("name", wrap=140), - tt.leaf("state"), - tt.leaf("sig", align=">"), - ] - - def get_grid_metrics(self): - """Return the metrics that should be displayed in the tracking table. - """ - return [ - tt.group( - "train", - [ - tt.leaf("epoch"), - tt.leaf("loss", ".3%"), - ], - align=">", - ), - tt.group( - "valid", - [ - tt.leaf("loss", ".3%"), - # tt.leaf("loss_0", ".3%"), - ], - align=">", - ), - tt.group( - "valid_ema", - [ - tt.leaf("loss", ".3%"), - # tt.leaf("loss_0", ".3%"), - ], - align=">", - ), - tt.group( - "evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"), - tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"), - tt.leaf("rvm_3", ".4f"), ], align=">" - ), - tt.group( - "evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"), - tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"), - tt.leaf("rvm_3", ".4f")], align=">" - ), - ] diff --git a/audiocraft/audiocraft/grids/musicgen/__init__.py b/audiocraft/audiocraft/grids/musicgen/__init__.py deleted file mode 100644 index d3f101f5a29ff85271e44e4f27545168a8f27baa..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/musicgen/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""MusicGen grids.""" diff --git a/audiocraft/audiocraft/grids/musicgen/_explorers.py b/audiocraft/audiocraft/grids/musicgen/_explorers.py deleted file mode 100644 index 334836b72559a120feb8a15eef3fe96ce88a4edb..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/musicgen/_explorers.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import typing as tp - -import treetable as tt - -from .._base_explorers import BaseExplorer - - -class LMExplorer(BaseExplorer): - eval_metrics: tp.List[str] = [] - - def stages(self) -> tp.List[str]: - return ['train', 'valid'] - - def get_grid_metrics(self): - """Return the metrics that should be displayed in the tracking table.""" - return [ - tt.group( - 'train', - [ - tt.leaf('epoch'), - tt.leaf('duration', '.1f'), # duration in minutes - tt.leaf('ping'), - tt.leaf('ce', '.4f'), # cross entropy - tt.leaf("ppl", '.3f'), # perplexity - ], - align='>', - ), - tt.group( - 'valid', - [ - tt.leaf('ce', '.4f'), - tt.leaf('ppl', '.3f'), - tt.leaf('best_ppl', '.3f'), - ], - align='>', - ), - ] - - def process_sheep(self, sheep, history): - parts = super().process_sheep(sheep, history) - - track_by = {'ppl': 'lower'} # values should be in ['lower', 'higher'] - best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()} - - def comparator(mode, a, b): - return a < b if mode == 'lower' else a > b - - for metrics in history: - for key, sub in metrics.items(): - for metric in track_by: - # for the validation set, keep track of best metrics (ppl in this example) - # this is so we can conveniently compare metrics between runs in the grid - if key == 'valid' and metric in sub and comparator( - track_by[metric], sub[metric], best_metrics[metric] - ): - best_metrics[metric] = sub[metric] - - if 'valid' in parts: - parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()}) - return parts - - -class GenerationEvalExplorer(BaseExplorer): - eval_metrics: tp.List[str] = [] - - def stages(self) -> tp.List[str]: - return ['evaluate'] - - def get_grid_metrics(self): - """Return the metrics that should be displayed in the tracking table.""" - return [ - tt.group( - 'evaluate', - [ - tt.leaf('epoch', '.3f'), - tt.leaf('duration', '.1f'), - tt.leaf('ping'), - tt.leaf('ce', '.4f'), - tt.leaf('ppl', '.3f'), - tt.leaf('fad', '.3f'), - tt.leaf('kld', '.3f'), - tt.leaf('text_consistency', '.3f'), - tt.leaf('chroma_cosine', '.3f'), - ], - align='>', - ), - ] diff --git a/audiocraft/audiocraft/grids/musicgen/musicgen_base_32khz.py b/audiocraft/audiocraft/grids/musicgen/musicgen_base_32khz.py deleted file mode 100644 index 4e364614537e426f21c18a2c2a9d94b3babce051..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/musicgen/musicgen_base_32khz.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from ._explorers import LMExplorer -from ...environment import AudioCraftEnvironment - - -@LMExplorer -def explorer(launcher): - partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) - launcher.slurm_(gpus=32, partition=partitions) - launcher.bind_(solver='musicgen/musicgen_base_32khz') - # replace this by the desired music dataset - launcher.bind_(dset='internal/music_400k_32khz') - - fsdp = {'autocast': False, 'fsdp.use': True} - medium = {'model/lm/model_scale': 'medium'} - large = {'model/lm/model_scale': 'large'} - - cfg_low = {'classifier_free_guidance.training_dropout': 0.2} - wd_low = {'conditioners.description.t5.word_dropout': 0.2} - - adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4} - - launcher.bind_(fsdp) - - launcher.slurm_(gpus=32).bind_(label='32gpus') - with launcher.job_array(): - sub = launcher.bind() - sub() - - launcher.slurm_(gpus=64).bind_(label='64gpus') - with launcher.job_array(): - sub = launcher.bind() - sub(medium, adam) - - launcher.slurm_(gpus=96).bind_(label='96gpus') - with launcher.job_array(): - sub = launcher.bind() - sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3}) diff --git a/audiocraft/audiocraft/grids/musicgen/musicgen_base_cached_32khz.py b/audiocraft/audiocraft/grids/musicgen/musicgen_base_cached_32khz.py deleted file mode 100644 index d9a43f37d7369b5de4542fba87c4c8739d58b1e8..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/musicgen/musicgen_base_cached_32khz.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from ._explorers import LMExplorer -from ...environment import AudioCraftEnvironment - - -@LMExplorer -def explorer(launcher): - partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) - launcher.slurm_(gpus=32, partition=partitions) - launcher.bind_(solver='musicgen/musicgen_base_32khz') - # replace this by the desired music dataset - launcher.bind_(dset='internal/music_400k_32khz') - - fsdp = {'autocast': False, 'fsdp.use': True} - medium = {'model/lm/model_scale': 'medium'} - large = {'model/lm/model_scale': 'large'} - - cfg_low = {'classifier_free_guidance.training_dropout': 0.2} - wd_low = {'conditioners.description.t5.word_dropout': 0.2} - - adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4} - - # BEGINNING OF CACHE WRITING JOBS. - cache_write = { - 'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k', - 'cache.write': True, - 'generate.every': 500, - 'evaluate.every': 500, - 'logging.log_updates': 50, - } - - cache_sub = launcher.bind({'model/lm/model_scale': 'xsmall', 'conditioner': 'none'}) - cache_sub.bind_({'deadlock.use': True}) - cache_sub.slurm_(gpus=8) - with launcher.job_array(): - num_shards = 10 # total number of jobs running in parallel. - for shard in range(0, num_shards): - launcher(cache_write, {'cache.write_num_shards': num_shards, 'cache.write_shard': shard}) - - # REMOVE THE FOLLOWING RETURN STATEMENT ONCE THE ABOVE JOBS ARE DONE, - # OR SUFFICIENTLY AHEAD. - return - - cache = { - 'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k', - } - launcher.bind_(fsdp, cache) - - launcher.slurm_(gpus=32).bind_(label='32gpus') - with launcher.job_array(): - sub = launcher.bind() - sub() - - launcher.slurm_(gpus=64).bind_(label='64gpus') - with launcher.job_array(): - sub = launcher.bind() - sub(medium, adam) - - launcher.slurm_(gpus=96).bind_(label='96gpus') - with launcher.job_array(): - sub = launcher.bind() - sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3}) diff --git a/audiocraft/audiocraft/grids/musicgen/musicgen_clapemb_32khz.py b/audiocraft/audiocraft/grids/musicgen/musicgen_clapemb_32khz.py deleted file mode 100644 index 64ad3f8c77afe1ab5908e407ad14d4879e1b1ad1..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/musicgen/musicgen_clapemb_32khz.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from ._explorers import LMExplorer -from ...environment import AudioCraftEnvironment - - -@LMExplorer -def explorer(launcher): - partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) - launcher.slurm_(gpus=32, partition=partitions) - launcher.bind_(solver='musicgen/musicgen_base_32khz') - # replace this by the desired music dataset - launcher.bind_(dset='internal/music_400k_32khz') - launcher.bind_(conditioner='clapemb2music') - - fsdp = {'autocast': False, 'fsdp.use': True} - cache_path = {'conditioners.description.clap.cache_path': - '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/clap_embed_music'} - text_wav_training_opt = {'conditioners.description.clap.text_p': 0.5} - - launcher.bind_(fsdp) - - launcher.slurm_(gpus=32).bind_(label='32gpus') - with launcher.job_array(): - launcher() - launcher(text_wav_training_opt) - launcher(cache_path) - launcher(cache_path, text_wav_training_opt) diff --git a/audiocraft/audiocraft/grids/musicgen/musicgen_melody_32khz.py b/audiocraft/audiocraft/grids/musicgen/musicgen_melody_32khz.py deleted file mode 100644 index b0d6710a23c117406e9724057a62eccab88ce907..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/musicgen/musicgen_melody_32khz.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from ._explorers import LMExplorer -from ...environment import AudioCraftEnvironment - - -@LMExplorer -def explorer(launcher): - partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) - launcher.slurm_(gpus=32, partition=partitions) - launcher.bind_(solver='musicgen/musicgen_melody_32khz') - # replace this by the desired music dataset - launcher.bind_(dset='internal/music_400k_32khz') - - fsdp = {'autocast': False, 'fsdp.use': True} - medium = {'model/lm/model_scale': 'medium'} - large = {'model/lm/model_scale': 'large'} - - cfg_low = {'classifier_free_guidance.training_dropout': 0.2} - wd_low = {'conditioners.description.t5.word_dropout': 0.2} - - adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4} - - cache_path = {'conditioners.self_wav.chroma_stem.cache_path': - '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/chroma_stem'} - - # CACHE GENERATION JOBS - n_cache_gen_jobs = 4 - gen_sub = launcher.slurm(gpus=1) - gen_sub.bind_( - cache_path, { - # the cache is always computed over the whole file, so duration doesn't matter here. - 'dataset.segment_duration': 2., - 'dataset.batch_size': 8, - 'dataset.train.permutation_on_files': True, # try to not repeat files. - 'optim.epochs': 10, - 'model/lm/model_scale': 'xsmall', - - }) - with gen_sub.job_array(): - for gen_job in range(n_cache_gen_jobs): - gen_sub({'dataset.train.shuffle_seed': gen_job}) - - # ACTUAL TRAINING JOBS. - launcher.bind_(fsdp) - - launcher.slurm_(gpus=32).bind_(label='32gpus') - with launcher.job_array(): - sub = launcher.bind() - sub() - sub(cache_path) - - launcher.slurm_(gpus=64).bind_(label='64gpus') - with launcher.job_array(): - sub = launcher.bind() - sub(medium, adam) - - launcher.slurm_(gpus=96).bind_(label='96gpus') - with launcher.job_array(): - sub = launcher.bind() - sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3}) diff --git a/audiocraft/audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py b/audiocraft/audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py deleted file mode 100644 index 39ceaf7dab15ec3f0f669cfe57ca9e932a9ab40d..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Evaluation with objective metrics for the pretrained MusicGen models. -This grid takes signature from the training grid and runs evaluation-only stage. - -When running the grid for the first time, please use: -REGEN=1 dora grid musicgen.musicgen_pretrained_32khz_eval -and re-use the REGEN=1 option when the grid is changed to force regenerating it. - -Note that you need the proper metrics external libraries setup to use all -the objective metrics activated in this grid. Refer to the README for more information. -""" - -import os - -from ._explorers import GenerationEvalExplorer -from ...environment import AudioCraftEnvironment -from ... import train - - -def eval(launcher, batch_size: int = 32, eval_melody: bool = False): - opts = { - 'dset': 'audio/musiccaps_32khz', - 'solver/musicgen/evaluation': 'objective_eval', - 'execute_only': 'evaluate', - '+dataset.evaluate.batch_size': batch_size, - '+metrics.fad.tf.batch_size': 16, - } - # chroma-specific evaluation - chroma_opts = { - 'dset': 'internal/music_400k_32khz', - 'dataset.evaluate.segment_duration': 30, - 'dataset.evaluate.num_samples': 1000, - 'evaluate.metrics.chroma_cosine': True, - 'evaluate.metrics.fad': False, - 'evaluate.metrics.kld': False, - 'evaluate.metrics.text_consistency': False, - } - # binary for FAD computation: replace this path with your own path - metrics_opts = { - 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research' - } - opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.} - opt2 = {'transformer_lm.two_step_cfg': True} - - sub = launcher.bind(opts) - sub.bind_(metrics_opts) - - # base objective metrics - sub(opt1, opt2) - - if eval_melody: - # chroma-specific metrics - sub(opt1, opt2, chroma_opts) - - -@GenerationEvalExplorer -def explorer(launcher): - partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) - launcher.slurm_(gpus=4, partition=partitions) - - if 'REGEN' not in os.environ: - folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1] - with launcher.job_array(): - for sig in folder.iterdir(): - if not sig.is_symlink(): - continue - xp = train.main.get_xp_from_sig(sig.name) - launcher(xp.argv) - return - - with launcher.job_array(): - musicgen_base = launcher.bind(solver="musicgen/musicgen_base_32khz") - musicgen_base.bind_({'autocast': False, 'fsdp.use': True}) - - # base musicgen models - musicgen_base_small = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-small'}) - eval(musicgen_base_small, batch_size=128) - - musicgen_base_medium = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-medium'}) - musicgen_base_medium.bind_({'model/lm/model_scale': 'medium'}) - eval(musicgen_base_medium, batch_size=128) - - musicgen_base_large = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-large'}) - musicgen_base_large.bind_({'model/lm/model_scale': 'large'}) - eval(musicgen_base_large, batch_size=128) - - # melody musicgen model - musicgen_melody = launcher.bind(solver="musicgen/musicgen_melody_32khz") - musicgen_melody.bind_({'autocast': False, 'fsdp.use': True}) - - musicgen_melody_medium = musicgen_melody.bind({'continue_from': '//pretrained/facebook/musicgen-melody'}) - musicgen_melody_medium.bind_({'model/lm/model_scale': 'medium'}) - eval(musicgen_melody_medium, batch_size=128, eval_melody=True) diff --git a/audiocraft/audiocraft/losses/__init__.py b/audiocraft/audiocraft/losses/__init__.py deleted file mode 100644 index d55107b2c11822cab749ed3683cf19020802898a..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/losses/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Loss related classes and functions. In particular the loss balancer from -EnCodec, and the usual spectral losses.""" - -# flake8: noqa -from .balancer import Balancer -from .sisnr import SISNR -from .stftloss import ( - LogSTFTMagnitudeLoss, - MRSTFTLoss, - SpectralConvergenceLoss, - STFTLoss -) -from .specloss import ( - MelSpectrogramL1Loss, - MultiScaleMelSpectrogramLoss, -) diff --git a/audiocraft/audiocraft/losses/__pycache__/__init__.cpython-311.pyc b/audiocraft/audiocraft/losses/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 291f80307bdaf60691dccdb14f940574abe9b8d7..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/losses/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/losses/__pycache__/balancer.cpython-311.pyc b/audiocraft/audiocraft/losses/__pycache__/balancer.cpython-311.pyc deleted file mode 100644 index 010d41c0c3a20d6b230bc6fb75ae25aa2e2e7eeb..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/losses/__pycache__/balancer.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/losses/__pycache__/sisnr.cpython-311.pyc b/audiocraft/audiocraft/losses/__pycache__/sisnr.cpython-311.pyc deleted file mode 100644 index a8e3f0383c9fde685859b1a0e2948a6fa6dd4bfc..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/losses/__pycache__/sisnr.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/losses/__pycache__/specloss.cpython-311.pyc b/audiocraft/audiocraft/losses/__pycache__/specloss.cpython-311.pyc deleted file mode 100644 index 246667d116ba5e49f48ee9a508f47eced417c5aa..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/losses/__pycache__/specloss.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/losses/__pycache__/stftloss.cpython-311.pyc b/audiocraft/audiocraft/losses/__pycache__/stftloss.cpython-311.pyc deleted file mode 100644 index a99f58760fea3330a18e08e0ebdef59e5a41456b..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/losses/__pycache__/stftloss.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/losses/balancer.py b/audiocraft/audiocraft/losses/balancer.py deleted file mode 100644 index 8a0ac8adebab8cdee8f82351965195dc02800d18..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/losses/balancer.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import typing as tp - -import flashy -import torch -from torch import autograd - - -class Balancer: - """Loss balancer. - - The loss balancer combines losses together to compute gradients for the backward. - Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...` - not having any dependence on `f`, the balancer can efficiently normalize the partial gradients - `d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between - the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient - going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy - interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown. - - Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be - (with `avg` an exponential moving average over the updates), - - G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i) - - If `balance_grads` is False, this is deactivated, and instead the gradient will just be the - standard sum of the partial gradients with the given weights. - - A call to the backward method of the balancer will compute the the partial gradients, - combining all the losses and potentially rescaling the gradients, - which can help stabilize the training and reason about multiple losses with varying scales. - The obtained gradient with respect to `y` is then back-propagated to `f(...)`. - - Expected usage: - - weights = {'loss_a': 1, 'loss_b': 4} - balancer = Balancer(weights, ...) - losses: dict = {} - losses['loss_a'] = compute_loss_a(x, y) - losses['loss_b'] = compute_loss_b(x, y) - if model.training(): - effective_loss = balancer.backward(losses, x) - - Args: - weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys - from the backward method to match the weights keys to assign weight to each of the provided loss. - balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the - overall gradient, rather than a constant multiplier. - total_norm (float): Reference norm when rescaling gradients, ignored otherwise. - emay_decay (float): EMA decay for averaging the norms. - per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds - when rescaling the gradients. - epsilon (float): Epsilon value for numerical stability. - monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients - coming from each loss, when calling `backward()`. - """ - def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1., - ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12, - monitor: bool = False): - self.weights = weights - self.per_batch_item = per_batch_item - self.total_norm = total_norm or 1. - self.averager = flashy.averager(ema_decay or 1.) - self.epsilon = epsilon - self.monitor = monitor - self.balance_grads = balance_grads - self._metrics: tp.Dict[str, tp.Any] = {} - - @property - def metrics(self): - return self._metrics - - def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor: - """Compute the backward and return the effective train loss, e.g. the loss obtained from - computing the effective weights. If `balance_grads` is True, the effective weights - are the one that needs to be applied to each gradient to respect the desired relative - scale of gradients coming from each loss. - - Args: - losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`. - input (torch.Tensor): the input of the losses, typically the output of the model. - This should be the single point of dependence between the losses - and the model being trained. - """ - norms = {} - grads = {} - for name, loss in losses.items(): - # Compute partial derivative of the less with respect to the input. - grad, = autograd.grad(loss, [input], retain_graph=True) - if self.per_batch_item: - # We do not average the gradient over the batch dimension. - dims = tuple(range(1, grad.dim())) - norm = grad.norm(dim=dims, p=2).mean() - else: - norm = grad.norm(p=2) - norms[name] = norm - grads[name] = grad - - count = 1 - if self.per_batch_item: - count = len(grad) - # Average norms across workers. Theoretically we should average the - # squared norm, then take the sqrt, but it worked fine like that. - avg_norms = flashy.distrib.average_metrics(self.averager(norms), count) - # We approximate the total norm of the gradient as the sums of the norms. - # Obviously this can be very incorrect if all gradients are aligned, but it works fine. - total = sum(avg_norms.values()) - - self._metrics = {} - if self.monitor: - # Store the ratio of the total gradient represented by each loss. - for k, v in avg_norms.items(): - self._metrics[f'ratio_{k}'] = v / total - - total_weights = sum([self.weights[k] for k in avg_norms]) - assert total_weights > 0. - desired_ratios = {k: w / total_weights for k, w in self.weights.items()} - - out_grad = torch.zeros_like(input) - effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype) - for name, avg_norm in avg_norms.items(): - if self.balance_grads: - # g_balanced = g / avg(||g||) * total_norm * desired_ratio - scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm) - else: - # We just do regular weighted sum of the gradients. - scale = self.weights[name] - out_grad.add_(grads[name], alpha=scale) - effective_loss += scale * losses[name].detach() - # Send the computed partial derivative with respect to the output of the model to the model. - input.backward(out_grad) - return effective_loss diff --git a/audiocraft/audiocraft/losses/sisnr.py b/audiocraft/audiocraft/losses/sisnr.py deleted file mode 100644 index 30f1fa1de9aca22758b6665609a1eacc0bd992ca..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/losses/sisnr.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import math -import typing as tp - -import torch -from torch import nn -from torch.nn import functional as F - - -def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor: - """Given input of size [*OT, T], output Tensor of size [*OT, F, K] - with K the kernel size, by extracting frames with the given stride. - This will pad the input so that `F = ceil(T / K)`. - see https://github.com/pytorch/pytorch/issues/60466 - """ - *shape, length = a.shape - n_frames = math.ceil(length / stride) - tgt_length = (n_frames - 1) * stride + kernel_size - a = F.pad(a, (0, tgt_length - length)) - strides = list(a.stride()) - assert strides[-1] == 1, "data should be contiguous" - strides = strides[:-1] + [stride, 1] - return a.as_strided([*shape, n_frames, kernel_size], strides) - - -def _center(x: torch.Tensor) -> torch.Tensor: - return x - x.mean(-1, True) - - -def _norm2(x: torch.Tensor) -> torch.Tensor: - return x.pow(2).sum(-1, True) - - -class SISNR(nn.Module): - """SISNR loss. - - Input should be [B, C, T], output is scalar. - - Args: - sample_rate (int): Sample rate. - segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on - entire audio only. - overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap. - epsilon (float): Epsilon value for numerical stability. - """ - def __init__( - self, - sample_rate: int = 16000, - segment: tp.Optional[float] = 20, - overlap: float = 0.5, - epsilon: float = torch.finfo(torch.float32).eps, - ): - super().__init__() - self.sample_rate = sample_rate - self.segment = segment - self.overlap = overlap - self.epsilon = epsilon - - def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor: - B, C, T = ref_sig.shape - assert ref_sig.shape == out_sig.shape - - if self.segment is None: - frame = T - stride = T - else: - frame = int(self.segment * self.sample_rate) - stride = int(frame * (1 - self.overlap)) - - epsilon = self.epsilon * frame # make epsilon prop to frame size. - - gt = _unfold(ref_sig, frame, stride) - est = _unfold(out_sig, frame, stride) - if self.segment is None: - assert gt.shape[-1] == 1 - - gt = _center(gt) - est = _center(est) - dot = torch.einsum("bcft,bcft->bcf", gt, est) - - proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt)) - noise = est - proj - - sisnr = 10 * ( - torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise)) - ) - return -1 * sisnr[..., 0].mean() diff --git a/audiocraft/audiocraft/losses/specloss.py b/audiocraft/audiocraft/losses/specloss.py deleted file mode 100644 index 11f2eb3e5c44b542a02f13db64bfb22fa0d3d212..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/losses/specloss.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import typing as tp - -import numpy as np -from torchaudio.transforms import MelSpectrogram -import torch -from torch import nn -from torch.nn import functional as F - -from ..modules import pad_for_conv1d - - -class MelSpectrogramWrapper(nn.Module): - """Wrapper around MelSpectrogram torchaudio transform providing proper padding - and additional post-processing including log scaling. - - Args: - n_mels (int): Number of mel bins. - n_fft (int): Number of fft. - hop_length (int): Hop size. - win_length (int): Window length. - n_mels (int): Number of mel bins. - sample_rate (int): Sample rate. - f_min (float or None): Minimum frequency. - f_max (float or None): Maximum frequency. - log (bool): Whether to scale with log. - normalized (bool): Whether to normalize the melspectrogram. - floor_level (float): Floor level based on human perception (default=1e-5). - """ - def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_length: tp.Optional[int] = None, - n_mels: int = 80, sample_rate: float = 22050, f_min: float = 0.0, f_max: tp.Optional[float] = None, - log: bool = True, normalized: bool = False, floor_level: float = 1e-5): - super().__init__() - self.n_fft = n_fft - hop_length = int(hop_length) - self.hop_length = hop_length - self.mel_transform = MelSpectrogram(n_mels=n_mels, sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, - win_length=win_length, f_min=f_min, f_max=f_max, normalized=normalized, - window_fn=torch.hann_window, center=False) - self.floor_level = floor_level - self.log = log - - def forward(self, x): - p = int((self.n_fft - self.hop_length) // 2) - if len(x.shape) == 2: - x = x.unsqueeze(1) - x = F.pad(x, (p, p), "reflect") - # Make sure that all the frames are full. - # The combination of `pad_for_conv1d` and the above padding - # will make the output of size ceil(T / hop). - x = pad_for_conv1d(x, self.n_fft, self.hop_length) - self.mel_transform.to(x.device) - mel_spec = self.mel_transform(x) - B, C, freqs, frame = mel_spec.shape - if self.log: - mel_spec = torch.log10(self.floor_level + mel_spec) - return mel_spec.reshape(B, C * freqs, frame) - - -class MelSpectrogramL1Loss(torch.nn.Module): - """L1 Loss on MelSpectrogram. - - Args: - sample_rate (int): Sample rate. - n_fft (int): Number of fft. - hop_length (int): Hop size. - win_length (int): Window length. - n_mels (int): Number of mel bins. - f_min (float or None): Minimum frequency. - f_max (float or None): Maximum frequency. - log (bool): Whether to scale with log. - normalized (bool): Whether to normalize the melspectrogram. - floor_level (float): Floor level value based on human perception (default=1e-5). - """ - def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, - n_mels: int = 80, f_min: float = 0.0, f_max: tp.Optional[float] = None, - log: bool = True, normalized: bool = False, floor_level: float = 1e-5): - super().__init__() - self.l1 = torch.nn.L1Loss() - self.melspec = MelSpectrogramWrapper(n_fft=n_fft, hop_length=hop_length, win_length=win_length, - n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, - log=log, normalized=normalized, floor_level=floor_level) - - def forward(self, x, y): - self.melspec.to(x.device) - s_x = self.melspec(x) - s_y = self.melspec(y) - return self.l1(s_x, s_y) - - -class MultiScaleMelSpectrogramLoss(nn.Module): - """Multi-Scale spectrogram loss (msspec). - - Args: - sample_rate (int): Sample rate. - range_start (int): Power of 2 to use for the first scale. - range_stop (int): Power of 2 to use for the last scale. - n_mels (int): Number of mel bins. - f_min (float): Minimum frequency. - f_max (float or None): Maximum frequency. - normalized (bool): Whether to normalize the melspectrogram. - alphas (bool): Whether to use alphas as coefficients or not. - floor_level (float): Floor level value based on human perception (default=1e-5). - """ - def __init__(self, sample_rate: int, range_start: int = 6, range_end: int = 11, - n_mels: int = 64, f_min: float = 0.0, f_max: tp.Optional[float] = None, - normalized: bool = False, alphas: bool = True, floor_level: float = 1e-5): - super().__init__() - l1s = list() - l2s = list() - self.alphas = list() - self.total = 0 - self.normalized = normalized - for i in range(range_start, range_end): - l1s.append( - MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i, - n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, - log=False, normalized=normalized, floor_level=floor_level)) - l2s.append( - MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i, - n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, - log=True, normalized=normalized, floor_level=floor_level)) - if alphas: - self.alphas.append(np.sqrt(2 ** i - 1)) - else: - self.alphas.append(1) - self.total += self.alphas[-1] + 1 - - self.l1s = nn.ModuleList(l1s) - self.l2s = nn.ModuleList(l2s) - - def forward(self, x, y): - loss = 0.0 - self.l1s.to(x.device) - self.l2s.to(x.device) - for i in range(len(self.alphas)): - s_x_1 = self.l1s[i](x) - s_y_1 = self.l1s[i](y) - s_x_2 = self.l2s[i](x) - s_y_2 = self.l2s[i](y) - loss += F.l1_loss(s_x_1, s_y_1) + self.alphas[i] * F.mse_loss(s_x_2, s_y_2) - if self.normalized: - loss = loss / self.total - return loss diff --git a/audiocraft/audiocraft/losses/stftloss.py b/audiocraft/audiocraft/losses/stftloss.py deleted file mode 100644 index 5ad4b7d3324ee5b0e6064b6f71cf8caf0fdc3be7..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/losses/stftloss.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# Adapted from MIT code under the original license -# Copyright 2019 Tomoki Hayashi -# MIT License (https://opensource.org/licenses/MIT) -import typing as tp - -import torch -from torch import nn -from torch.nn import functional as F - - -# TODO: Replace with torchaudio.STFT? -def _stft(x: torch.Tensor, fft_size: int, hop_length: int, win_length: int, - window: tp.Optional[torch.Tensor], normalized: bool) -> torch.Tensor: - """Perform STFT and convert to magnitude spectrogram. - - Args: - x: Input signal tensor (B, C, T). - fft_size (int): FFT size. - hop_length (int): Hop size. - win_length (int): Window length. - window (torch.Tensor or None): Window function type. - normalized (bool): Whether to normalize the STFT or not. - - Returns: - torch.Tensor: Magnitude spectrogram (B, C, #frames, fft_size // 2 + 1). - """ - B, C, T = x.shape - x_stft = torch.stft( - x.view(-1, T), fft_size, hop_length, win_length, window, - normalized=normalized, return_complex=True, - ) - x_stft = x_stft.view(B, C, *x_stft.shape[1:]) - real = x_stft.real - imag = x_stft.imag - - # NOTE(kan-bayashi): clamp is needed to avoid nan or inf - return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) - - -class SpectralConvergenceLoss(nn.Module): - """Spectral convergence loss. - """ - def __init__(self, epsilon: float = torch.finfo(torch.float32).eps): - super().__init__() - self.epsilon = epsilon - - def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor): - """Calculate forward propagation. - - Args: - x_mag: Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). - y_mag: Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). - Returns: - torch.Tensor: Spectral convergence loss value. - """ - return torch.norm(y_mag - x_mag, p="fro") / (torch.norm(y_mag, p="fro") + self.epsilon) - - -class LogSTFTMagnitudeLoss(nn.Module): - """Log STFT magnitude loss. - - Args: - epsilon (float): Epsilon value for numerical stability. - """ - def __init__(self, epsilon: float = torch.finfo(torch.float32).eps): - super().__init__() - self.epsilon = epsilon - - def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor): - """Calculate forward propagation. - - Args: - x_mag (torch.Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). - y_mag (torch.Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). - Returns: - torch.Tensor: Log STFT magnitude loss value. - """ - return F.l1_loss(torch.log(self.epsilon + y_mag), torch.log(self.epsilon + x_mag)) - - -class STFTLosses(nn.Module): - """STFT losses. - - Args: - n_fft (int): Size of FFT. - hop_length (int): Hop length. - win_length (int): Window length. - window (str): Window function type. - normalized (bool): Whether to use normalized STFT or not. - epsilon (float): Epsilon for numerical stability. - """ - def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600, - window: str = "hann_window", normalized: bool = False, - epsilon: float = torch.finfo(torch.float32).eps): - super().__init__() - self.n_fft = n_fft - self.hop_length = hop_length - self.win_length = win_length - self.normalized = normalized - self.register_buffer("window", getattr(torch, window)(win_length)) - self.spectral_convergenge_loss = SpectralConvergenceLoss(epsilon) - self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(epsilon) - - def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: - """Calculate forward propagation. - - Args: - x (torch.Tensor): Predicted signal (B, T). - y (torch.Tensor): Groundtruth signal (B, T). - Returns: - torch.Tensor: Spectral convergence loss value. - torch.Tensor: Log STFT magnitude loss value. - """ - x_mag = _stft(x, self.n_fft, self.hop_length, - self.win_length, self.window, self.normalized) # type: ignore - y_mag = _stft(y, self.n_fft, self.hop_length, - self.win_length, self.window, self.normalized) # type: ignore - sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) - mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) - - return sc_loss, mag_loss - - -class STFTLoss(nn.Module): - """Single Resolution STFT loss. - - Args: - n_fft (int): Nb of FFT. - hop_length (int): Hop length. - win_length (int): Window length. - window (str): Window function type. - normalized (bool): Whether to use normalized STFT or not. - epsilon (float): Epsilon for numerical stability. - factor_sc (float): Coefficient for the spectral loss. - factor_mag (float): Coefficient for the magnitude loss. - """ - def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600, - window: str = "hann_window", normalized: bool = False, - factor_sc: float = 0.1, factor_mag: float = 0.1, - epsilon: float = torch.finfo(torch.float32).eps): - super().__init__() - self.loss = STFTLosses(n_fft, hop_length, win_length, window, normalized, epsilon) - self.factor_sc = factor_sc - self.factor_mag = factor_mag - - def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: - """Calculate forward propagation. - - Args: - x (torch.Tensor): Predicted signal (B, T). - y (torch.Tensor): Groundtruth signal (B, T). - Returns: - torch.Tensor: Single resolution STFT loss. - """ - sc_loss, mag_loss = self.loss(x, y) - return self.factor_sc * sc_loss + self.factor_mag * mag_loss - - -class MRSTFTLoss(nn.Module): - """Multi resolution STFT loss. - - Args: - n_ffts (Sequence[int]): Sequence of FFT sizes. - hop_lengths (Sequence[int]): Sequence of hop sizes. - win_lengths (Sequence[int]): Sequence of window lengths. - window (str): Window function type. - factor_sc (float): Coefficient for the spectral loss. - factor_mag (float): Coefficient for the magnitude loss. - normalized (bool): Whether to use normalized STFT or not. - epsilon (float): Epsilon for numerical stability. - """ - def __init__(self, n_ffts: tp.Sequence[int] = [1024, 2048, 512], hop_lengths: tp.Sequence[int] = [120, 240, 50], - win_lengths: tp.Sequence[int] = [600, 1200, 240], window: str = "hann_window", - factor_sc: float = 0.1, factor_mag: float = 0.1, - normalized: bool = False, epsilon: float = torch.finfo(torch.float32).eps): - super().__init__() - assert len(n_ffts) == len(hop_lengths) == len(win_lengths) - self.stft_losses = torch.nn.ModuleList() - for fs, ss, wl in zip(n_ffts, hop_lengths, win_lengths): - self.stft_losses += [STFTLosses(fs, ss, wl, window, normalized, epsilon)] - self.factor_sc = factor_sc - self.factor_mag = factor_mag - - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Calculate forward propagation. - - Args: - x (torch.Tensor): Predicted signal (B, T). - y (torch.Tensor): Groundtruth signal (B, T). - Returns: - torch.Tensor: Multi resolution STFT loss. - """ - sc_loss = torch.Tensor([0.0]) - mag_loss = torch.Tensor([0.0]) - for f in self.stft_losses: - sc_l, mag_l = f(x, y) - sc_loss += sc_l - mag_loss += mag_l - sc_loss /= len(self.stft_losses) - mag_loss /= len(self.stft_losses) - - return self.factor_sc * sc_loss + self.factor_mag * mag_loss diff --git a/audiocraft/audiocraft/metrics/__init__.py b/audiocraft/audiocraft/metrics/__init__.py deleted file mode 100644 index 3474bdc4f1c88b21904d2a21ba077c93a8a70c8b..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/metrics/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Metrics like CLAP score, FAD, KLD, Visqol, Chroma similarity, etc. -""" -# flake8: noqa -from .clap_consistency import CLAPTextConsistencyMetric, TextConsistencyMetric -from .chroma_cosinesim import ChromaCosineSimilarityMetric -from .fad import FrechetAudioDistanceMetric -from .kld import KLDivergenceMetric, PasstKLDivergenceMetric -from .rvm import RelativeVolumeMel -from .visqol import ViSQOL diff --git a/audiocraft/audiocraft/metrics/__pycache__/__init__.cpython-311.pyc b/audiocraft/audiocraft/metrics/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 37df2fed20227e2b2b56aab3c7ac89ea25a35264..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/metrics/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/metrics/__pycache__/chroma_cosinesim.cpython-311.pyc b/audiocraft/audiocraft/metrics/__pycache__/chroma_cosinesim.cpython-311.pyc deleted file mode 100644 index 51a9ac4e76cbaf6b0702f181f280ba7241b8f1a5..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/metrics/__pycache__/chroma_cosinesim.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/metrics/__pycache__/clap_consistency.cpython-311.pyc b/audiocraft/audiocraft/metrics/__pycache__/clap_consistency.cpython-311.pyc deleted file mode 100644 index cfaa4d19047a3cd0807f1c02c5fd688634139f9a..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/metrics/__pycache__/clap_consistency.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/metrics/__pycache__/fad.cpython-311.pyc b/audiocraft/audiocraft/metrics/__pycache__/fad.cpython-311.pyc deleted file mode 100644 index 8f4a9e8b4ee3bdaea2ee0529cbd7dc98ff27047f..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/metrics/__pycache__/fad.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/metrics/__pycache__/kld.cpython-311.pyc b/audiocraft/audiocraft/metrics/__pycache__/kld.cpython-311.pyc deleted file mode 100644 index 513c56207cbf5af77c3491cee7749fcac8170af9..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/metrics/__pycache__/kld.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/metrics/__pycache__/rvm.cpython-311.pyc b/audiocraft/audiocraft/metrics/__pycache__/rvm.cpython-311.pyc deleted file mode 100644 index c8b2757666530770f23dd12368717da0e9c97a1a..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/metrics/__pycache__/rvm.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/metrics/__pycache__/visqol.cpython-311.pyc b/audiocraft/audiocraft/metrics/__pycache__/visqol.cpython-311.pyc deleted file mode 100644 index 885016eda4c3b75b47457a03bb767bf97a5d8be9..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/metrics/__pycache__/visqol.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/metrics/chroma_cosinesim.py b/audiocraft/audiocraft/metrics/chroma_cosinesim.py deleted file mode 100644 index 40c26081b803c2017fae1b6d7d086f0b0e074cef..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/metrics/chroma_cosinesim.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torchmetrics - -from ..data.audio_utils import convert_audio -from ..modules.chroma import ChromaExtractor - - -class ChromaCosineSimilarityMetric(torchmetrics.Metric): - """Chroma cosine similarity metric. - - This metric extracts a chromagram for a reference waveform and - a generated waveform and compares each frame using the cosine similarity - function. The output is the mean cosine similarity. - - Args: - sample_rate (int): Sample rate used by the chroma extractor. - n_chroma (int): Number of chroma used by the chroma extractor. - radix2_exp (int): Exponent for the chroma extractor. - argmax (bool): Whether the chroma extractor uses argmax. - eps (float): Epsilon for cosine similarity computation. - """ - def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-8): - super().__init__() - self.chroma_sample_rate = sample_rate - self.n_chroma = n_chroma - self.eps = eps - self.chroma_extractor = ChromaExtractor(sample_rate=self.chroma_sample_rate, n_chroma=self.n_chroma, - radix2_exp=radix2_exp, argmax=argmax) - self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum") - self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, targets: torch.Tensor, - sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: - """Compute cosine similarity between chromagrams and accumulate scores over the dataset.""" - if preds.size(0) == 0: - return - - assert preds.shape == targets.shape, ( - f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}") - assert preds.size(0) == sizes.size(0), ( - f"Number of items in preds ({preds.shape}) mismatch ", - f"with sizes ({sizes.shape})") - assert preds.size(0) == sample_rates.size(0), ( - f"Number of items in preds ({preds.shape}) mismatch ", - f"with sample_rates ({sample_rates.shape})") - assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch" - - device = self.weight.device - preds, targets = preds.to(device), targets.to(device) # type: ignore - sample_rate = sample_rates[0].item() - preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1) - targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1) - gt_chroma = self.chroma_extractor(targets) - gen_chroma = self.chroma_extractor(preds) - chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int() - for i in range(len(gt_chroma)): - t = int(chroma_lens[i].item()) - cosine_sim = torch.nn.functional.cosine_similarity( - gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps) - self.cosine_sum += cosine_sim.sum(dim=0) # type: ignore - self.weight += torch.tensor(t) # type: ignore - - def compute(self) -> float: - """Computes the average cosine similarty across all generated/target chromagrams pairs.""" - assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore - return (self.cosine_sum / self.weight).item() # type: ignore diff --git a/audiocraft/audiocraft/metrics/clap_consistency.py b/audiocraft/audiocraft/metrics/clap_consistency.py deleted file mode 100644 index d2a6c61ae177533ca2fb17e25bc77d2acbbe3791..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/metrics/clap_consistency.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from pathlib import Path -import typing as tp - -import torch -import torchmetrics -from transformers import RobertaTokenizer # type: ignore - -from ..data.audio_utils import convert_audio -from ..environment import AudioCraftEnvironment -from ..utils.utils import load_clap_state_dict - -try: - import laion_clap # type: ignore -except ImportError: - laion_clap = None - - -class TextConsistencyMetric(torchmetrics.Metric): - """Text consistency metric measuring consistency between audio and text pairs.""" - - def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: - raise NotImplementedError("implement how to update the metric from the audio and text pairs.") - - def compute(self): - raise NotImplementedError("implement how to compute the final metric score.") - - -class CLAPTextConsistencyMetric(TextConsistencyMetric): - """Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP). - - This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf) - or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf). - - As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the - similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as - well as the generated audio based on them, and define the MCC metric as the average cosine similarity - between these embeddings. - - Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP - """ - def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False): - super().__init__() - if laion_clap is None: - raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'") - self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum") - self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum") - self._initialize_model(model_path, model_arch, enable_fusion) - - def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool): - model_path = AudioCraftEnvironment.resolve_reference_path(model_path) - self.tokenize = RobertaTokenizer.from_pretrained('roberta-base') - self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch) - self.model_sample_rate = 48_000 - load_clap_state_dict(self.model, model_path) - self.model.eval() - - def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: - # we use the default params from CLAP module here as well - return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt") - - def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: - """Compute cosine similarity between audio and text pairs and accumulate scores over the dataset.""" - assert audio.size(0) == len(text), "Number of audio and text samples should match" - assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate" - sample_rate = int(sample_rates[0].item()) - # convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T] - audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1) - audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True) - text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True) - # cosine similarity between the text and the audio embedding - cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8) - self.cosine_sum += cosine_sim.sum(dim=0) - self.weight += torch.tensor(cosine_sim.size(0)) - - def compute(self): - """Computes the average cosine similarty across all audio/text pairs.""" - assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore - return (self.cosine_sum / self.weight).item() # type: ignore diff --git a/audiocraft/audiocraft/metrics/fad.py b/audiocraft/audiocraft/metrics/fad.py deleted file mode 100644 index de66138dbb14fd4246bbfe590bddfd5beaf1ed8c..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/metrics/fad.py +++ /dev/null @@ -1,329 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import logging -from pathlib import Path -import os -import subprocess -import tempfile -import typing as tp - -from audiocraft.data.audio import audio_write -from audiocraft.data.audio_utils import convert_audio -import flashy -import torch -import torchmetrics - -from ..environment import AudioCraftEnvironment - - -logger = logging.getLogger(__name__) - -VGGISH_SAMPLE_RATE = 16_000 -VGGISH_CHANNELS = 1 - - -class FrechetAudioDistanceMetric(torchmetrics.Metric): - """Fréchet Audio Distance computation based on official TensorFlow implementation from Google Research. - - From: D.C. Dowson & B.V. Landau The Fréchet distance between - multivariate normal distributions - https://doi.org/10.1016/0047-259X(82)90077-X - The Fréchet distance between two multivariate gaussians, - `X ~ N(mu_x, sigma_x)` and `Y ~ N(mu_y, sigma_y)`, is `d^2`. - d^2 = (mu_x - mu_y)^2 + Tr(sigma_x + sigma_y - 2 * sqrt(sigma_x*sigma_y)) - = (mu_x - mu_y)^2 + Tr(sigma_x) + Tr(sigma_y) - - 2 * Tr(sqrt(sigma_x*sigma_y))) - - To use this FAD computation metric, you need to have the proper Frechet Audio Distance tool setup - from: https://github.com/google-research/google-research/tree/master/frechet_audio_distance - We provide the below instructions as reference but we do not guarantee for further support - in frechet_audio_distance installation. This was tested with python 3.10, cuda 11.8, tensorflow 2.12.0. - - We recommend installing the frechet_audio_distance library in a dedicated env (e.g. conda). - - 1. Get the code and models following the repository instructions. We used the steps below: - git clone git@github.com:google-research/google-research.git - git clone git@github.com:tensorflow/models.git - mkdir google-research/tensorflow_models - touch google-research/tensorflow_models/__init__.py - cp -r models/research/audioset google-research/tensorflow_models/ - touch google-research/tensorflow_models/audioset/__init__.py - echo "from .vggish import mel_features, vggish_params, vggish_slim" > \ - google-research/tensorflow_models/audioset/__init__.py - # we can now remove the tensorflow models repository - # rm -r models - cd google-research - Follow the instructions to download the vggish checkpoint. AudioCraft base configuration - assumes it is placed in the AudioCraft reference dir. - - Note that we operate the following changes for the code to work with TensorFlow 2.X and python 3: - - Update xrange for range in: - https://github.com/google-research/google-research/blob/master/frechet_audio_distance/audioset_model.py - - Update `tf_record = tf.python_io.tf_record_iterator(filename).next()` to - `tf_record = tf.python_io.tf_record_iterator(filename).__next__()` in - https://github.com/google-research/google-research/blob/master/frechet_audio_distance/fad_utils.py - - Update `import vggish_params as params` to `from . import vggish_params as params` in: - https://github.com/tensorflow/models/blob/master/research/audioset/vggish/vggish_slim.py - - Add flag to provide a given batch size for running the AudioSet model in: - https://github.com/google-research/google-research/blob/master/frechet_audio_distance/create_embeddings_main.py - ``` - flags.DEFINE_integer('batch_size', 64, - 'Number of samples in the batch for AudioSet model.') - ``` - Ensure you pass the flag to the create_embeddings_beam.create_pipeline function, adding: - `batch_size=FLAGS.batch_size` to the provided parameters. - - 2. Follow instructions for the library installation and a valid TensorFlow installation - ``` - # e.g. instructions from: https://www.tensorflow.org/install/pip - conda install -c conda-forge cudatoolkit=11.8.0 - python3 -m pip install nvidia-cudnn-cu11==8.6.0.163 tensorflow==2.12.* - mkdir -p $CONDA_PREFIX/etc/conda/activate.d - echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' \ - >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh - echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib' \ - >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh - source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh - # Verify install: on a machine with GPU device - python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))" - ``` - - Now install frechet_audio_distance required dependencies: - ``` - # We assume we already have TensorFlow installed from the above steps - pip install apache-beam numpy scipy tf_slim - ``` - - Finally, follow remaining library instructions to ensure you have a working frechet_audio_distance setup - (you may want to specify --model_ckpt flag pointing to the model's path). - - 3. AudioCraft's FrechetAudioDistanceMetric requires 2 environment variables pointing to the python executable - and Tensorflow library path from the above installation steps: - export TF_PYTHON_EXE="" - export TF_LIBRARY_PATH="" - - e.g. assuming we have installed everything in a dedicated conda env - with python 3.10 that is currently active: - export TF_PYTHON_EXE="$CONDA_PREFIX/bin/python" - export TF_LIBRARY_PATH="$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/cudnn/lib" - - Finally you may want to export the following variable: - export TF_FORCE_GPU_ALLOW_GROWTH=true - See: https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth - - You can save those environment variables in your training conda env, when currently active: - `$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh` - e.g. assuming the env with TensorFlow and frechet_audio_distance install is named ac_eval, - and the training conda env is named audiocraft: - ``` - # activate training env - conda activate audiocraft - # get path to all envs - CONDA_ENV_DIR=$(dirname $CONDA_PREFIX) - # export pointers to evaluation env for using TensorFlow in FrechetAudioDistanceMetric - touch $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh - echo 'export TF_PYTHON_EXE="$CONDA_ENV_DIR/ac_eval/bin/python"' >> \ - $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh - echo 'export TF_LIBRARY_PATH="$CONDA_ENV_DIR/ac_eval/lib/python3.10/site-packages/nvidia/cudnn/lib"' >> \ - $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh - # optionally: - echo 'export TF_FORCE_GPU_ALLOW_GROWTH=true' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh - # you may need to reactivate the audiocraft env for this to take effect - ``` - - Args: - bin (Path or str): Path to installed frechet audio distance code. - model_path (Path or str): Path to Tensorflow checkpoint for the model - used to compute statistics over the embedding beams. - format (str): Audio format used to save files. - log_folder (Path or str, optional): Path where to write process logs. - """ - def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path, str], - format: str = "wav", batch_size: tp.Optional[int] = None, - log_folder: tp.Optional[tp.Union[Path, str]] = None): - super().__init__() - self.model_sample_rate = VGGISH_SAMPLE_RATE - self.model_channels = VGGISH_CHANNELS - self.model_path = AudioCraftEnvironment.resolve_reference_path(model_path) - assert Path(self.model_path).exists(), f"Could not find provided model checkpoint path at: {self.model_path}" - self.format = format - self.batch_size = batch_size - self.bin = bin - self.tf_env = {"PYTHONPATH": str(self.bin)} - self.python_path = os.environ.get('TF_PYTHON_EXE') or 'python' - logger.info("Python exe for TF is %s", self.python_path) - if 'TF_LIBRARY_PATH' in os.environ: - self.tf_env['LD_LIBRARY_PATH'] = os.environ['TF_LIBRARY_PATH'] - if 'TF_FORCE_GPU_ALLOW_GROWTH' in os.environ: - self.tf_env['TF_FORCE_GPU_ALLOW_GROWTH'] = os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] - logger.info("Env for TF is %r", self.tf_env) - self.reset(log_folder) - self.add_state("total_files", default=torch.tensor(0.), dist_reduce_fx="sum") - - def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None): - """Reset torchmetrics.Metrics state.""" - log_folder = Path(log_folder or tempfile.mkdtemp()) - self.tmp_dir = log_folder / 'fad' - self.tmp_dir.mkdir(exist_ok=True) - self.samples_tests_dir = self.tmp_dir / 'tests' - self.samples_tests_dir.mkdir(exist_ok=True) - self.samples_background_dir = self.tmp_dir / 'background' - self.samples_background_dir.mkdir(exist_ok=True) - self.manifest_tests = self.tmp_dir / 'files_tests.cvs' - self.manifest_background = self.tmp_dir / 'files_background.cvs' - self.stats_tests_dir = self.tmp_dir / 'stats_tests' - self.stats_background_dir = self.tmp_dir / 'stats_background' - self.counter = 0 - - def update(self, preds: torch.Tensor, targets: torch.Tensor, - sizes: torch.Tensor, sample_rates: torch.Tensor, - stems: tp.Optional[tp.List[str]] = None): - """Update torchmetrics.Metrics by saving the audio and updating the manifest file.""" - assert preds.shape == targets.shape, f"preds={preds.shape} != targets={targets.shape}" - num_samples = preds.shape[0] - assert num_samples == sizes.size(0) and num_samples == sample_rates.size(0) - assert stems is None or num_samples == len(set(stems)) - for i in range(num_samples): - self.total_files += 1 # type: ignore - self.counter += 1 - wav_len = int(sizes[i].item()) - sample_rate = int(sample_rates[i].item()) - pred_wav = preds[i] - target_wav = targets[i] - pred_wav = pred_wav[..., :wav_len] - target_wav = target_wav[..., :wav_len] - stem_name = stems[i] if stems is not None else f'sample_{self.counter}_{flashy.distrib.rank()}' - # dump audio files - try: - pred_wav = convert_audio( - pred_wav.unsqueeze(0), from_rate=sample_rate, - to_rate=self.model_sample_rate, to_channels=1).squeeze(0) - audio_write( - self.samples_tests_dir / stem_name, pred_wav, sample_rate=self.model_sample_rate, - format=self.format, strategy="peak") - except Exception as e: - logger.error(f"Exception occured when saving tests files for FAD computation: {repr(e)} - {e}") - try: - # for the ground truth audio, we enforce the 'peak' strategy to avoid modifying - # the original audio when writing it - target_wav = convert_audio( - target_wav.unsqueeze(0), from_rate=sample_rate, - to_rate=self.model_sample_rate, to_channels=1).squeeze(0) - audio_write( - self.samples_background_dir / stem_name, target_wav, sample_rate=self.model_sample_rate, - format=self.format, strategy="peak") - except Exception as e: - logger.error(f"Exception occured when saving background files for FAD computation: {repr(e)} - {e}") - - def _get_samples_name(self, is_background: bool): - return 'background' if is_background else 'tests' - - def _create_embedding_beams(self, is_background: bool, gpu_index: tp.Optional[int] = None): - if is_background: - input_samples_dir = self.samples_background_dir - input_filename = self.manifest_background - stats_name = self.stats_background_dir - else: - input_samples_dir = self.samples_tests_dir - input_filename = self.manifest_tests - stats_name = self.stats_tests_dir - beams_name = self._get_samples_name(is_background) - log_file = self.tmp_dir / f'fad_logs_create_beams_{beams_name}.log' - - logger.info(f"Scanning samples folder to fetch list of files: {input_samples_dir}") - with open(input_filename, "w") as fout: - for path in Path(input_samples_dir).glob(f"*.{self.format}"): - fout.write(f"{str(path)}\n") - - cmd = [ - self.python_path, "-m", - "frechet_audio_distance.create_embeddings_main", - "--model_ckpt", f"{self.model_path}", - "--input_files", f"{str(input_filename)}", - "--stats", f"{str(stats_name)}", - ] - if self.batch_size is not None: - cmd += ["--batch_size", str(self.batch_size)] - logger.info(f"Launching frechet_audio_distance embeddings main method: {' '.join(cmd)} on {beams_name}") - env = os.environ - if gpu_index is not None: - env["CUDA_VISIBLE_DEVICES"] = str(gpu_index) - process = subprocess.Popen( - cmd, stdout=open(log_file, "w"), env={**env, **self.tf_env}, stderr=subprocess.STDOUT) - return process, log_file - - def _compute_fad_score(self, gpu_index: tp.Optional[int] = None): - cmd = [ - self.python_path, "-m", "frechet_audio_distance.compute_fad", - "--test_stats", f"{str(self.stats_tests_dir)}", - "--background_stats", f"{str(self.stats_background_dir)}", - ] - logger.info(f"Launching frechet_audio_distance compute fad method: {' '.join(cmd)}") - env = os.environ - if gpu_index is not None: - env["CUDA_VISIBLE_DEVICES"] = str(gpu_index) - result = subprocess.run(cmd, env={**env, **self.tf_env}, capture_output=True) - if result.returncode: - logger.error( - "Error with FAD computation from stats: \n %s \n %s", - result.stdout.decode(), result.stderr.decode() - ) - raise RuntimeError("Error while executing FAD computation from stats") - try: - # result is "FAD: (d+).(d+)" hence we remove the prefix with (d+) being one digit or more - fad_score = float(result.stdout[4:]) - return fad_score - except Exception as e: - raise RuntimeError(f"Error parsing FAD score from command stdout: {e}") - - def _log_process_result(self, returncode: int, log_file: tp.Union[Path, str], is_background: bool) -> None: - beams_name = self._get_samples_name(is_background) - if returncode: - with open(log_file, "r") as f: - error_log = f.read() - logger.error(error_log) - os._exit(1) - else: - logger.info(f"Successfully computed embedding beams on {beams_name} samples.") - - def _parallel_create_embedding_beams(self, num_of_gpus: int): - assert num_of_gpus > 0 - logger.info("Creating embeddings beams in a parallel manner on different GPUs") - tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False, gpu_index=0) - bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True, gpu_index=1) - tests_beams_code = tests_beams_process.wait() - bg_beams_code = bg_beams_process.wait() - self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False) - self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True) - - def _sequential_create_embedding_beams(self): - logger.info("Creating embeddings beams in a sequential manner") - tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False) - tests_beams_code = tests_beams_process.wait() - self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False) - bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True) - bg_beams_code = bg_beams_process.wait() - self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True) - - @flashy.distrib.rank_zero_only - def _local_compute_frechet_audio_distance(self): - """Compute Frechet Audio Distance score calling TensorFlow API.""" - num_of_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 - if num_of_gpus > 1: - self._parallel_create_embedding_beams(num_of_gpus) - else: - self._sequential_create_embedding_beams() - fad_score = self._compute_fad_score(gpu_index=0) - return fad_score - - def compute(self) -> float: - """Compute metrics.""" - assert self.total_files.item() > 0, "No files dumped for FAD computation!" # type: ignore - fad_score = self._local_compute_frechet_audio_distance() - logger.warning(f"FAD score = {fad_score}") - fad_score = flashy.distrib.broadcast_object(fad_score, src=0) - return fad_score diff --git a/audiocraft/audiocraft/metrics/kld.py b/audiocraft/audiocraft/metrics/kld.py deleted file mode 100644 index ebbbcda09b0419be4d51ae6698292ff7221e47e6..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/metrics/kld.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import contextlib -from functools import partial -import logging -import os -import typing as tp - -import torch -import torchmetrics - -from ..data.audio_utils import convert_audio - - -logger = logging.getLogger(__name__) - - -class _patch_passt_stft: - """Decorator to patch torch.stft in PaSST.""" - def __init__(self): - self.old_stft = torch.stft - - def __enter__(self): - # return_complex is a mandatory parameter in latest torch versions - # torch is throwing RuntimeErrors when not set - torch.stft = partial(torch.stft, return_complex=False) - - def __exit__(self, *exc): - torch.stft = self.old_stft - - -def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor: - """Computes the elementwise KL-Divergence loss between probability distributions - from generated samples and target samples. - - Args: - pred_probs (torch.Tensor): Probabilities for each label obtained - from a classifier on generated audio. Expected shape is [B, num_classes]. - target_probs (torch.Tensor): Probabilities for each label obtained - from a classifier on target audio. Expected shape is [B, num_classes]. - epsilon (float): Epsilon value. - Returns: - kld (torch.Tensor): KLD loss between each generated sample and target pair. - """ - kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none") - return kl_div.sum(-1) - - -class KLDivergenceMetric(torchmetrics.Metric): - """Base implementation for KL Divergence metric. - - The KL divergence is measured between probability distributions - of class predictions returned by a pre-trained audio classification model. - When the KL-divergence is low, the generated audio is expected to - have similar acoustic characteristics as the reference audio, - according to the classifier. - """ - def __init__(self): - super().__init__() - self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum") - self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum") - self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum") - self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum") - - def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor, - sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]: - """Get model output given provided input tensor. - - Args: - x (torch.Tensor): Input audio tensor of shape [B, C, T]. - sizes (torch.Tensor): Actual audio sample length, of shape [B]. - sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. - Returns: - probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes]. - """ - raise NotImplementedError("implement method to extract label distributions from the model.") - - def update(self, preds: torch.Tensor, targets: torch.Tensor, - sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: - """Calculates running KL-Divergence loss between batches of audio - preds (generated) and target (ground-truth) - Args: - preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T]. - targets (torch.Tensor): Target samples to compare against, of shape [B, C, T]. - sizes (torch.Tensor): Actual audio sample length, of shape [B]. - sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. - """ - assert preds.shape == targets.shape - assert preds.size(0) > 0, "Cannot update the loss with empty tensors" - preds_probs = self._get_label_distribution(preds, sizes, sample_rates) - targets_probs = self._get_label_distribution(targets, sizes, sample_rates) - if preds_probs is not None and targets_probs is not None: - assert preds_probs.shape == targets_probs.shape - kld_scores = kl_divergence(preds_probs, targets_probs) - assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!" - self.kld_pq_sum += torch.sum(kld_scores) - kld_qp_scores = kl_divergence(targets_probs, preds_probs) - self.kld_qp_sum += torch.sum(kld_qp_scores) - self.weight += torch.tensor(kld_scores.size(0)) - - def compute(self) -> dict: - """Computes KL-Divergence across all evaluated pred/target pairs.""" - weight: float = float(self.weight.item()) # type: ignore - assert weight > 0, "Unable to compute with total number of comparisons <= 0" - logger.info(f"Computing KL divergence on a total of {weight} samples") - kld_pq = self.kld_pq_sum.item() / weight # type: ignore - kld_qp = self.kld_qp_sum.item() / weight # type: ignore - kld_both = kld_pq + kld_qp - return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both} - - -class PasstKLDivergenceMetric(KLDivergenceMetric): - """KL-Divergence metric based on pre-trained PASST classifier on AudioSet. - - From: PaSST: Efficient Training of Audio Transformers with Patchout - Paper: https://arxiv.org/abs/2110.05069 - Implementation: https://github.com/kkoutini/PaSST - - Follow instructions from the github repo: - ``` - pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt' - ``` - - Args: - pretrained_length (float, optional): Audio duration used for the pretrained model. - """ - def __init__(self, pretrained_length: tp.Optional[float] = None): - super().__init__() - self._initialize_model(pretrained_length) - - def _initialize_model(self, pretrained_length: tp.Optional[float] = None): - """Initialize underlying PaSST audio classifier.""" - model, sr, max_frames, min_frames = self._load_base_model(pretrained_length) - self.min_input_frames = min_frames - self.max_input_frames = max_frames - self.model_sample_rate = sr - self.model = model - self.model.eval() - self.model.to(self.device) - - def _load_base_model(self, pretrained_length: tp.Optional[float]): - """Load pretrained model from PaSST.""" - try: - if pretrained_length == 30: - from hear21passt.base30sec import get_basic_model # type: ignore - max_duration = 30 - elif pretrained_length == 20: - from hear21passt.base20sec import get_basic_model # type: ignore - max_duration = 20 - else: - from hear21passt.base import get_basic_model # type: ignore - # Original PASST was trained on AudioSet with 10s-long audio samples - max_duration = 10 - min_duration = 0.15 - min_duration = 0.15 - except ModuleNotFoundError: - raise ModuleNotFoundError( - "Please install hear21passt to compute KL divergence: ", - "pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'" - ) - model_sample_rate = 32_000 - max_input_frames = int(max_duration * model_sample_rate) - min_input_frames = int(min_duration * model_sample_rate) - with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f): - model = get_basic_model(mode='logits') - return model, model_sample_rate, max_input_frames, min_input_frames - - def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]: - """Process audio to feed to the pretrained model.""" - wav = wav.unsqueeze(0) - wav = wav[..., :wav_len] - wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1) - wav = wav.squeeze(0) - # we don't pad but return a list of audio segments as this otherwise affects the KLD computation - segments = torch.split(wav, self.max_input_frames, dim=-1) - valid_segments = [] - for s in segments: - # ignoring too small segments that are breaking the model inference - if s.size(-1) > self.min_input_frames: - valid_segments.append(s) - return [s[None] for s in valid_segments] - - def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor: - """Run the pretrained model and get the predictions.""" - assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}" - wav = wav.mean(dim=1) - # PaSST is printing a lot of garbage that we are not interested in - with open(os.devnull, "w") as f, contextlib.redirect_stdout(f): - with torch.no_grad(), _patch_passt_stft(): - logits = self.model(wav.to(self.device)) - probs = torch.softmax(logits, dim=-1) - return probs - - def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor, - sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]: - """Get model output given provided input tensor. - - Args: - x (torch.Tensor): Input audio tensor of shape [B, C, T]. - sizes (torch.Tensor): Actual audio sample length, of shape [B]. - sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. - Returns: - probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes]. - """ - all_probs: tp.List[torch.Tensor] = [] - for i, wav in enumerate(x): - sample_rate = int(sample_rates[i].item()) - wav_len = int(sizes[i].item()) - wav_segments = self._process_audio(wav, sample_rate, wav_len) - for segment in wav_segments: - probs = self._get_model_preds(segment).mean(dim=0) - all_probs.append(probs) - if len(all_probs) > 0: - return torch.stack(all_probs, dim=0) - else: - return None diff --git a/audiocraft/audiocraft/metrics/rvm.py b/audiocraft/audiocraft/metrics/rvm.py deleted file mode 100644 index 2047b6c8d5b1d58a67090b947e7e2666c3104eca..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/metrics/rvm.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import typing as tp -import torch -from torch import nn -import torchaudio - - -def db_to_scale(volume: tp.Union[float, torch.Tensor]): - return 10 ** (volume / 20) - - -def scale_to_db(scale: torch.Tensor, min_volume: float = -120): - min_scale = db_to_scale(min_volume) - return 20 * torch.log10(scale.clamp(min=min_scale)) - - -class RelativeVolumeMel(nn.Module): - """Relative volume melspectrogram measure. - - Computes a measure of distance over two mel spectrogram that is interpretable in terms - of decibels. Given `x_ref` and `x_est` two waveforms of shape `[*, T]`, it will - first renormalize both by the ground truth of `x_ref`. - - ..Warning:: This class returns the volume of the distortion at the spectrogram level, - e.g. low negative values reflects lower distortion levels. For a SNR (like reported - in the MultiBandDiffusion paper), just take `-rvm`. - - Then it computes the mel spectrogram `z_ref` and `z_est` and compute volume of the difference - relative to the volume of `z_ref` for each time-frequency bin. It further adds some limits, e.g. - clamping the values between -25 and 25 dB (controlled by `min_relative_volume` and `max_relative_volume`) - with the goal of avoiding the loss being dominated by parts where the reference is almost silent. - Indeed, volumes in dB can take unbounded values both towards -oo and +oo, which can make the final - average metric harder to interpret. Besides, anything below -30 dB of attenuation would sound extremely - good (for a neural network output, although sound engineers typically aim for much lower attenuations). - Similarly, anything above +30 dB would just be completely missing the target, and there is no point - in measuring by exactly how much it missed it. -25, 25 is a more conservative range, but also more - in line with what neural nets currently can achieve. - - For instance, a Relative Volume Mel (RVM) score of -10 dB means that on average, the delta between - the target and reference mel-spec is 10 dB lower than the reference mel-spec value. - - The metric can be aggregated over a given frequency band in order have different insights for - different region of the spectrum. `num_aggregated_bands` controls the number of bands. - - ..Warning:: While this function is optimized for interpretability, nothing was done to ensure it - is numerically stable when computing its gradient. We thus advise against using it as a training loss. - - Args: - sample_rate (int): Sample rate of the input audio. - n_mels (int): Number of mel bands to use. - n_fft (int): Number of frequency bins for the STFT. - hop_length (int): Hop length of the STFT and the mel-spectrogram. - min_relative_volume (float): The error `z_ref - z_est` volume is given relative to - the volume of `z_ref`. If error is smaller than -25 dB of `z_ref`, then it is clamped. - max_relative_volume (float): Same as `min_relative_volume` but clamping if the error is larger than that. - max_initial_gain (float): When rescaling the audio at the very beginning, we will limit the gain - to that amount, to avoid rescaling near silence. Given in dB. - min_activity_volume (float): When computing the reference level from `z_ref`, will clamp low volume - bins to that amount. This is effectively our "zero" level for the reference mel-spectrogram, - and anything below that will be considered equally. - num_aggregated_bands (int): Number of bands to keep when computing the average RVM value. - For instance, a value of 3 would give 3 scores, roughly for low, mid and high freqs. - """ - def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: int = 512, - hop_length: int = 128, min_relative_volume: float = -25, - max_relative_volume: float = 25, max_initial_gain: float = 25, - min_activity_volume: float = -25, - num_aggregated_bands: int = 4) -> None: - super().__init__() - self.melspec = torchaudio.transforms.MelSpectrogram( - n_mels=n_mels, n_fft=n_fft, hop_length=hop_length, - normalized=True, sample_rate=sample_rate, power=2) - self.min_relative_volume = min_relative_volume - self.max_relative_volume = max_relative_volume - self.max_initial_gain = max_initial_gain - self.min_activity_volume = min_activity_volume - self.num_aggregated_bands = num_aggregated_bands - - def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) -> tp.Dict[str, torch.Tensor]: - """Compute RVM metric between estimate and reference samples. - - Args: - estimate (torch.Tensor): Estimate sample. - ground_truth (torch.Tensor): Reference sample. - - Returns: - dict[str, torch.Tensor]: Metrics with keys `rvm` for the overall average, and `rvm_{k}` - for the RVM over the k-th band (k=0..num_aggregated_bands - 1). - """ - min_scale = db_to_scale(-self.max_initial_gain) - std = ground_truth.pow(2).mean().sqrt().clamp(min=min_scale) - z_gt = self.melspec(ground_truth / std).sqrt() - z_est = self.melspec(estimate / std).sqrt() - - delta = z_gt - z_est - ref_db = scale_to_db(z_gt, self.min_activity_volume) - delta_db = scale_to_db(delta.abs(), min_volume=-120) - relative_db = (delta_db - ref_db).clamp(self.min_relative_volume, self.max_relative_volume) - dims = list(range(relative_db.dim())) - dims.remove(dims[-2]) - losses_per_band = relative_db.mean(dim=dims) - aggregated = [chunk.mean() for chunk in losses_per_band.chunk(self.num_aggregated_bands, dim=0)] - metrics = {f'rvm_{index}': value for index, value in enumerate(aggregated)} - metrics['rvm'] = losses_per_band.mean() - return metrics diff --git a/audiocraft/audiocraft/metrics/visqol.py b/audiocraft/audiocraft/metrics/visqol.py deleted file mode 100644 index 44f4b0a2c3c6c726857db8386491823dd85dde51..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/metrics/visqol.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import csv -import json -import logging -from pathlib import Path -import tempfile -import typing as tp -import subprocess -import shutil - -import torch -import torchaudio - -logger = logging.getLogger(__name__) - - -class ViSQOL: - """ViSQOL wrapper to run ViSQOL from Python using a pre-installed binary. - - To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the - instructions available in the open source repository: https://github.com/google/visqol - - ViSQOL is capable of running in two modes: - - Audio Mode: - When running in audio mode, input signals must have a 48kHz sample rate. Input should be resampled to 48kHz. - Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison. - Audio mode uses support vector regression, with the maximum range at ~4.75. - - Speech Mode: - When running in speech mode, ViSQOL uses a wideband model. It therefore expects input sample rates of 16kHz. - Input should be resampled to 16kHz. - As part of the speech mode processing, a root mean square implementation for voice activity detection - is performed on the reference signal to determine what parts of the signal have voice activity and - should therefore be included in the comparison. The signal is normalized before performing the voice - activity detection. - Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison. - Speech mode is scaled to have a maximum MOS of 5.0 to match previous version behavior. - - For more details, check the guidelines: https://github.com/google/visqol#general-guidelines-for-input - - Args: - visqol_bin (str): Path to the ViSQOL binary. - mode (str): ViSQOL computation mode, expecting "audio" or "speech". - model (str): Name of the model to use for similarity to quality model. - debug (bool): Whether to also get debug metrics from ViSQOL or not. - """ - SAMPLE_RATES_MODES = {"audio": 48_000, "speech": 16_000} - ALLOWED_SAMPLE_RATES = frozenset(SAMPLE_RATES_MODES.values()) - - def __init__(self, bin: tp.Union[Path, str], mode: str = "audio", - model: str = "libsvm_nu_svr_model.txt", debug: bool = False): - assert bin is not None and Path(bin).exists(), f"Could not find ViSQOL binary in specified path: {bin}" - self.visqol_bin = str(bin) - self.visqol_mode = mode - self.target_sr = self._get_target_sr(self.visqol_mode) - self.model = model - self.debug = debug - assert Path(self.visqol_model).exists(), \ - f"Could not find the specified model in ViSQOL install: {self.visqol_model}" - - def _get_target_sr(self, mode: str) -> int: - # returns target sampling rate for the corresponding ViSQOL mode. - if mode not in ViSQOL.SAMPLE_RATES_MODES: - raise ValueError( - f"Unsupported mode! Allowed are: {', '.join(ViSQOL.SAMPLE_RATES_MODES.keys())}" - ) - return ViSQOL.SAMPLE_RATES_MODES[mode] - - def _prepare_files( - self, ref_sig: torch.Tensor, deg_sig: torch.Tensor, sr: int, target_sr: int, pad_with_silence: bool = False - ): - # prepare files for ViSQOL evaluation. - assert target_sr in ViSQOL.ALLOWED_SAMPLE_RATES - assert len(ref_sig) == len(deg_sig), ( - "Expects same number of ref and degraded inputs", - f" but ref len {len(ref_sig)} != deg len {len(deg_sig)}" - ) - # resample audio if needed - if sr != target_sr: - transform = torchaudio.transforms.Resample(sr, target_sr) - pad = int(0.5 * target_sr) - rs_ref = [] - rs_deg = [] - for i in range(len(ref_sig)): - rs_ref_i = transform(ref_sig[i]) - rs_deg_i = transform(deg_sig[i]) - if pad_with_silence: - rs_ref_i = torch.nn.functional.pad(rs_ref_i, (pad, pad), mode='constant', value=0) - rs_deg_i = torch.nn.functional.pad(rs_deg_i, (pad, pad), mode='constant', value=0) - rs_ref.append(rs_ref_i) - rs_deg.append(rs_deg_i) - ref_sig = torch.stack(rs_ref) - deg_sig = torch.stack(rs_deg) - # save audio chunks to tmp dir and create csv - tmp_dir = Path(tempfile.mkdtemp()) - try: - tmp_input_csv_path = tmp_dir / "input.csv" - tmp_results_csv_path = tmp_dir / "results.csv" - tmp_debug_json_path = tmp_dir / "debug.json" - with open(tmp_input_csv_path, "w") as csv_file: - csv_writer = csv.writer(csv_file) - csv_writer.writerow(["reference", "degraded"]) - for i in range(len(ref_sig)): - tmp_ref_filename = tmp_dir / f"ref_{i}.wav" - tmp_deg_filename = tmp_dir / f"deg_{i}.wav" - torchaudio.save( - tmp_ref_filename, - torch.clamp(ref_sig[i], min=-0.99, max=0.99), - sample_rate=target_sr, - bits_per_sample=16, - encoding="PCM_S" - ) - torchaudio.save( - tmp_deg_filename, - torch.clamp(deg_sig[i], min=-0.99, max=0.99), - sample_rate=target_sr, - bits_per_sample=16, - encoding="PCM_S" - ) - csv_writer.writerow([str(tmp_ref_filename), str(tmp_deg_filename)]) - return tmp_dir, tmp_input_csv_path, tmp_results_csv_path, tmp_debug_json_path - except Exception as e: - logger.error("Exception occurred when preparing files for ViSQOL: %s", e) - return tmp_dir, None, None, None - - def _flush_files(self, tmp_dir: tp.Union[Path, str]): - # flush tmp files used to compute ViSQOL. - shutil.rmtree(str(tmp_dir)) - - def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str]) -> float: - # collect results for each evaluated pair and return averaged moslqo score. - with open(results_csv_path, "r") as csv_file: - reader = csv.DictReader(csv_file) - moslqo_scores = [float(row["moslqo"]) for row in reader] - if len(moslqo_scores) > 0: - return sum(moslqo_scores) / len(moslqo_scores) - else: - return 0.0 - - def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) -> dict: - # collect debug data for the visqol inference. - with open(debug_json_path, "r") as f: - data = json.load(f) - return data - - @property - def visqol_model(self): - return f'{self.visqol_bin}/model/{self.model}' - - def _run_visqol( - self, - input_csv_path: tp.Union[Path, str], - results_csv_path: tp.Union[Path, str], - debug_csv_path: tp.Optional[tp.Union[Path, str]], - ): - input_csv_path = str(input_csv_path) - results_csv_path = str(results_csv_path) - debug_csv_path = str(debug_csv_path) - cmd = [ - f'{self.visqol_bin}/bazel-bin/visqol', - '--batch_input_csv', f'{input_csv_path}', - '--results_csv', f'{results_csv_path}' - ] - if debug_csv_path is not None: - cmd += ['--output_debug', f'{debug_csv_path}'] - if self.visqol_mode == "speech": - cmd += ['--use_speech_mode'] - cmd += ['--similarity_to_quality_model', f'{self.visqol_model}'] - result = subprocess.run(cmd, capture_output=True) - if result.returncode: - logger.error("Error with visqol: \n %s \n %s", result.stdout.decode(), result.stderr.decode()) - raise RuntimeError("Error while executing visqol") - result.check_returncode() - - def __call__( - self, - ref_sig: torch.Tensor, - deg_sig: torch.Tensor, - sr: int, - pad_with_silence: bool = False, - ): - """Calculate the ViSQOL metric for a pair of audio signals at a given sample rate. - Args: - ref_sig (torch.Tensor): Reference signals as [B, C, T]. - deg_sig (torch.Tensor): Degraded signals as [B, C, T]. - sr (int): Sample rate of the two audio signals. - pad_with_silence (bool): Whether to pad the file with silences as recommended - in visqol guidelines (see: https://github.com/google/visqol#general-guidelines-for-input). - Returns: - float: The ViSQOL score or mean score for the batch. - """ - logger.debug(f"Calculating visqol with mode={self.visqol_mode} on {len(ref_sig)} samples") - tmp_dir, input_csv, results_csv, debug_json = self._prepare_files( - ref_sig, deg_sig, sr, self.target_sr, pad_with_silence - ) - try: - if input_csv and results_csv: - self._run_visqol( - input_csv, - results_csv, - debug_json if self.debug else None, - ) - mosqol = self._collect_moslqo_score(results_csv) - return mosqol - else: - raise RuntimeError("Something unexpected happened when running VISQOL!") - except Exception as e: - logger.error("Exception occurred when running ViSQOL: %s", e) - finally: - self._flush_files(tmp_dir) diff --git a/audiocraft/audiocraft/models/__init__.py b/audiocraft/audiocraft/models/__init__.py deleted file mode 100644 index be6bfe4b787a132aeaabaed1c3437c9ecd5c656c..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/models/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -""" -Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel. -""" -# flake8: noqa -from . import builders, loaders -from .encodec import ( - CompressionModel, EncodecModel, DAC, - HFEncodecModel, HFEncodecCompressionModel) -from .audiogen import AudioGen -from .lm import LMModel -from .multibanddiffusion import MultiBandDiffusion -from .musicgen import MusicGen -from .unet import DiffusionUnet diff --git a/audiocraft/audiocraft/models/__pycache__/__init__.cpython-311.pyc b/audiocraft/audiocraft/models/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index da36453ed3573a6e1d198357130c356270ee5206..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/models/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/models/__pycache__/audiogen.cpython-311.pyc b/audiocraft/audiocraft/models/__pycache__/audiogen.cpython-311.pyc deleted file mode 100644 index 75d6a4610579722296550c8a17c602ffb1d7ddfc..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/models/__pycache__/audiogen.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/models/__pycache__/builders.cpython-311.pyc b/audiocraft/audiocraft/models/__pycache__/builders.cpython-311.pyc deleted file mode 100644 index d9a2220766f6586ceb09135edeb28733d9772274..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/models/__pycache__/builders.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/models/__pycache__/encodec.cpython-311.pyc b/audiocraft/audiocraft/models/__pycache__/encodec.cpython-311.pyc deleted file mode 100644 index 5b9ed907232b76840c1b299ea5fbcbb7da100d3b..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/models/__pycache__/encodec.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/models/__pycache__/lm.cpython-311.pyc b/audiocraft/audiocraft/models/__pycache__/lm.cpython-311.pyc deleted file mode 100644 index 571c3468eb4c452f0cb4cf4d7fd6bee703743b11..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/models/__pycache__/lm.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/models/__pycache__/loaders.cpython-311.pyc b/audiocraft/audiocraft/models/__pycache__/loaders.cpython-311.pyc deleted file mode 100644 index 002244067cbc34f0467a773d0c6888fd657902bc..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/models/__pycache__/loaders.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/models/__pycache__/multibanddiffusion.cpython-311.pyc b/audiocraft/audiocraft/models/__pycache__/multibanddiffusion.cpython-311.pyc deleted file mode 100644 index f7f19ed26afe9b48bdbc18db92e27295e5f9e4ed..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/models/__pycache__/multibanddiffusion.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/models/__pycache__/musicgen.cpython-311.pyc b/audiocraft/audiocraft/models/__pycache__/musicgen.cpython-311.pyc deleted file mode 100644 index 7647d227c35c714e64488c4f20ff473094112d4e..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/models/__pycache__/musicgen.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/models/__pycache__/unet.cpython-311.pyc b/audiocraft/audiocraft/models/__pycache__/unet.cpython-311.pyc deleted file mode 100644 index ee74b3c24f03f7c31cef168b98e97c0d56008c04..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/models/__pycache__/unet.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/models/audiogen.py b/audiocraft/audiocraft/models/audiogen.py deleted file mode 100644 index 5cb889982ddc027e2588b7cfb8ef428b313ce88a..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/models/audiogen.py +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Main model for using AudioGen. This will combine all the required components -and provide easy access to the generation API. -""" - -import typing as tp - -import torch - -from .encodec import CompressionModel -from .lm import LMModel -from .builders import get_debug_compression_model, get_debug_lm_model -from .loaders import load_compression_model, load_lm_model -from ..data.audio_utils import convert_audio -from ..modules.conditioners import ConditioningAttributes -from ..utils.autocast import TorchAutocast - - -class AudioGen: - """AudioGen main model with convenient generation API. - - Args: - name (str): name of the model. - compression_model (CompressionModel): Compression model - used to map audio to invertible discrete representations. - lm (LMModel): Language model over discrete representations. - max_duration (float, optional): maximum duration the model can produce, - otherwise, inferred from the training params. - """ - def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, - max_duration: tp.Optional[float] = None): - self.name = name - self.compression_model = compression_model - self.lm = lm - if max_duration is None: - if hasattr(lm, 'cfg'): - max_duration = lm.cfg.dataset.segment_duration # type: ignore - else: - raise ValueError("You must provide max_duration when building directly AudioGen") - assert max_duration is not None - self.max_duration: float = max_duration - self.device = next(iter(lm.parameters())).device - self.generation_params: dict = {} - self.set_generation_params(duration=5) # 5 seconds by default - self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None - if self.device.type == 'cpu': - self.autocast = TorchAutocast(enabled=False) - else: - self.autocast = TorchAutocast( - enabled=True, device_type=self.device.type, dtype=torch.float16) - - @property - def frame_rate(self) -> float: - """Roughly the number of AR steps per seconds.""" - return self.compression_model.frame_rate - - @property - def sample_rate(self) -> int: - """Sample rate of the generated audio.""" - return self.compression_model.sample_rate - - @property - def audio_channels(self) -> int: - """Audio channels of the generated audio.""" - return self.compression_model.channels - - @staticmethod - def get_pretrained(name: str = 'facebook/audiogen-medium', device=None): - """Return pretrained model, we provide a single model for now: - - facebook/audiogen-medium (1.5B), text to sound, - # see: https://huggingface.co/facebook/audiogen-medium - """ - if device is None: - if torch.cuda.device_count(): - device = 'cuda' - else: - device = 'cpu' - - if name == 'debug': - # used only for unit tests - compression_model = get_debug_compression_model(device, sample_rate=16000) - lm = get_debug_lm_model(device) - return AudioGen(name, compression_model, lm, max_duration=10) - - compression_model = load_compression_model(name, device=device) - lm = load_lm_model(name, device=device) - assert 'self_wav' not in lm.condition_provider.conditioners, \ - "AudioGen do not support waveform conditioning for now" - return AudioGen(name, compression_model, lm) - - def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, - top_p: float = 0.0, temperature: float = 1.0, - duration: float = 10.0, cfg_coef: float = 3.0, - two_step_cfg: bool = False, extend_stride: float = 2): - """Set the generation parameters for AudioGen. - - Args: - use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True. - top_k (int, optional): top_k used for sampling. Defaults to 250. - top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0. - temperature (float, optional): Softmax temperature parameter. Defaults to 1.0. - duration (float, optional): Duration of the generated waveform. Defaults to 10.0. - cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0. - two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance, - instead of batching together the two. This has some impact on how things - are padded but seems to have little impact in practice. - extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much - should we extend the audio each time. Larger values will mean less context is - preserved, and shorter value will require extra computations. - """ - assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration." - self.extend_stride = extend_stride - self.duration = duration - self.generation_params = { - 'use_sampling': use_sampling, - 'temp': temperature, - 'top_k': top_k, - 'top_p': top_p, - 'cfg_coef': cfg_coef, - 'two_step_cfg': two_step_cfg, - } - - def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None): - """Override the default progress callback.""" - self._progress_callback = progress_callback - - def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor: - """Generate samples conditioned on text. - - Args: - descriptions (list of str): A list of strings used as text conditioning. - progress (bool, optional): Flag to display progress of the generation process. Defaults to False. - """ - attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) - assert prompt_tokens is None - return self._generate_tokens(attributes, prompt_tokens, progress) - - def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, - descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None, - progress: bool = False) -> torch.Tensor: - """Generate samples conditioned on audio prompts. - - Args: - prompt (torch.Tensor): A batch of waveforms used for continuation. - Prompt should be [B, C, T], or [C, T] if only one sample is generated. - prompt_sample_rate (int): Sampling rate of the given audio waveforms. - descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None. - progress (bool, optional): Flag to display progress of the generation process. Defaults to False. - """ - if prompt.dim() == 2: - prompt = prompt[None] - if prompt.dim() != 3: - raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).") - prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels) - if descriptions is None: - descriptions = [None] * len(prompt) - attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt) - assert prompt_tokens is not None - return self._generate_tokens(attributes, prompt_tokens, progress) - - @torch.no_grad() - def _prepare_tokens_and_attributes( - self, - descriptions: tp.Sequence[tp.Optional[str]], - prompt: tp.Optional[torch.Tensor], - ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]: - """Prepare model inputs. - - Args: - descriptions (list of str): A list of strings used as text conditioning. - prompt (torch.Tensor): A batch of waveforms used for continuation. - """ - attributes = [ - ConditioningAttributes(text={'description': description}) - for description in descriptions] - - if prompt is not None: - if descriptions is not None: - assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match" - prompt = prompt.to(self.device) - prompt_tokens, scale = self.compression_model.encode(prompt) - assert scale is None - else: - prompt_tokens = None - return attributes, prompt_tokens - - def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], - prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor: - """Generate discrete audio tokens given audio prompt and/or conditions. - - Args: - attributes (list of ConditioningAttributes): Conditions used for generation (here text). - prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation. - progress (bool, optional): Flag to display progress of the generation process. Defaults to False. - Returns: - torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params. - """ - total_gen_len = int(self.duration * self.frame_rate) - max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate) - current_gen_offset: int = 0 - - def _progress_callback(generated_tokens: int, tokens_to_generate: int): - generated_tokens += current_gen_offset - if self._progress_callback is not None: - # Note that total_gen_len might be quite wrong depending on the - # codebook pattern used, but with delay it is almost accurate. - self._progress_callback(generated_tokens, total_gen_len) - else: - print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r') - - if prompt_tokens is not None: - assert max_prompt_len >= prompt_tokens.shape[-1], \ - "Prompt is longer than audio to generate" - - callback = None - if progress: - callback = _progress_callback - - if self.duration <= self.max_duration: - # generate by sampling from LM, simple case. - with self.autocast: - gen_tokens = self.lm.generate( - prompt_tokens, attributes, - callback=callback, max_gen_len=total_gen_len, **self.generation_params) - - else: - all_tokens = [] - if prompt_tokens is None: - prompt_length = 0 - else: - all_tokens.append(prompt_tokens) - prompt_length = prompt_tokens.shape[-1] - - stride_tokens = int(self.frame_rate * self.extend_stride) - while current_gen_offset + prompt_length < total_gen_len: - time_offset = current_gen_offset / self.frame_rate - chunk_duration = min(self.duration - time_offset, self.max_duration) - max_gen_len = int(chunk_duration * self.frame_rate) - with self.autocast: - gen_tokens = self.lm.generate( - prompt_tokens, attributes, - callback=callback, max_gen_len=max_gen_len, **self.generation_params) - if prompt_tokens is None: - all_tokens.append(gen_tokens) - else: - all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:]) - prompt_tokens = gen_tokens[:, :, stride_tokens:] - prompt_length = prompt_tokens.shape[-1] - current_gen_offset += stride_tokens - - gen_tokens = torch.cat(all_tokens, dim=-1) - - # generate audio - assert gen_tokens.dim() == 3 - with torch.no_grad(): - gen_audio = self.compression_model.decode(gen_tokens, None) - return gen_audio diff --git a/audiocraft/audiocraft/models/builders.py b/audiocraft/audiocraft/models/builders.py deleted file mode 100644 index 2a427bc4f4a1925501d9eee54429e3f72eedb7f9..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/models/builders.py +++ /dev/null @@ -1,267 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -All the functions to build the relevant models and modules -from the Hydra config. -""" - -import typing as tp - -import audiocraft -import omegaconf -import torch - -from .encodec import CompressionModel, EncodecModel -from .lm import LMModel -from ..modules.codebooks_patterns import ( - CodebooksPatternProvider, - DelayedPatternProvider, - MusicLMPattern, - ParallelPatternProvider, - UnrolledPatternProvider, - VALLEPattern, -) -from ..modules.conditioners import ( - BaseConditioner, - ChromaStemConditioner, - CLAPEmbeddingConditioner, - ConditionFuser, - ConditioningProvider, - LUTConditioner, - T5Conditioner, - ChordProgressionConditioner, - BeatConditioner -) -from .unet import DiffusionUnet -from .. import quantization as qt -from ..utils.utils import dict_from_config -from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor - - -def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer: - klass = { - 'no_quant': qt.DummyQuantizer, - 'rvq': qt.ResidualVectorQuantizer - }[quantizer] - kwargs = dict_from_config(getattr(cfg, quantizer)) - if quantizer != 'no_quant': - kwargs['dimension'] = dimension - return klass(**kwargs) - - -def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): - if encoder_name == 'seanet': - kwargs = dict_from_config(getattr(cfg, 'seanet')) - encoder_override_kwargs = kwargs.pop('encoder') - decoder_override_kwargs = kwargs.pop('decoder') - encoder_kwargs = {**kwargs, **encoder_override_kwargs} - decoder_kwargs = {**kwargs, **decoder_override_kwargs} - encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs) - decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs) - return encoder, decoder - else: - raise KeyError(f"Unexpected compression model {cfg.compression_model}") - - -def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel: - """Instantiate a compression model.""" - if cfg.compression_model == 'encodec': - kwargs = dict_from_config(getattr(cfg, 'encodec')) - encoder_name = kwargs.pop('autoencoder') - quantizer_name = kwargs.pop('quantizer') - encoder, decoder = get_encodec_autoencoder(encoder_name, cfg) - quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) - frame_rate = kwargs['sample_rate'] // encoder.hop_length - renormalize = kwargs.pop('renormalize', False) - # deprecated params - kwargs.pop('renorm', None) - return EncodecModel(encoder, decoder, quantizer, - frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device) - else: - raise KeyError(f"Unexpected compression model {cfg.compression_model}") - - -def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: - """Instantiate a transformer LM.""" - if cfg.lm_model == 'transformer_lm': - kwargs = dict_from_config(getattr(cfg, 'transformer_lm')) - n_q = kwargs['n_q'] - q_modeling = kwargs.pop('q_modeling', None) - codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern') - attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout')) - cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance')) - cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef'] - fuser = get_condition_fuser(cfg) - condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device) - if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically - kwargs['cross_attention'] = True - if codebooks_pattern_cfg.modeling is None: - assert q_modeling is not None, \ - "LM model should either have a codebook pattern defined or transformer_lm.q_modeling" - codebooks_pattern_cfg = omegaconf.OmegaConf.create( - {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}} - ) - pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg) - return LMModel( - pattern_provider=pattern_provider, - condition_provider=condition_provider, - fuser=fuser, - cfg_dropout=cfg_prob, - cfg_coef=cfg_coef, - attribute_dropout=attribute_dropout, - dtype=getattr(torch, cfg.dtype), - device=cfg.device, - **kwargs - ).to(cfg.device) - else: - raise KeyError(f"Unexpected LM model {cfg.lm_model}") - - -def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider: - """Instantiate a conditioning model.""" - device = cfg.device - duration = cfg.dataset.segment_duration - cfg = getattr(cfg, 'conditioners') - dict_cfg = {} if cfg is None else dict_from_config(cfg) - conditioners: tp.Dict[str, BaseConditioner] = {} - condition_provider_args = dict_cfg.pop('args', {}) - condition_provider_args.pop('merge_text_conditions_p', None) - condition_provider_args.pop('drop_desc_p', None) - - for cond, cond_cfg in dict_cfg.items(): - model_type = cond_cfg['model'] - model_args = cond_cfg[model_type] - if model_type == 't5': - conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args) - elif model_type == 'lut': - conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args) - elif model_type == 'chroma_stem': - conditioners[str(cond)] = ChromaStemConditioner( - output_dim=output_dim, - duration=duration, - device=device, - **model_args - ) - elif model_type == 'beat': - conditioners[str(cond)] = BeatConditioner( - output_dim=output_dim, - device=device, - **model_args - ) - elif model_type == 'chord': - conditioners[str(cond)] = ChordProgressionConditioner( - output_dim=output_dim, - device=device, - **model_args - ) - elif model_type == 'clap': - conditioners[str(cond)] = CLAPEmbeddingConditioner( - output_dim=output_dim, - device=device, - **model_args - ) - else: - raise ValueError(f"Unrecognized conditioning model: {model_type}") - conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args) - return conditioner - - -def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser: - """Instantiate a condition fuser object.""" - fuser_cfg = getattr(cfg, 'fuser') - fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate'] - fuse2cond = {k: fuser_cfg[k] for k in fuser_methods} - kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods} - print(f"==== use in-attention: {fuser_cfg['in_attn']} ====") - fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs) - return fuser - - -def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider: - """Instantiate a codebooks pattern provider object.""" - pattern_providers = { - 'parallel': ParallelPatternProvider, - 'delay': DelayedPatternProvider, - 'unroll': UnrolledPatternProvider, - 'valle': VALLEPattern, - 'musiclm': MusicLMPattern, - } - name = cfg.modeling - kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {} - klass = pattern_providers[name] - return klass(n_q, **kwargs) - - -def get_debug_compression_model(device='cpu', sample_rate: int = 32000): - """Instantiate a debug compression model to be used for unit tests.""" - assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model" - model_ratios = { - 16000: [10, 8, 8], # 25 Hz at 16kHz - 32000: [10, 8, 16] # 25 Hz at 32kHz - } - ratios: tp.List[int] = model_ratios[sample_rate] - frame_rate = 25 - seanet_kwargs: dict = { - 'n_filters': 4, - 'n_residual_layers': 1, - 'dimension': 32, - 'ratios': ratios, - } - print(seanet_kwargs) - encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs) - decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs) - quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4) - init_x = torch.randn(8, 32, 128) - quantizer(init_x, 1) # initialize kmeans etc. - compression_model = EncodecModel( - encoder, decoder, quantizer, - frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device) - return compression_model.eval() - - -def get_diffusion_model(cfg: omegaconf.DictConfig): - # TODO Find a way to infer the channels from dset - channels = cfg.channels - num_steps = cfg.schedule.num_steps - return DiffusionUnet( - chin=channels, num_steps=num_steps, **cfg.diffusion_unet) - - -def get_processor(cfg, sample_rate: int = 24000): - sample_processor = SampleProcessor() - if cfg.use: - kw = dict(cfg) - kw.pop('use') - kw.pop('name') - if cfg.name == "multi_band_processor": - sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw) - return sample_processor - - -def get_debug_lm_model(device='cpu'): - """Instantiate a debug LM to be used for unit tests.""" - pattern = DelayedPatternProvider(n_q=4) - dim = 16 - providers = { - 'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"), - } - condition_provider = ConditioningProvider(providers) - fuser = ConditionFuser( - {'cross': ['description'], 'prepend': [], - 'sum': [], 'input_interpolate': []}) - lm = LMModel( - pattern, condition_provider, fuser, - n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2, - cross_attention=True, causal=True) - return lm.to(device).eval() - - -def get_wrapped_compression_model( - compression_model: CompressionModel, - cfg: omegaconf.DictConfig) -> CompressionModel: - # more to come. - return compression_model diff --git a/audiocraft/audiocraft/models/encodec.py b/audiocraft/audiocraft/models/encodec.py deleted file mode 100644 index 40d133017c0a0eddaafb07d291b3845789775bc3..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/models/encodec.py +++ /dev/null @@ -1,393 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Compression models or wrapper around existing models. -Also defines the main interface that a model must follow to be usable as an audio tokenizer. -""" - -from abc import ABC, abstractmethod -import logging -import math -from pathlib import Path -import typing as tp - -import numpy as np -import torch -from torch import nn -from transformers import EncodecModel as HFEncodecModel - -from .. import quantization as qt - - -logger = logging.getLogger() - - -class CompressionModel(ABC, nn.Module): - """Base API for all compression model that aim at being used as audio tokenizers - with a language model. - """ - - @abstractmethod - def forward(self, x: torch.Tensor) -> qt.QuantizedResult: - ... - - @abstractmethod - def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: - """See `EncodecModel.encode`.""" - ... - - @abstractmethod - def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): - """See `EncodecModel.decode`.""" - ... - - @abstractmethod - def decode_latent(self, codes: torch.Tensor): - """Decode from the discrete codes to continuous latent space.""" - ... - - @property - @abstractmethod - def channels(self) -> int: - ... - - @property - @abstractmethod - def frame_rate(self) -> float: - ... - - @property - @abstractmethod - def sample_rate(self) -> int: - ... - - @property - @abstractmethod - def cardinality(self) -> int: - ... - - @property - @abstractmethod - def num_codebooks(self) -> int: - ... - - @property - @abstractmethod - def total_codebooks(self) -> int: - ... - - @abstractmethod - def set_num_codebooks(self, n: int): - """Set the active number of codebooks used by the quantizer.""" - ... - - @staticmethod - def get_pretrained( - name: str, device: tp.Union[torch.device, str] = 'cpu' - ) -> 'CompressionModel': - """Instantiate a CompressionModel from a given pretrained model. - - Args: - name (Path or str): name of the pretrained model. See after. - device (torch.device or str): Device on which the model is loaded. - - Pretrained models: - - dac_44khz (https://github.com/descriptinc/descript-audio-codec) - - dac_24khz (same) - - facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz) - - facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz) - - your own model on HugginFace. Export instructions to come... - """ - - from . import builders, loaders - model: CompressionModel - if name in ['dac_44khz', 'dac_24khz']: - model_type = name.split('_')[1] - logger.info("Getting pretrained compression model from DAC %s", model_type) - model = DAC(model_type) - elif name in ['debug_compression_model']: - logger.info("Getting pretrained compression model for debug") - model = builders.get_debug_compression_model() - elif Path(name).exists(): - # We assume here if the paths exist that it is in fact an AC checkpoint - # that was exported using `audiocraft.utils.export` functions. - model = loaders.load_compression_model(name, device=device) - else: - logger.info("Getting pretrained compression model from HF %s", name) - hf_model = HFEncodecModel.from_pretrained(name) - model = HFEncodecCompressionModel(hf_model).to(device) - return model.to(device).eval() - - -class EncodecModel(CompressionModel): - """Encodec model operating on the raw waveform. - - Args: - encoder (nn.Module): Encoder network. - decoder (nn.Module): Decoder network. - quantizer (qt.BaseQuantizer): Quantizer network. - frame_rate (int): Frame rate for the latent representation. - sample_rate (int): Audio sample rate. - channels (int): Number of audio channels. - causal (bool): Whether to use a causal version of the model. - renormalize (bool): Whether to renormalize the audio before running the model. - """ - # we need assignment to override the property in the abstract class, - # I couldn't find a better way... - frame_rate: float = 0 - sample_rate: int = 0 - channels: int = 0 - - def __init__(self, - encoder: nn.Module, - decoder: nn.Module, - quantizer: qt.BaseQuantizer, - frame_rate: int, - sample_rate: int, - channels: int, - causal: bool = False, - renormalize: bool = False): - super().__init__() - self.encoder = encoder - self.decoder = decoder - self.quantizer = quantizer - self.frame_rate = frame_rate - self.sample_rate = sample_rate - self.channels = channels - self.renormalize = renormalize - self.causal = causal - if self.causal: - # we force disabling here to avoid handling linear overlap of segments - # as supported in original EnCodec codebase. - assert not self.renormalize, 'Causal model does not support renormalize' - - @property - def total_codebooks(self): - """Total number of quantizer codebooks available.""" - return self.quantizer.total_codebooks - - @property - def num_codebooks(self): - """Active number of codebooks used by the quantizer.""" - return self.quantizer.num_codebooks - - def set_num_codebooks(self, n: int): - """Set the active number of codebooks used by the quantizer.""" - self.quantizer.set_num_codebooks(n) - - @property - def cardinality(self): - """Cardinality of each codebook.""" - return self.quantizer.bins - - def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: - scale: tp.Optional[torch.Tensor] - if self.renormalize: - mono = x.mean(dim=1, keepdim=True) - volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() - scale = 1e-8 + volume - x = x / scale - scale = scale.view(-1, 1) - else: - scale = None - return x, scale - - def postprocess(self, - x: torch.Tensor, - scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: - if scale is not None: - assert self.renormalize - x = x * scale.view(-1, 1, 1) - return x - - def forward(self, x: torch.Tensor) -> qt.QuantizedResult: - assert x.dim() == 3 - length = x.shape[-1] - x, scale = self.preprocess(x) - - emb = self.encoder(x) - q_res = self.quantizer(emb, self.frame_rate) - out = self.decoder(q_res.x) - - # remove extra padding added by the encoder and decoder - assert out.shape[-1] >= length, (out.shape[-1], length) - out = out[..., :length] - - q_res.x = self.postprocess(out, scale) - - return q_res - - def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: - """Encode the given input tensor to quantized representation along with scale parameter. - - Args: - x (torch.Tensor): Float tensor of shape [B, C, T] - - Returns: - codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of: - codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. - scale a float tensor containing the scale for audio renormalizealization. - """ - assert x.dim() == 3 - x, scale = self.preprocess(x) - emb = self.encoder(x) - codes = self.quantizer.encode(emb) - return codes, scale - - def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): - """Decode the given codes to a reconstructed representation, using the scale to perform - audio denormalization if needed. - - Args: - codes (torch.Tensor): Int tensor of shape [B, K, T] - scale (torch.Tensor, optional): Float tensor containing the scale value. - - Returns: - out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio. - """ - emb = self.decode_latent(codes) - out = self.decoder(emb) - out = self.postprocess(out, scale) - # out contains extra padding added by the encoder and decoder - return out - - def decode_latent(self, codes: torch.Tensor): - """Decode from the discrete codes to continuous latent space.""" - return self.quantizer.decode(codes) - - -class DAC(CompressionModel): - def __init__(self, model_type: str = "44khz"): - super().__init__() - try: - import dac.utils - except ImportError: - raise RuntimeError("Could not import dac, make sure it is installed, " - "please run `pip install descript-audio-codec`") - self.model = dac.utils.load_model(model_type=model_type) - self.n_quantizers = self.total_codebooks - self.model.eval() - - def forward(self, x: torch.Tensor) -> qt.QuantizedResult: - # We don't support training with this. - raise NotImplementedError("Forward and training with DAC not supported.") - - def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: - codes = self.model.encode(x, self.n_quantizers)[1] - return codes, None - - def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): - assert scale is None - z_q = self.decode_latent(codes) - return self.model.decode(z_q) - - def decode_latent(self, codes: torch.Tensor): - """Decode from the discrete codes to continuous latent space.""" - return self.model.quantizer.from_codes(codes)[0] - - @property - def channels(self) -> int: - return 1 - - @property - def frame_rate(self) -> float: - return self.model.sample_rate / self.model.hop_length - - @property - def sample_rate(self) -> int: - return self.model.sample_rate - - @property - def cardinality(self) -> int: - return self.model.codebook_size - - @property - def num_codebooks(self) -> int: - return self.n_quantizers - - @property - def total_codebooks(self) -> int: - return self.model.n_codebooks - - def set_num_codebooks(self, n: int): - """Set the active number of codebooks used by the quantizer. - """ - assert n >= 1 - assert n <= self.total_codebooks - self.n_quantizers = n - - -class HFEncodecCompressionModel(CompressionModel): - """Wrapper around HuggingFace Encodec. - """ - def __init__(self, model: HFEncodecModel): - super().__init__() - self.model = model - bws = self.model.config.target_bandwidths - num_codebooks = [ - bw * 1000 / (self.frame_rate * math.log2(self.cardinality)) - for bw in bws - ] - deltas = [nc - int(nc) for nc in num_codebooks] - # Checking we didn't do some bad maths and we indeed have integers! - assert all(deltas) <= 1e-3, deltas - self.possible_num_codebooks = [int(nc) for nc in num_codebooks] - self.set_num_codebooks(max(self.possible_num_codebooks)) - - def forward(self, x: torch.Tensor) -> qt.QuantizedResult: - # We don't support training with this. - raise NotImplementedError("Forward and training with HF EncodecModel not supported.") - - def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: - bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks) - bandwidth = self.model.config.target_bandwidths[bandwidth_index] - res = self.model.encode(x, None, bandwidth) - assert len(res[0]) == 1 - assert len(res[1]) == 1 - return res[0][0], res[1][0] - - def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): - if scale is None: - scales = [None] # type: ignore - else: - scales = scale # type: ignore - res = self.model.decode(codes[None], scales) - return res[0] - - def decode_latent(self, codes: torch.Tensor): - """Decode from the discrete codes to continuous latent space.""" - return self.model.quantizer.decode(codes.transpose(0, 1)) - - @property - def channels(self) -> int: - return self.model.config.audio_channels - - @property - def frame_rate(self) -> float: - hop_length = int(np.prod(self.model.config.upsampling_ratios)) - return self.sample_rate / hop_length - - @property - def sample_rate(self) -> int: - return self.model.config.sampling_rate - - @property - def cardinality(self) -> int: - return self.model.config.codebook_size - - @property - def num_codebooks(self) -> int: - return self._num_codebooks - - @property - def total_codebooks(self) -> int: - return max(self.possible_num_codebooks) - - def set_num_codebooks(self, n: int): - """Set the active number of codebooks used by the quantizer. - """ - if n not in self.possible_num_codebooks: - raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}") - self._num_codebooks = n diff --git a/audiocraft/audiocraft/models/lm.py b/audiocraft/audiocraft/models/lm.py deleted file mode 100644 index f21d61f198baf4fb88c0e9ebd400948e2277fcd6..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/models/lm.py +++ /dev/null @@ -1,533 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass -from functools import partial -import logging -import math -import typing as tp - -import torch -from torch import nn - -from ..utils import utils -from ..modules.streaming import StreamingModule, State -from ..modules.transformer import StreamingTransformer, create_norm_fn -from ..modules.conditioners import ( - ConditionFuser, - ClassifierFreeGuidanceDropout, - AttributeDropout, - ConditioningProvider, - ConditioningAttributes, - ConditionType, -) -from ..modules.codebooks_patterns import CodebooksPatternProvider -from ..modules.activations import get_activation_fn - - -logger = logging.getLogger(__name__) -ConditionTensors = tp.Dict[str, ConditionType] -CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]] - - -def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None): - """LM layer initialization. - Inspired from xlformers: https://github.com/fairinternal/xlformers - - Args: - method (str): Method name for init function. Valid options are: - 'gaussian', 'uniform'. - input_dim (int): Input dimension of the initialized module. - init_depth (int, optional): Optional init depth value used to rescale - the standard deviation if defined. - """ - # Compute std - std = 1 / math.sqrt(input_dim) - # Rescale with depth - if init_depth is not None: - std = std / math.sqrt(2 * init_depth) - - if method == 'gaussian': - return partial( - torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std - ) - elif method == 'uniform': - bound = math.sqrt(3) * std # ensure the standard deviation is `std` - return partial(torch.nn.init.uniform_, a=-bound, b=bound) - else: - raise ValueError("Unsupported layer initialization method") - - -def init_layer(m: nn.Module, - method: str, - init_depth: tp.Optional[int] = None, - zero_bias_init: bool = False): - """Wrapper around ``get_init_fn`` for proper initialization of LM modules. - - Args: - m (nn.Module): Module to initialize. - method (str): Method name for the init function. - init_depth (int, optional): Optional init depth value used to rescale - the standard deviation if defined. - zero_bias_init (bool): Whether to initialize the bias to 0 or not. - """ - if isinstance(m, nn.Linear): - init_fn = get_init_fn(method, m.in_features, init_depth=init_depth) - if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: - weight = m.weight.float() - init_fn(weight) - m.weight.data[:] = weight.half() - else: - init_fn(m.weight) - if zero_bias_init and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.Embedding): - init_fn = get_init_fn(method, m.embedding_dim, init_depth=None) - if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: - weight = m.weight.float() - init_fn(weight) - m.weight.data[:] = weight.half() - else: - init_fn(m.weight) - - -class ScaledEmbedding(nn.Embedding): - """Boost learning rate for embeddings (with `scale`). - """ - def __init__(self, *args, lr=None, **kwargs): - super().__init__(*args, **kwargs) - self.lr = lr - - def make_optim_group(self): - group = {"params": list(self.parameters())} - if self.lr is not None: - group["lr"] = self.lr - return group - - -@dataclass -class LMOutput: - # The logits are already re-aligned with the input codes - # hence no extra shift is required, e.g. when computing CE - logits: torch.Tensor # [B, K, T, card] - mask: torch.Tensor # [B, K, T] - - -class LMModel(StreamingModule): - """Transformer-based language model on multiple streams of codes. - - Args: - pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving. - condition_provider (MusicConditioningProvider): Conditioning provider from metadata. - fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input. - n_q (int): Number of parallel streams to model. - card (int): Cardinality, vocabulary size. - dim (int): Dimension of the transformer encoder. - num_heads (int): Number of heads for the transformer encoder. - hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder. - norm (str): Normalization method. - norm_first (bool): Use pre-norm instead of post-norm. - emb_lr (float, optional): Embedding-specific learning rate. - bias_proj (bool): Use bias for output projections. - weight_init (str, optional): Method for weight initialization. - depthwise_init (str, optional): Method for depthwise weight initialization. - zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros. - cfg_dropout (float): Classifier-free guidance dropout. - cfg_coef (float): Classifier-free guidance coefficient. - attribute_dropout (dict): Attribute dropout probabilities. - two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps. - **kwargs: Additional parameters for the transformer encoder. - """ - def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider, - fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8, - hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False, - emb_lr: tp.Optional[float] = None, bias_proj: bool = True, - weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None, - zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0, - attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False, - **kwargs): - super().__init__() - self.cfg_coef = cfg_coef - self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout) - self.att_dropout = AttributeDropout(p=attribute_dropout) - self.condition_provider = condition_provider - self.fuser = fuser - self.card = card - embed_dim = self.card + 1 - self.n_q = n_q - self.dim = dim - self.pattern_provider = pattern_provider - self.two_step_cfg = two_step_cfg - self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)]) - if 'activation' in kwargs: - kwargs['activation'] = get_activation_fn(kwargs['activation']) - self.transformer = StreamingTransformer( - d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim), - norm=norm, norm_first=norm_first, **kwargs) - self.out_norm: tp.Optional[nn.Module] = None - if norm_first: - self.out_norm = create_norm_fn(norm, dim) - self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)]) - self._init_weights(weight_init, depthwise_init, zero_bias_init) - self._fsdp: tp.Optional[nn.Module] - self.__dict__['_fsdp'] = None - - def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool): - """Initialization of the transformer module weights. - - Args: - weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options. - depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid: - 'current' where the depth corresponds to the current layer index or 'global' where the total number - of layer is used as depth. If not set, no depthwise initialization strategy is used. - zero_bias_init (bool): Whether to initialize bias to zero or not. - """ - assert depthwise_init is None or depthwise_init in ['current', 'global'] - assert depthwise_init is None or weight_init is not None, \ - "If 'depthwise_init' is defined, a 'weight_init' method should be provided." - assert not zero_bias_init or weight_init is not None, \ - "If 'zero_bias_init', a 'weight_init' method should be provided" - - if weight_init is None: - return - - for emb_layer in self.emb: - init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) - - for layer_idx, tr_layer in enumerate(self.transformer.layers): - depth = None - if depthwise_init == 'current': - depth = layer_idx + 1 - elif depthwise_init == 'global': - depth = len(self.transformer.layers) - init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init) - tr_layer.apply(init_fn) - - for linear in self.linears: - init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) - - @property - def special_token_id(self) -> int: - return self.card - - @property - def num_codebooks(self) -> int: - return self.n_q - - def forward(self, sequence: torch.Tensor, - conditions: tp.List[ConditioningAttributes], - condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor: - """Apply language model on sequence and conditions. - Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and - S the sequence steps, return the logits with shape [B, card, K, S]. - - Args: - indices (torch.Tensor): Indices of the codes to model. - conditions (list of ConditioningAttributes): Conditions to use when modeling - the given codes. Note that when evaluating multiple time with the same conditioning - you should pre-compute those and pass them as `condition_tensors`. - condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning - tensors, see `conditions`. - Returns: - torch.Tensor: Logits. - """ - B, K, S = sequence.shape - #assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks" - input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)]) # [B, K, S] -> [B, K, S, dim] -(sum)> [B, S, dim] - if condition_tensors is None: - assert not self._is_streaming, "Conditions tensors should be precomputed when streaming." - # apply dropout modules - conditions = self.cfg_dropout(conditions) - conditions = self.att_dropout(conditions) - tokenized = self.condition_provider.tokenize(conditions) - # encode conditions and fuse, both have a streaming cache to not recompute when generating. - condition_tensors = self.condition_provider(tokenized) - else: - assert not conditions, "Shouldn't pass both conditions and condition_tensors." - - # input_, cross_attention_input = self.fuser(input_, condition_tensors) - input_, in_attn_input, cross_attention_input = self.fuser(input_, condition_tensors) - - # out = self.transformer(input_, cross_attention_src=cross_attention_input) - out = self.transformer(input_, in_attn_src=in_attn_input, cross_attention_src=cross_attention_input) - if self.out_norm: - out = self.out_norm(out) - logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card] - - # remove the prefix from the model outputs - if len(self.fuser.fuse2cond['prepend']) > 0: - logits = logits[:, :, -S:] - - return logits # [B, K, S, card] - - def compute_predictions( - self, codes: torch.Tensor, - conditions: tp.List[ConditioningAttributes], - condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput: - """Given an input tensor of codes [B, K, T] and list of conditions, runs the model - forward using the specified codes interleaving pattern. - - Args: - codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size, - K the number of codebooks and T the number of timesteps. - conditions (list of ConditioningAttributes): conditionings to use when modeling - the given codes. Note that when evaluating multiple time with the same conditioning - you should pre-compute those and pass them as `condition_tensors`. - condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning - tensors, see `conditions`. - Returns: - LMOutput: Language model outputs - logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes, - i.e. the first item corresponds to logits to predict the first code, meaning that - no additional shifting of codes and logits is required. - mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions. - Given the specified interleaving strategies, parts of the logits and codes should - not be considered as valid predictions because of invalid context. - """ - B, K, T = codes.shape - codes = codes.contiguous() - # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens - pattern = self.pattern_provider.get_pattern(T) - sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence( - codes, self.special_token_id, keep_only_valid_steps=True - ) - # apply model on pattern sequence - model = self if self._fsdp is None else self._fsdp - logits = model(sequence_codes, conditions, condition_tensors) # [B, K, S, card] - # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card] - # and provide the corresponding mask over invalid positions of tokens - logits = logits.permute(0, 3, 1, 2) # [B, card, K, S] - # note: we use nans as special token to make it obvious if we feed unexpected logits - logits, logits_indexes, logits_mask = pattern.revert_pattern_logits( - logits, float('nan'), keep_only_valid_steps=True - ) - logits = logits.permute(0, 2, 3, 1) # [B, K, T, card] - logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T] - return LMOutput(logits, logits_mask) - - def _sample_next_token(self, - sequence: torch.Tensor, - cfg_conditions: CFGConditions, - unconditional_state: State, - use_sampling: bool = False, - temp: float = 1.0, - top_k: int = 0, - top_p: float = 0.0, - cfg_coef: tp.Optional[float] = None) -> torch.Tensor: - """Sample next token from the model given a sequence and a set of conditions. The model supports - multiple sampling strategies (greedy sampling, softmax, top-k, top-p...). - - Args: - sequence (torch.Tensor): Current sequence of shape [B, K, S] - with K corresponding to the number of codebooks and S the number of sequence steps. - S = 1 in streaming mode, except for the first step that contains a bigger prompt. - condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used, - should be twice the batch size, being the concatenation of the conditions + null conditions. - use_sampling (bool): Whether to use a sampling strategy or not. - temp (float): Sampling temperature. - top_k (int): K for "top-k" sampling. - top_p (float): P for "top-p" sampling. - cfg_coef (float, optional): classifier free guidance coefficient - Returns: - next_token (torch.Tensor): Next token tensor of shape [B, K, 1]. - """ - B = sequence.shape[0] - cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef - model = self if self._fsdp is None else self._fsdp - if self.two_step_cfg and cfg_conditions != {}: - assert isinstance(cfg_conditions, tuple), type(cfg_conditions) - condition_tensors, null_condition_tensors = cfg_conditions - cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors) - state = self.get_streaming_state() - self.set_streaming_state(unconditional_state) - uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors) - unconditional_state.update(self.get_streaming_state()) - self.set_streaming_state(state) - logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef - else: - assert isinstance(cfg_conditions, dict) - condition_tensors = cfg_conditions - if condition_tensors: - # Preparing for CFG, predicting both conditional and unconditional logits. - sequence = torch.cat([sequence, sequence], dim=0) - all_logits = model( - sequence, - conditions=[], condition_tensors=condition_tensors) - if condition_tensors: - cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card] - logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef - else: - logits = all_logits - - logits = logits.permute(0, 1, 3, 2) # [B, K, card, T] - logits = logits[..., -1] # [B x K x card] - - # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error. - if use_sampling and temp > 0.0: - probs = torch.softmax(logits / temp, dim=-1) - if top_p > 0.0: - next_token = utils.sample_top_p(probs, p=top_p) - elif top_k > 0: - next_token = utils.sample_top_k(probs, k=top_k) - else: - next_token = utils.multinomial(probs, num_samples=1) - else: - next_token = torch.argmax(logits, dim=-1, keepdim=True) - - return next_token - - @torch.no_grad() - def generate(self, - prompt: tp.Optional[torch.Tensor] = None, - conditions: tp.List[ConditioningAttributes] = [], - num_samples: tp.Optional[int] = None, - max_gen_len: int = 256, - use_sampling: bool = True, - temp: float = 1.0, - top_k: int = 250, - top_p: float = 0.0, - cfg_coef: tp.Optional[float] = None, - two_step_cfg: tp.Optional[bool] = None, - remove_prompts: bool = False, - check: bool = False, - callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor: - """Generate tokens sampling from the model given a prompt or unconditionally. Generation can - be perform in a greedy fashion or using sampling with top K and top P strategies. - - Args: - prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T]. - conditions_tensors (list of ConditioningAttributes, optional): List of conditions. - num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given. - max_gen_len (int): Maximum generation length. - use_sampling (bool): Whether to use a sampling strategy or not. - temp (float): Sampling temperature. - top_k (int): K for "top-k" sampling. - top_p (float): P for "top-p" sampling. - cfg_coeff (float, optional): Classifier-free guidance coefficient. - two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation. - remove_prompts (bool): Whether to remove prompts from generation or not. - check (bool): Whether to apply further checks on generated sequence. - callback (Callback, optional): Callback function to report generation progress. - Returns: - torch.Tensor: Generated tokens. - """ - assert not self.training, "generation shouldn't be used in training mode." - first_param = next(iter(self.parameters())) - device = first_param.device - - # Checking all input shapes are consistent. - possible_num_samples = [] - if num_samples is not None: - possible_num_samples.append(num_samples) - elif prompt is not None: - possible_num_samples.append(prompt.shape[0]) - elif conditions: - possible_num_samples.append(len(conditions)) - else: - possible_num_samples.append(1) - assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes" - num_samples = possible_num_samples[0] - - # below we create set of conditions: one conditional and one unconditional - # to do that we merge the regular condition together with the null condition - # we then do 1 forward pass instead of 2. - # the reason for that is two-fold: - # 1. it is about x2 faster than doing 2 forward passes - # 2. avoid the streaming API treating the 2 passes as part of different time steps - # We also support doing two different passes, in particular to ensure that - # the padding structure is exactly the same between train and test. - # With a batch size of 1, this can be slower though. - cfg_conditions: CFGConditions - two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg - if conditions: - null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions) - if two_step_cfg: - cfg_conditions = ( - self.condition_provider(self.condition_provider.tokenize(conditions)), - self.condition_provider(self.condition_provider.tokenize(null_conditions)), - ) - else: - conditions = conditions + null_conditions - tokenized = self.condition_provider.tokenize(conditions) - cfg_conditions = self.condition_provider(tokenized) - else: - cfg_conditions = {} - - if prompt is None: - assert num_samples > 0 - prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device) - - B, K, T = prompt.shape - start_offset = T - assert start_offset < max_gen_len - - pattern = self.pattern_provider.get_pattern(max_gen_len) - # this token is used as default value for codes that are not generated yet - unknown_token = -1 - - # we generate codes up to the max_gen_len that will be mapped to the pattern sequence - gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device) - # filling the gen_codes with the prompt if needed - gen_codes[..., :start_offset] = prompt - # create the gen_sequence with proper interleaving from the pattern: [B, K, S] - gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id) - # retrieve the start_offset in the sequence: - # it is the first sequence step that contains the `start_offset` timestep - start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) - assert start_offset_sequence is not None - - with self.streaming(): - unconditional_state = self.get_streaming_state() - prev_offset = 0 - gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S] - for offset in range(start_offset_sequence, gen_sequence_len): - # get current sequence (note that the streaming API is providing the caching over previous offsets) - curr_sequence = gen_sequence[..., prev_offset:offset] - curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1) - if check: - # check coherence between mask and sequence - assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all() - # should never happen as gen_sequence is filled progressively - assert not (curr_sequence == unknown_token).any() - # sample next token from the model, next token shape is [B, K, 1] - next_token = self._sample_next_token( - curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p, - cfg_coef=cfg_coef) - # ensure the tokens that should be masked are properly set to special_token_id - # as the model never output special_token_id - valid_mask = mask[..., offset:offset+1].expand(B, -1, -1) - next_token[~valid_mask] = self.special_token_id - # ensure we don't overwrite prompt tokens, we only write over unknown tokens - # (then mask tokens should be left as is as well, which is correct) - gen_sequence[..., offset:offset+1] = torch.where( - gen_sequence[..., offset:offset+1] == unknown_token, - next_token, gen_sequence[..., offset:offset+1] - ) - prev_offset = offset - if callback is not None: - callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) - unconditional_state.clear() - - # ensure sequence has been entirely filled - assert not (gen_sequence == unknown_token).any() - # ensure gen_sequence pattern and mask are matching - # which means the gen_sequence is valid according to the pattern - assert ( - gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id) - ).all() - # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps - out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) - - # sanity checks over the returned codes and corresponding masks - assert (out_codes[..., :max_gen_len] != unknown_token).all() - assert (out_mask[..., :max_gen_len] == 1).all() - - out_start_offset = start_offset if remove_prompts else 0 - out_codes = out_codes[..., out_start_offset:max_gen_len] - - # ensure the returned codes are all valid - assert (out_codes >= 0).all() and (out_codes <= self.card).all() - return out_codes diff --git a/audiocraft/audiocraft/models/loaders.py b/audiocraft/audiocraft/models/loaders.py deleted file mode 100644 index 9c7808a0588bd1a8084157b072bae42aa7efaf84..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/models/loaders.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Utility functions to load from the checkpoints. -Each checkpoint is a torch.saved dict with the following keys: -- 'xp.cfg': the hydra config as dumped during training. This should be used - to rebuild the object using the audiocraft.models.builders functions, -- 'model_best_state': a readily loadable best state for the model, including - the conditioner. The model obtained from `xp.cfg` should be compatible - with this state dict. In the case of a LM, the encodec model would not be - bundled along but instead provided separately. - -Those functions also support loading from a remote location with the Torch Hub API. -They also support overriding some parameters, in particular the device and dtype -of the returned model. -""" - -from pathlib import Path -from huggingface_hub import hf_hub_download -import typing as tp -import os - -from omegaconf import OmegaConf, DictConfig -import torch - -from . import builders -from .encodec import CompressionModel - - -def get_audiocraft_cache_dir() -> tp.Optional[str]: - return os.environ.get('AUDIOCRAFT_CACHE_DIR', None) - - -def _get_state_dict( - file_or_url_or_id: tp.Union[Path, str], - filename: tp.Optional[str] = None, - device='cpu', - cache_dir: tp.Optional[str] = None, -): - if cache_dir is None: - cache_dir = get_audiocraft_cache_dir() - # Return the state dict either from a file or url - file_or_url_or_id = str(file_or_url_or_id) - assert isinstance(file_or_url_or_id, str) - - if os.path.isfile(file_or_url_or_id): - return torch.load(file_or_url_or_id, map_location=device) - - if os.path.isdir(file_or_url_or_id): - file = f"{file_or_url_or_id}/{filename}" - return torch.load(file, map_location=device) - - elif file_or_url_or_id.startswith('https://'): - return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True) - - else: - assert filename is not None, "filename needs to be defined if using HF checkpoints" - - file = hf_hub_download(repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir) - return torch.load(file, map_location=device) - - -def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None): - return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir) - - -def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): - pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) - if 'pretrained' in pkg: - return CompressionModel.get_pretrained(pkg['pretrained'], device=device) - cfg = OmegaConf.create(pkg['xp.cfg']) - cfg.device = str(device) - model = builders.get_compression_model(cfg) - model.load_state_dict(pkg['best_state']) - model.eval() - return model - - -def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None): - return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir) - - -def _delete_param(cfg: DictConfig, full_name: str): - parts = full_name.split('.') - for part in parts[:-1]: - if part in cfg: - cfg = cfg[part] - else: - return - OmegaConf.set_struct(cfg, False) - if parts[-1] in cfg: - del cfg[parts[-1]] - OmegaConf.set_struct(cfg, True) - - -def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): - pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) - cfg = OmegaConf.create(pkg['xp.cfg']) - cfg.device = str(device) - if cfg.device == 'cpu': - cfg.dtype = 'float32' - else: - cfg.dtype = 'float16' - _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path') - _delete_param(cfg, 'conditioners.args.merge_text_conditions_p') - _delete_param(cfg, 'conditioners.args.drop_desc_p') - model = builders.get_lm_model(cfg) - model.load_state_dict(pkg['best_state']) - model.eval() - model.cfg = cfg - return model - - -def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None): - return _get_state_dict(file_or_url_or_id, filename="all_in_one.pt", cache_dir=cache_dir) - - -def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): - pkg = load_mbd_ckpt(file_or_url_or_id, cache_dir=cache_dir) - models = [] - processors = [] - cfgs = [] - sample_rate = pkg['sample_rate'] - for i in range(pkg['n_bands']): - cfg = pkg[i]['cfg'] - model = builders.get_diffusion_model(cfg) - model_dict = pkg[i]['model_state'] - model.load_state_dict(model_dict) - model.to(device) - processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate) - processor_dict = pkg[i]['processor_state'] - processor.load_state_dict(processor_dict) - processor.to(device) - models.append(model) - processors.append(processor) - cfgs.append(cfg) - return models, processors, cfgs diff --git a/audiocraft/audiocraft/models/multibanddiffusion.py b/audiocraft/audiocraft/models/multibanddiffusion.py deleted file mode 100644 index 1121d2fc660ab2ceed7deaaf87edba5337ab5472..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/models/multibanddiffusion.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Multi Band Diffusion models as described in -"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion" -(paper link). -""" - -import typing as tp - -import torch -import julius - -from .unet import DiffusionUnet -from ..modules.diffusion_schedule import NoiseSchedule -from .encodec import CompressionModel -from ..solvers.compression import CompressionSolver -from .loaders import load_compression_model, load_diffusion_models - - -class DiffusionProcess: - """Sampling for a diffusion Model. - - Args: - model (DiffusionUnet): Diffusion U-Net model. - noise_schedule (NoiseSchedule): Noise schedule for diffusion process. - """ - def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None: - """ - """ - self.model = model - self.schedule = noise_schedule - - def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor, - step_list: tp.Optional[tp.List[int]] = None): - """Perform one diffusion process to generate one of the bands. - - Args: - condition (tensor): The embeddings form the compression model. - initial_noise (tensor): The initial noise to start the process/ - """ - return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list, - condition=condition) - - -class MultiBandDiffusion: - """Sample from multiple diffusion models. - - Args: - DPs (list of DiffusionProcess): Diffusion processes. - codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens. - """ - def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None: - self.DPs = DPs - self.codec_model = codec_model - self.device = next(self.codec_model.parameters()).device - - @property - def sample_rate(self) -> int: - return self.codec_model.sample_rate - - @staticmethod - def get_mbd_musicgen(device=None): - """Load our diffusion models trained for MusicGen.""" - if device is None: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - path = 'https://dl.fbaipublicfiles.com/encodec/Diffusion/mbd_musicgen_32khz.th' - name = 'facebook/musicgen-small' - codec_model = load_compression_model(name, device=device) - models, processors, cfgs = load_diffusion_models(path, device=device) - DPs = [] - for i in range(len(models)): - schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device) - DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule)) - return MultiBandDiffusion(DPs=DPs, codec_model=codec_model) - - @staticmethod - def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True, - device: tp.Optional[tp.Union[torch.device, str]] = None, - n_q: tp.Optional[int] = None): - """Get the pretrained Models for MultibandDiffusion. - - Args: - bw (float): Bandwidth of the compression model. - pretrained (bool): Whether to use / download if necessary the models. - device (torch.device or str, optional): Device on which the models are loaded. - n_q (int, optional): Number of quantizers to use within the compression model. - """ - if device is None: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available" - if n_q is not None: - assert n_q in [2, 4, 8] - assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \ - f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}" - n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw] - codec_model = CompressionSolver.model_from_checkpoint( - '//pretrained/facebook/encodec_24khz', device=device) - codec_model.set_num_codebooks(n_q) - codec_model = codec_model.to(device) - path = f'https://dl.fbaipublicfiles.com/encodec/Diffusion/mbd_comp_{n_q}.pt' - models, processors, cfgs = load_diffusion_models(path, device=device) - DPs = [] - for i in range(len(models)): - schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device) - DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule)) - return MultiBandDiffusion(DPs=DPs, codec_model=codec_model) - - return MultiBandDiffusion(DPs, codec_model) - - @torch.no_grad() - def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: - """Get the conditioning (i.e. latent reprentatios of the compression model) from a waveform. - Args: - wav (torch.Tensor): The audio that we want to extract the conditioning from - sample_rate (int): sample rate of the audio""" - if sample_rate != self.sample_rate: - wav = julius.resample_frac(wav, sample_rate, self.sample_rate) - codes, scale = self.codec_model.encode(wav) - assert scale is None, "Scaled compression models not supported." - emb = self.get_emb(codes) - return emb - - @torch.no_grad() - def get_emb(self, codes: torch.Tensor): - """Get latent representation from the discrete codes - Argrs: - codes (torch.Tensor): discrete tokens""" - emb = self.codec_model.decode_latent(codes) - return emb - - def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None, - step_list: tp.Optional[tp.List[int]] = None): - """Generate Wavform audio from the latent embeddings of the compression model - Args: - emb (torch.Tensor): Conditioning embeddinds - size (none torch.Size): size of the output - if None this is computed from the typical upsampling of the model - step_list (optional list[int]): list of Markov chain steps, defaults to 50 linearly spaced step. - """ - if size is None: - upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate) - size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling]) - assert size[0] == emb.size(0) - out = torch.zeros(size).to(self.device) - for DP in self.DPs: - out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out)) - return out - - def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1): - """match the eq to the encodec output by matching the standard deviation of some frequency bands - Args: - wav (torch.Tensor): audio to equalize - ref (torch.Tensor):refenrence audio from which we match the spectrogram. - n_bands (int): number of bands of the eq - strictness (float): how strict the the matching. 0 is no matching, 1 is exact matching. - """ - split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device) - bands = split(wav) - bands_ref = split(ref) - out = torch.zeros_like(ref) - for i in range(n_bands): - out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness - return out - - def regenerate(self, wav: torch.Tensor, sample_rate: int): - """Regenerate a wavform through compression and diffusion regeneration. - Args: - wav (torch.Tensor): Original 'ground truth' audio - sample_rate (int): sample rate of the input (and output) wav - """ - if sample_rate != self.codec_model.sample_rate: - wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate) - emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate) - size = wav.size() - out = self.generate(emb, size=size) - if sample_rate != self.codec_model.sample_rate: - out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate) - return out - - def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32): - """Generate Waveform audio with diffusion from the discrete codes. - Args: - tokens (torch.Tensor): discrete codes - n_bands (int): bands for the eq matching. - """ - wav_encodec = self.codec_model.decode(tokens) - condition = self.get_emb(tokens) - wav_diffusion = self.generate(emb=condition, size=wav_encodec.size()) - return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands) diff --git a/audiocraft/audiocraft/models/musicgen.py b/audiocraft/audiocraft/models/musicgen.py deleted file mode 100644 index e04878cc794b53e6c8f67ee9d341550ccccf0bf3..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/models/musicgen.py +++ /dev/null @@ -1,583 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Main model for using MusicGen. This will combine all the required components -and provide easy access to the generation API. -""" - -import typing as tp -import warnings - -import torch -import numpy as np - -from .encodec import CompressionModel -from .lm import LMModel -from .builders import get_debug_compression_model, get_debug_lm_model -from .loaders import load_compression_model, load_lm_model -from ..data.audio_utils import convert_audio, convert_txtchord2chroma, convert_txtchord2chroma_24 -from ..modules.conditioners import ConditioningAttributes, WavCondition, ChordCondition, BeatCondition -from ..utils.autocast import TorchAutocast - - -MelodyList = tp.List[tp.Optional[torch.Tensor]] -MelodyType = tp.Union[torch.Tensor, MelodyList] - - -# backward compatible names mapping -_HF_MODEL_CHECKPOINTS_MAP = { - "small": "facebook/musicgen-small", - "medium": "facebook/musicgen-medium", - "large": "facebook/musicgen-large", - "melody": "facebook/musicgen-melody", -} - - -class MusicGen: - """MusicGen main model with convenient generation API. - - Args: - name (str): name of the model. - compression_model (CompressionModel): Compression model - used to map audio to invertible discrete representations. - lm (LMModel): Language model over discrete representations. - max_duration (float, optional): maximum duration the model can produce, - otherwise, inferred from the training params. - """ - def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, - max_duration: tp.Optional[float] = None): - self.name = name - self.compression_model = compression_model - self.lm = lm - if max_duration is None: - if hasattr(lm, 'cfg'): - max_duration = lm.cfg.dataset.segment_duration # type: ignore - else: - raise ValueError("You must provide max_duration when building directly MusicGen") - assert max_duration is not None - self.max_duration: float = max_duration - self.device = next(iter(lm.parameters())).device - self.generation_params: dict = {} - self.set_generation_params(duration=6, extend_stride=3) # 6 seconds by default - self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None - if self.device.type == 'cpu': - self.autocast = TorchAutocast(enabled=False) - else: - self.autocast = TorchAutocast( - enabled=True, device_type=self.device.type, dtype=torch.float16) - - @property - def frame_rate(self) -> float: - """Roughly the number of AR steps per seconds.""" - return self.compression_model.frame_rate - - @property - def sample_rate(self) -> int: - """Sample rate of the generated audio.""" - return self.compression_model.sample_rate - - @property - def audio_channels(self) -> int: - """Audio channels of the generated audio.""" - return self.compression_model.channels - - @staticmethod - def get_pretrained(name: str = 'facebook/musicgen-melody', device=None): - """Return pretrained model, we provide four models: - - facebook/musicgen-small (300M), text to music, - # see: https://huggingface.co/facebook/musicgen-small - - facebook/musicgen-medium (1.5B), text to music, - # see: https://huggingface.co/facebook/musicgen-medium - - facebook/musicgen-melody (1.5B) text to music and text+melody to music, - # see: https://huggingface.co/facebook/musicgen-melody - - facebook/musicgen-large (3.3B), text to music, - # see: https://huggingface.co/facebook/musicgen-large - """ - if device is None: - if torch.cuda.device_count(): - device = 'cuda' - else: - device = 'cpu' - - if name == 'debug': - # used only for unit tests - compression_model = get_debug_compression_model(device) - lm = get_debug_lm_model(device) - return MusicGen(name, compression_model, lm, max_duration=30) - - if name in _HF_MODEL_CHECKPOINTS_MAP: - warnings.warn( - "MusicGen pretrained model relying on deprecated checkpoint mapping. " + - f"Please use full pre-trained id instead: facebook/musicgen-{name}") - name = _HF_MODEL_CHECKPOINTS_MAP[name] - - lm = load_lm_model(name, device=device) - compression_model = load_compression_model(name, device=device) - if 'self_wav' in lm.condition_provider.conditioners: - lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True - - return MusicGen(name, compression_model, lm) - - def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, - top_p: float = 0.0, temperature: float = 1.0, - duration: float = 30.0, cfg_coef: float = 3.0, - two_step_cfg: bool = False, extend_stride: float = 18): - """Set the generation parameters for MusicGen. - - Args: - use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True. - top_k (int, optional): top_k used for sampling. Defaults to 250. - top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0. - temperature (float, optional): Softmax temperature parameter. Defaults to 1.0. - duration (float, optional): Duration of the generated waveform. Defaults to 30.0. - cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0. - two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance, - instead of batching together the two. This has some impact on how things - are padded but seems to have little impact in practice. - extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much - should we extend the audio each time. Larger values will mean less context is - preserved, and shorter value will require extra computations. - """ - assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration." - self.extend_stride = extend_stride - self.duration = duration - self.generation_params = { - 'use_sampling': use_sampling, - 'temp': temperature, - 'top_k': top_k, - 'top_p': top_p, - 'cfg_coef': cfg_coef, - 'two_step_cfg': two_step_cfg, - } - - def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None): - """Override the default progress callback.""" - self._progress_callback = progress_callback - - def generate_unconditional(self, num_samples: int, progress: bool = False, - return_tokens: bool = False) -> tp.Union[torch.Tensor, - tp.Tuple[torch.Tensor, torch.Tensor]]: - """Generate samples in an unconditional manner. - - Args: - num_samples (int): Number of samples to be generated. - progress (bool, optional): Flag to display progress of the generation process. Defaults to False. - """ - descriptions: tp.List[tp.Optional[str]] = [None] * num_samples - attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) - tokens = self._generate_tokens(attributes, prompt_tokens, progress) - if return_tokens: - return self.generate_audio(tokens), tokens - return self.generate_audio(tokens) - - def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \ - -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: - """Generate samples conditioned on text. - - Args: - descriptions (list of str): A list of strings used as text conditioning. - progress (bool, optional): Flag to display progress of the generation process. Defaults to False. - """ - attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) - assert prompt_tokens is None - tokens = self._generate_tokens(attributes, prompt_tokens, progress) - if return_tokens: - return self.generate_audio(tokens), tokens - return self.generate_audio(tokens) - - def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType, - melody_sample_rate: int, progress: bool = False, - return_tokens: bool = False) -> tp.Union[torch.Tensor, - tp.Tuple[torch.Tensor, torch.Tensor]]: - """Generate samples conditioned on text and melody. - - Args: - descriptions (list of str): A list of strings used as text conditioning. - melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as - melody conditioning. Should have shape [B, C, T] with B matching the description length, - C=1 or 2. It can be [C, T] if there is a single description. It can also be - a list of [C, T] tensors. - melody_sample_rate: (int): Sample rate of the melody waveforms. - progress (bool, optional): Flag to display progress of the generation process. Defaults to False. - """ - if isinstance(melody_wavs, torch.Tensor): - if melody_wavs.dim() == 2: - melody_wavs = melody_wavs[None] - if melody_wavs.dim() != 3: - raise ValueError("Melody wavs should have a shape [B, C, T].") - melody_wavs = list(melody_wavs) - else: - for melody in melody_wavs: - if melody is not None: - assert melody.dim() == 2, "One melody in the list has the wrong number of dims." - - melody_wavs = [ - convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels) - if wav is not None else None - for wav in melody_wavs] - attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None, - melody_wavs=melody_wavs) - assert prompt_tokens is None - tokens = self._generate_tokens(attributes, prompt_tokens, progress) - if return_tokens: - return self.generate_audio(tokens), tokens - return self.generate_audio(tokens) - - def generate_with_chords(self, descriptions: tp.List[str], melody_chords: tp.Optional[tp.Union[MelodyList,tp.List[str]]] = None, - bpms: tp.Optional[tp.Union[float,int,tp.List[float],tp.List[int]]] = [120.], - meters: tp.Optional[tp.Union[float,int,tp.List[float],tp.List[int]]] = [4.], - progress: bool = False, return_tokens: bool = False) -> tp.Union[torch.Tensor, - tp.Tuple[torch.Tensor, torch.Tensor]]: - """Generate samples conditioned on text and melody. - - Args: - descriptions (list of str): A list of strings used as text conditioning. - melody_chords: (torch.Tensor or list of Tensor): A list of chords in chormagram or string type - progress (bool, optional): Flag to display progress of the generation process. Defaults to False. - """ - - if isinstance(melody_chords[0], str): - # check the bpm, meter length - if len(bpms) == 1: - bpms *= len(melody_chords) - if len(meters) == 1: - meters *= len(melody_chords) - assert len(bpms) == len(melody_chords), "bpm length is not equal to chord length" - assert len(meters) == len(melody_chords), "meter length is not equal to chord length" - # convert str to chromagram - melody_chromas = [] - for melody_chord, bpm, meter in zip(melody_chords, bpms, meters): - melody_chroma = convert_txtchord2chroma(melody_chord, bpm, meter, self.duration).permute(1,0) # [C=12, T] - melody_chromas.append(melody_chroma) - melody_chromas = torch.stack(melody_chromas, dim=0) - assert melody_chromas.dim() == 3 - melody_chords = list(melody_chromas) - else: - for melody in melody_chords: - if melody is not None: - assert melody.dim() == 2, "One melody in the list has the wrong number of dims." - - attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None, - melody_chords=melody_chords, bpms=bpms) - assert prompt_tokens is None - tokens = self._generate_tokens(attributes, prompt_tokens, progress) - if return_tokens: - return self.generate_audio(tokens), tokens - return self.generate_audio(tokens) - - def generate_with_chords_and_beats(self, descriptions: tp.List[str], melody_chords: tp.Optional[tp.Union[MelodyList,tp.List[str]]] = None, - bpms: tp.Optional[tp.Union[float,int,tp.List[float],tp.List[int]]] = [120.], - meters: tp.Optional[tp.Union[float,int,tp.List[float],tp.List[int]]] = [4.], - progress: bool = False, return_tokens: bool = False) -> tp.Union[torch.Tensor, - tp.Tuple[torch.Tensor, torch.Tensor]]: - """Generate samples conditioned on text and melody. - - Args: - descriptions (list of str): A list of strings used as text conditioning. - melody_chords: (torch.Tensor or list of Tensor): A list of chords in chormagram or string type - progress (bool, optional): Flag to display progress of the generation process. Defaults to False. - """ - - if isinstance(melody_chords[0], str): - # check the bpm, meter length - if len(bpms) == 1: - bpms *= len(melody_chords) - if len(meters) == 1: - meters *= len(melody_chords) - assert len(bpms) == len(melody_chords), "bpm length is not equal to chord length" - assert len(meters) == len(melody_chords), "meter length is not equal to chord length" - # convert str to chromagram - melody_chromas = [] - for melody_chord, bpm, meter in zip(melody_chords, bpms, meters): - melody_chroma = convert_txtchord2chroma(melody_chord, bpm, meter, self.duration).permute(1,0) # [C=24, T] - melody_chromas.append(melody_chroma) - melody_chromas = torch.stack(melody_chromas, dim=0) - assert melody_chromas.dim() == 3 - melody_chords = list(melody_chromas) - else: - for melody in melody_chords: - if melody is not None: - assert melody.dim() == 2, "One melody in the list has the wrong number of dims." - - fs = self.sample_rate / 640 - beats = [] - for bpm, meter in zip(bpms, meters): - beat = np.zeros(int(fs * self.duration)) - beat_gap = int(60 / bpm * fs) - beat[::beat_gap] = 1 - bar = np.zeros(int(fs * self.duration)) - bar[::beat_gap * meter] = 1 - kernel = np.array([0.05, 0.1, 0.3, 0.9, 0.3, 0.1, 0.05]) - beat = np.convolve(beat , kernel, 'same') - beat = beat + bar - beats.append(torch.tensor(beat).unsqueeze(0)) # [C, T] - beats = list(torch.stack(beats, dim=0)) # [B, C, T] - - attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None, - melody_chords=melody_chords, beats=beats, bpms=bpms) - assert prompt_tokens is None - tokens = self._generate_tokens(attributes, prompt_tokens, progress) - if return_tokens: - return self.generate_audio(tokens), tokens - return self.generate_audio(tokens) - - def generate_for_eval(self, descriptions: tp.List[str], melody_chords: tp.List[torch.Tensor], beats: tp.List[torch.Tensor], - bpms: tp.List[float], progress: bool = False, return_tokens: bool = False) -> tp.Union[torch.Tensor, - tp.Tuple[torch.Tensor, torch.Tensor]]: - - # assert melody_chords.dim() == 3 - # assert beats.dim() == 3 - - attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None, - melody_chords=melody_chords, beats=beats, bpms=bpms) - assert prompt_tokens is None - tokens = self._generate_tokens(attributes, prompt_tokens, progress) - if return_tokens: - return self.generate_audio(tokens), tokens - return self.generate_audio(tokens) - - - def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, - descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None, audio_channels=1, - progress: bool = False, return_tokens: bool = False) \ - -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: - """Generate samples conditioned on audio prompts. - - Args: - prompt (torch.Tensor): A batch of waveforms used for continuation. - Prompt should be [B, C, T], or [C, T] if only one sample is generated. - prompt_sample_rate (int): Sampling rate of the given audio waveforms. - descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None. - progress (bool, optional): Flag to display progress of the generation process. Defaults to False. - """ - if prompt.dim() == 2: - prompt = prompt[None] - if prompt.dim() != 3: - raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).") - prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, audio_channels) - if descriptions is None: - descriptions = [None] * len(prompt) - attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt) - assert prompt_tokens is not None - tokens = self._generate_tokens(attributes, prompt_tokens, progress) - if return_tokens: - return self.generate_audio(tokens), tokens - return self.generate_audio(tokens) - - @torch.no_grad() - def _prepare_tokens_and_attributes( - self, - descriptions: tp.Sequence[tp.Optional[str]], - prompt: tp.Optional[torch.Tensor], - melody_wavs: tp.Optional[MelodyList] = None, - melody_chords: tp.Optional[MelodyList] = None, - beats : tp.Optional[MelodyList] = None, - bpms : tp.Optional[list] = None, - ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]: - """Prepare model inputs. - - Args: - descriptions (list of str): A list of strings used as text conditioning. - prompt (torch.Tensor): A batch of waveforms used for continuation. - melody_wavs (torch.Tensor, optional): A batch of waveforms - used as melody conditioning. Defaults to None. - """ - attributes = [ - ConditioningAttributes(text={'description': description}) - for description in descriptions] - - if melody_wavs is None: - for attr in attributes: - attr.wav['self_wav'] = WavCondition( - torch.zeros((1, 1, 1), device=self.device), - torch.tensor([0], device=self.device), - sample_rate=[self.sample_rate], - path=[None]) - else: - if 'self_wav' not in self.lm.condition_provider.conditioners: - raise RuntimeError("This model doesn't support melody conditioning. " - "Use the `melody` model.") - assert len(melody_wavs) == len(descriptions), \ - f"number of melody wavs must match number of descriptions! " \ - f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}" - for attr, melody in zip(attributes, melody_wavs): - if melody is None: - attr.wav['self_wav'] = WavCondition( - torch.zeros((1, 1, 1), device=self.device), - torch.tensor([0], device=self.device), - sample_rate=[self.sample_rate], - path=[None]) - else: - attr.wav['self_wav'] = WavCondition( - melody[None].to(device=self.device), - torch.tensor([melody.shape[-1]], device=self.device), - sample_rate=[self.sample_rate], - path=[None], - ) - - if melody_chords is None: - for attr in attributes: - attr.chord['chord'] = ChordCondition( - torch.zeros((1, 12, 1), device=self.device), - torch.tensor([0], device=self.device), - bpm=[None], - path=[None]) - else: - # if 'chord' not in self.lm.condition_provider.conditioners: - # raise RuntimeError("This model doesn't support chord conditioning. " - # "Use the `chord` model.") - assert len(melody_chords) == len(descriptions), \ - f"number of melody_chords must match number of descriptions! " \ - f"got melody len={len(melody_chords)}, and descriptions len={len(descriptions)}" - for attr, chord, bpm in zip(attributes, melody_chords, bpms): - if chord is None: - attr.chord['chord'] = ChordCondition( - torch.zeros((1, 1, 1), device=self.device), - torch.tensor([0], device=self.device), - bpm=[None], - path=[None]) - else: - attr.chord['chord'] = ChordCondition( - chord[None].to(device=self.device), - torch.tensor([chord.shape[-1]], device=self.device), - bpm=[bpm], - path=[None], - ) - - if beats is None: - for attr in attributes: - attr.beat['beat'] = BeatCondition( - torch.zeros((1, 1, 1), device=self.device), - torch.tensor([0], device=self.device), - bpm=[None], - path=[None]) - else: - # if 'beat' not in self.lm.condition_provider.conditioners: - # raise RuntimeError("This model doesn't support beat conditioning. " - # "Use the `beat` model.") - assert len(beats) == len(descriptions), \ - f"number of beats must match number of descriptions! " \ - f"got melody len={len(beats)}, and descriptions len={len(descriptions)}" - for attr, beat, bpm in zip(attributes, beats, bpms): - if beat is None: - attr.beat['beat'] = BeatCondition( - torch.zeros((1, 1, 1), device=self.device), - torch.tensor([0], device=self.device), - bpm=[None], - path=[None]) - else: - attr.beat['beat'] = BeatCondition( - beat[None].to(device=self.device), - torch.tensor([beat.shape[-1]], device=self.device), - bpm=[bpm], - path=[None], - ) - - if prompt is not None: - if descriptions is not None: - assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match" - prompt = prompt.to(self.device) - prompt_tokens, scale = self.compression_model.encode(prompt) - assert scale is None - else: - prompt_tokens = None - return attributes, prompt_tokens - - def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], - prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor: - """Generate discrete audio tokens given audio prompt and/or conditions. - - Args: - attributes (list of ConditioningAttributes): Conditions used for generation (text/melody). - prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation. - progress (bool, optional): Flag to display progress of the generation process. Defaults to False. - Returns: - torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params. - """ - total_gen_len = int(self.duration * self.frame_rate) - max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate) - current_gen_offset: int = 0 - - def _progress_callback(generated_tokens: int, tokens_to_generate: int): - generated_tokens += current_gen_offset - if self._progress_callback is not None: - # Note that total_gen_len might be quite wrong depending on the - # codebook pattern used, but with delay it is almost accurate. - self._progress_callback(generated_tokens, total_gen_len) - else: - print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r') - - if prompt_tokens is not None: - assert max_prompt_len >= prompt_tokens.shape[-1], \ - "Prompt is longer than audio to generate" - - callback = None - if progress: - callback = _progress_callback - - if self.duration <= self.max_duration: - # generate by sampling from LM, simple case. - with self.autocast: - gen_tokens = self.lm.generate( - prompt_tokens, attributes, - callback=callback, max_gen_len=total_gen_len, **self.generation_params) - - else: - # now this gets a bit messier, we need to handle prompts, - # melody conditioning etc. - ref_wavs = [attr.wav['self_wav'] for attr in attributes] - all_tokens = [] - if prompt_tokens is None: - prompt_length = 0 - else: - all_tokens.append(prompt_tokens) - prompt_length = prompt_tokens.shape[-1] - - stride_tokens = int(self.frame_rate * self.extend_stride) - - while current_gen_offset + prompt_length < total_gen_len: - time_offset = current_gen_offset / self.frame_rate - chunk_duration = min(self.duration - time_offset, self.max_duration) - max_gen_len = int(chunk_duration * self.frame_rate) - for attr, ref_wav in zip(attributes, ref_wavs): - wav_length = ref_wav.length.item() - if wav_length == 0: - continue - # We will extend the wav periodically if it not long enough. - # we have to do it here rather than in conditioners.py as otherwise - # we wouldn't have the full wav. - initial_position = int(time_offset * self.sample_rate) - wav_target_length = int(self.max_duration * self.sample_rate) - positions = torch.arange(initial_position, - initial_position + wav_target_length, device=self.device) - attr.wav['self_wav'] = WavCondition( - ref_wav[0][..., positions % wav_length], - torch.full_like(ref_wav[1], wav_target_length), - [self.sample_rate] * ref_wav[0].size(0), - [None], [0.]) - with self.autocast: - gen_tokens = self.lm.generate( - prompt_tokens, attributes, - callback=callback, max_gen_len=max_gen_len, **self.generation_params) - if prompt_tokens is None: - all_tokens.append(gen_tokens) - else: - all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:]) - prompt_tokens = gen_tokens[:, :, stride_tokens:] - prompt_length = prompt_tokens.shape[-1] - current_gen_offset += stride_tokens - - gen_tokens = torch.cat(all_tokens, dim=-1) - return gen_tokens - - def generate_audio(self, gen_tokens: torch.Tensor): - """Generate Audio from tokens""" - assert gen_tokens.dim() == 3 - with torch.no_grad(): - n_channel = gen_tokens.shape[1] - gen_audio = self.compression_model.decode(gen_tokens, None) - return gen_audio diff --git a/audiocraft/audiocraft/models/unet.py b/audiocraft/audiocraft/models/unet.py deleted file mode 100644 index db4a6df8e309c21fede37abdbe3c862932027641..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/models/unet.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Pytorch Unet Module used for diffusion. -""" - -from dataclasses import dataclass -import typing as tp - -import torch -from torch import nn -from torch.nn import functional as F -from audiocraft.modules.transformer import StreamingTransformer, create_sin_embedding - - -@dataclass -class Output: - sample: torch.Tensor - - -def get_model(cfg, channels: int, side: int, num_steps: int): - if cfg.model == 'unet': - return DiffusionUnet( - chin=channels, num_steps=num_steps, **cfg.diffusion_unet) - else: - raise RuntimeError('Not Implemented') - - -class ResBlock(nn.Module): - def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4, - dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, - dropout: float = 0.): - super().__init__() - stride = 1 - padding = dilation * (kernel - stride) // 2 - Conv = nn.Conv1d - Drop = nn.Dropout1d - self.norm1 = nn.GroupNorm(norm_groups, channels) - self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation) - self.activation1 = activation() - self.dropout1 = Drop(dropout) - - self.norm2 = nn.GroupNorm(norm_groups, channels) - self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation) - self.activation2 = activation() - self.dropout2 = Drop(dropout) - - def forward(self, x): - h = self.dropout1(self.conv1(self.activation1(self.norm1(x)))) - h = self.dropout2(self.conv2(self.activation2(self.norm2(h)))) - return x + h - - -class DecoderLayer(nn.Module): - def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2, - norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, - dropout: float = 0.): - super().__init__() - padding = (kernel - stride) // 2 - self.res_blocks = nn.Sequential( - *[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout) - for idx in range(res_blocks)]) - self.norm = nn.GroupNorm(norm_groups, chin) - ConvTr = nn.ConvTranspose1d - self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False) - self.activation = activation() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.res_blocks(x) - x = self.norm(x) - x = self.activation(x) - x = self.convtr(x) - return x - - -class EncoderLayer(nn.Module): - def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2, - norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, - dropout: float = 0.): - super().__init__() - padding = (kernel - stride) // 2 - Conv = nn.Conv1d - self.conv = Conv(chin, chout, kernel, stride, padding, bias=False) - self.norm = nn.GroupNorm(norm_groups, chout) - self.activation = activation() - self.res_blocks = nn.Sequential( - *[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout) - for idx in range(res_blocks)]) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - B, C, T = x.shape - stride, = self.conv.stride - pad = (stride - (T % stride)) % stride - x = F.pad(x, (0, pad)) - - x = self.conv(x) - x = self.norm(x) - x = self.activation(x) - x = self.res_blocks(x) - return x - - -class BLSTM(nn.Module): - """BiLSTM with same hidden units as input dim. - """ - def __init__(self, dim, layers=2): - super().__init__() - self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) - self.linear = nn.Linear(2 * dim, dim) - - def forward(self, x): - x = x.permute(2, 0, 1) - x = self.lstm(x)[0] - x = self.linear(x) - x = x.permute(1, 2, 0) - return x - - -class DiffusionUnet(nn.Module): - def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2., - max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False, - bilstm: bool = False, transformer: bool = False, - codec_dim: tp.Optional[int] = None, **kwargs): - super().__init__() - self.encoders = nn.ModuleList() - self.decoders = nn.ModuleList() - self.embeddings: tp.Optional[nn.ModuleList] = None - self.embedding = nn.Embedding(num_steps, hidden) - if emb_all_layers: - self.embeddings = nn.ModuleList() - self.condition_embedding: tp.Optional[nn.Module] = None - for d in range(depth): - encoder = EncoderLayer(chin, hidden, **kwargs) - decoder = DecoderLayer(hidden, chin, **kwargs) - self.encoders.append(encoder) - self.decoders.insert(0, decoder) - if emb_all_layers and d > 0: - assert self.embeddings is not None - self.embeddings.append(nn.Embedding(num_steps, hidden)) - chin = hidden - hidden = min(int(chin * growth), max_channels) - self.bilstm: tp.Optional[nn.Module] - if bilstm: - self.bilstm = BLSTM(chin) - else: - self.bilstm = None - self.use_transformer = transformer - self.cross_attention = False - if transformer: - self.cross_attention = cross_attention - self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False, - cross_attention=cross_attention) - - self.use_codec = False - if codec_dim is not None: - self.conv_codec = nn.Conv1d(codec_dim, chin, 1) - self.use_codec = True - - def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None): - skips = [] - bs = x.size(0) - z = x - view_args = [1] - if type(step) is torch.Tensor: - step_tensor = step - else: - step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs) - - for idx, encoder in enumerate(self.encoders): - z = encoder(z) - if idx == 0: - z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z) - elif self.embeddings is not None: - z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z) - - skips.append(z) - - if self.use_codec: # insert condition in the bottleneck - assert condition is not None, "Model defined for conditionnal generation" - condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim - assert condition_emb.size(-1) <= 2 * z.size(-1), \ - f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}" - if not self.cross_attention: - - condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1)) - assert z.size() == condition_emb.size() - z += condition_emb - cross_attention_src = None - else: - cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C - B, T, C = cross_attention_src.shape - positions = torch.arange(T, device=x.device).view(1, -1, 1) - pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype) - cross_attention_src = cross_attention_src + pos_emb - if self.use_transformer: - z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1) - else: - if self.bilstm is None: - z = torch.zeros_like(z) - else: - z = self.bilstm(z) - - for decoder in self.decoders: - s = skips.pop(-1) - z = z[:, :, :s.shape[2]] - z = z + s - z = decoder(z) - - z = z[:, :, :x.shape[2]] - return Output(z) diff --git a/audiocraft/audiocraft/modules/__init__.py b/audiocraft/audiocraft/modules/__init__.py deleted file mode 100644 index 61418616ef18f0ecca56a007c43af4a731d98b9b..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/modules/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Modules used for building the models.""" - -# flake8: noqa -from .conv import ( - NormConv1d, - NormConv2d, - NormConvTranspose1d, - NormConvTranspose2d, - StreamableConv1d, - StreamableConvTranspose1d, - pad_for_conv1d, - pad1d, - unpad1d, -) -from .lstm import StreamableLSTM -from .seanet import SEANetEncoder, SEANetDecoder -from .transformer import StreamingTransformer \ No newline at end of file diff --git a/audiocraft/audiocraft/modules/__pycache__/__init__.cpython-311.pyc b/audiocraft/audiocraft/modules/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 62712228f5a9fe15ea967b1a9c293231e2e3d057..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/modules/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/modules/__pycache__/activations.cpython-311.pyc b/audiocraft/audiocraft/modules/__pycache__/activations.cpython-311.pyc deleted file mode 100644 index 06b1745b1b13769b4b05528223029cf3ada9324a..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/modules/__pycache__/activations.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/modules/__pycache__/chroma.cpython-311.pyc b/audiocraft/audiocraft/modules/__pycache__/chroma.cpython-311.pyc deleted file mode 100644 index 6da8ea24700237cd7d5dd5f0c80e836749b01202..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/modules/__pycache__/chroma.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/modules/__pycache__/codebooks_patterns.cpython-311.pyc b/audiocraft/audiocraft/modules/__pycache__/codebooks_patterns.cpython-311.pyc deleted file mode 100644 index 7ff99fa78f252b1766c85e2bc8f41e630b5c3183..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/modules/__pycache__/codebooks_patterns.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/modules/__pycache__/conditioners.cpython-311.pyc b/audiocraft/audiocraft/modules/__pycache__/conditioners.cpython-311.pyc deleted file mode 100644 index d913486a7d5fc0f4c01ee66fad77e707e3aed0c1..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/modules/__pycache__/conditioners.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/modules/__pycache__/conv.cpython-311.pyc b/audiocraft/audiocraft/modules/__pycache__/conv.cpython-311.pyc deleted file mode 100644 index e7ee0d9c7ad787a81d18ce7a0ce7f259f6280e4e..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/modules/__pycache__/conv.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/modules/__pycache__/diffusion_schedule.cpython-311.pyc b/audiocraft/audiocraft/modules/__pycache__/diffusion_schedule.cpython-311.pyc deleted file mode 100644 index 8afe3f7fc4cd9a7b52d94a15695bc987a810a0ce..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/modules/__pycache__/diffusion_schedule.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/modules/__pycache__/lstm.cpython-311.pyc b/audiocraft/audiocraft/modules/__pycache__/lstm.cpython-311.pyc deleted file mode 100644 index 4c7c5865c61429551ca24a3011a37a9dc72de1a9..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/modules/__pycache__/lstm.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/modules/__pycache__/rope.cpython-311.pyc b/audiocraft/audiocraft/modules/__pycache__/rope.cpython-311.pyc deleted file mode 100644 index dff1b5155c978387270255eec10d6635362e69c1..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/modules/__pycache__/rope.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/modules/__pycache__/seanet.cpython-311.pyc b/audiocraft/audiocraft/modules/__pycache__/seanet.cpython-311.pyc deleted file mode 100644 index b725c6e7cfbf4070315262f6f54e11a16e0a7c4e..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/modules/__pycache__/seanet.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/modules/__pycache__/streaming.cpython-311.pyc b/audiocraft/audiocraft/modules/__pycache__/streaming.cpython-311.pyc deleted file mode 100644 index b737da74c537c940197ee703224e0f04d60f8853..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/modules/__pycache__/streaming.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/modules/__pycache__/transformer.cpython-311.pyc b/audiocraft/audiocraft/modules/__pycache__/transformer.cpython-311.pyc deleted file mode 100644 index 209eb91453bc7c634746352dfb78521e8b5574fe..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/modules/__pycache__/transformer.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/modules/activations.py b/audiocraft/audiocraft/modules/activations.py deleted file mode 100644 index 2d83d7c4c2dc84c64b724eadbe06157507d4f20d..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/modules/activations.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -from torch import Tensor -from typing import Union, Callable - - -class CustomGLU(nn.Module): - """Custom Gated Linear Unit activation. - Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half - of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation - function (i.e. sigmoid, swish, etc.). - - Args: - activation (nn.Module): The custom activation to apply in the Gated Linear Unit - dim (int): the dimension on which to split the input. Default: -1 - - Shape: - - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional - dimensions - - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` - - Examples:: - >>> m = CustomGLU(nn.Sigmoid()) - >>> input = torch.randn(4, 2) - >>> output = m(input) - """ - def __init__(self, activation: nn.Module, dim: int = -1): - super(CustomGLU, self).__init__() - self.dim = dim - self.activation = activation - - def forward(self, x: Tensor): - assert x.shape[self.dim] % 2 == 0 # M = N / 2 - a, b = torch.chunk(x, 2, dim=self.dim) - return a * self.activation(b) - - -class SwiGLU(CustomGLU): - """SiLU Gated Linear Unit activation. - Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is - the first half of the input matrices, :math:`b` is the second half. - - Args: - dim (int): the dimension on which to split the input. Default: -1 - """ - def __init__(self, dim: int = -1): - super(SwiGLU, self).__init__(nn.SiLU(), dim) - - -class GeGLU(CustomGLU): - """GeLU Gated Linear Unit activation. - Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is - the first half of the input matrices, :math:`b` is the second half. - - Args: - dim (int): the dimension on which to split the input. Default: -1 - """ - def __init__(self, dim: int = -1): - super(GeGLU, self).__init__(nn.GELU(), dim) - - -class ReGLU(CustomGLU): - """ReLU Gated Linear Unit activation. - Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is - the first half of the input matrices, :math:`b` is the second half. - - Args: - dim (int): the dimension on which to split the input. Default: -1 - """ - def __init__(self, dim: int = -1): - super(ReGLU, self).__init__(nn.ReLU(), dim) - - -def get_activation_fn( - activation: Union[str, Callable[[Tensor], Tensor]] -) -> Union[str, Callable[[Tensor], Tensor]]: - """Helper function to map an activation string to the activation class. - If the supplied activation is not a string that is recognized, the activation is passed back. - - Args: - activation (str, or Callable[[Tensor], Tensor]): Activation to check - """ - if isinstance(activation, str): - if activation == "reglu": - return ReGLU() - elif activation == "geglu": - return GeGLU() - elif activation == "swiglu": - return SwiGLU() - return activation diff --git a/audiocraft/audiocraft/modules/chroma.py b/audiocraft/audiocraft/modules/chroma.py deleted file mode 100644 index e84fb66b4a4aaefb0b3ccac8a9a44c3b20e48f61..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/modules/chroma.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -import typing as tp - -from einops import rearrange -from librosa import filters -import torch -from torch import nn -import torch.nn.functional as F -import torchaudio - - -class ChromaExtractor(nn.Module): - """Chroma extraction and quantization. - - Args: - sample_rate (int): Sample rate for the chroma extraction. - n_chroma (int): Number of chroma bins for the chroma extraction. - radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12). - nfft (int, optional): Number of FFT. - winlen (int, optional): Window length. - winhop (int, optional): Window hop size. - argmax (bool, optional): Whether to use argmax. Defaults to False. - norm (float, optional): Norm for chroma normalization. Defaults to inf. - """ - def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None, - winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False, - norm: float = torch.inf): - super().__init__() - self.winlen = winlen or 2 ** radix2_exp - self.nfft = nfft or self.winlen - self.winhop = winhop or (self.winlen // 4) - self.sample_rate = sample_rate - self.n_chroma = n_chroma - self.norm = norm - self.argmax = argmax - self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0, - n_chroma=self.n_chroma)), persistent=False) - self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen, - hop_length=self.winhop, power=2, center=True, - pad=0, normalized=True) - - def forward(self, wav: torch.Tensor) -> torch.Tensor: - T = wav.shape[-1] - # in case we are getting a wav that was dropped out (nullified) - # from the conditioner, make sure wav length is no less that nfft - if T < self.nfft: - pad = self.nfft - T - r = 0 if pad % 2 == 0 else 1 - wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0) - assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}" - - spec = self.spec(wav).squeeze(1) - raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec) - norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6) - norm_chroma = rearrange(norm_chroma, 'b d t -> b t d') - - if self.argmax: - idx = norm_chroma.argmax(-1, keepdim=True) - norm_chroma[:] = 0 - norm_chroma.scatter_(dim=-1, index=idx, value=1) - - return norm_chroma diff --git a/audiocraft/audiocraft/modules/codebooks_patterns.py b/audiocraft/audiocraft/modules/codebooks_patterns.py deleted file mode 100644 index 1bfc767dce8d804dd1058a92924713af599be808..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/modules/codebooks_patterns.py +++ /dev/null @@ -1,542 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from collections import namedtuple -from dataclasses import dataclass -from functools import lru_cache -import logging -import typing as tp - -from abc import ABC, abstractmethod -import torch - -LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index) -PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates -logger = logging.getLogger(__name__) - - -@dataclass -class Pattern: - """Base implementation of a pattern over a sequence with multiple codebooks. - - The codebook pattern consists in a layout, defining for each sequence step - the list of coordinates of each codebook timestep in the resulting interleaved sequence. - The first item of the pattern is always an empty list in order to properly insert a special token - to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern - and ``timesteps`` the number of timesteps corresponding to the original sequence. - - The pattern provides convenient methods to build and revert interleaved sequences from it: - ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T] - to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size, - K being the number of codebooks, T the number of original timesteps and S the number of sequence steps - for the output sequence. The unfilled positions are replaced with a special token and the built sequence - is returned along with a mask indicating valid tokens. - ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment - of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask - to fill and specify invalid positions if needed. - See the dedicated methods for more details. - """ - # Pattern layout, for each sequence step, we have a list of coordinates - # corresponding to the original codebook timestep and position. - # The first list is always an empty list in order to properly insert - # a special token to start with. - layout: PatternLayout - timesteps: int - n_q: int - - def __post_init__(self): - assert len(self.layout) > 0 - assert self.layout[0] == [] - self._validate_layout() - self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes) - self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes) - logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout)) - - def _validate_layout(self): - """Runs checks on the layout to ensure a valid pattern is defined. - A pattern is considered invalid if: - - Multiple timesteps for a same codebook are defined in the same sequence step - - The timesteps for a given codebook are not in ascending order as we advance in the sequence - (this would mean that we have future timesteps before past timesteps). - """ - q_timesteps = {q: 0 for q in range(self.n_q)} - for s, seq_coords in enumerate(self.layout): - if len(seq_coords) > 0: - qs = set() - for coord in seq_coords: - qs.add(coord.q) - last_q_timestep = q_timesteps[coord.q] - assert coord.t >= last_q_timestep, \ - f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}" - q_timesteps[coord.q] = coord.t - # each sequence step contains at max 1 coordinate per codebook - assert len(qs) == len(seq_coords), \ - f"Multiple entries for a same codebook are found at step {s}" - - @property - def num_sequence_steps(self): - return len(self.layout) - 1 - - @property - def max_delay(self): - max_t_in_seq_coords = 0 - for seq_coords in self.layout[1:]: - for coords in seq_coords: - max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1) - return max_t_in_seq_coords - self.timesteps - - @property - def valid_layout(self): - valid_step = len(self.layout) - self.max_delay - return self.layout[:valid_step] - - def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None): - """Get codebook coordinates in the layout that corresponds to the specified timestep t - and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step - and the actual codebook coordinates. - """ - assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps" - if q is not None: - assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks" - coords = [] - for s, seq_codes in enumerate(self.layout): - for code in seq_codes: - if code.t == t and (q is None or code.q == q): - coords.append((s, code)) - return coords - - def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]: - return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)] - - def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]: - steps_with_timesteps = self.get_steps_with_timestep(t, q) - return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None - - def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool, - device: tp.Union[torch.device, str] = 'cpu'): - """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps. - - Args: - timesteps (int): Maximum number of timesteps steps to consider. - keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps. - device (torch.device or str): Device for created tensors. - Returns: - indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S]. - mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S]. - """ - # assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" - assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern" - # use the proper layout based on whether we limit ourselves to valid steps only or not, - # note that using the valid_layout will result in a truncated sequence up to the valid steps - ref_layout = self.valid_layout if keep_only_valid_steps else self.layout - # single item indexing being super slow with pytorch vs. numpy, so we use numpy here - indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy() - mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy() - # fill indexes with last sequence step value that will correspond to our special token - # the last value is n_q * timesteps as we have flattened z and append special token as the last token - # which will correspond to the index: n_q * timesteps - indexes[:] = n_q * timesteps - # iterate over the pattern and fill scattered indexes and mask - for s, sequence_coords in enumerate(ref_layout): - for coords in sequence_coords: - if coords.t < timesteps: - indexes[coords.q, s] = coords.t + coords.q * timesteps - mask[coords.q, s] = 1 - indexes = torch.from_numpy(indexes).to(device) - mask = torch.from_numpy(mask).to(device) - return indexes, mask - - def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): - """Build sequence corresponding to the pattern from the input tensor z. - The sequence is built using up to sequence_steps if specified, and non-pattern - coordinates are filled with the special token. - - Args: - z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T]. - special_token (int): Special token used to fill non-pattern coordinates in the new sequence. - keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. - Steps that are beyond valid steps will be replaced by the special_token in that case. - Returns: - values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S - corresponding either to the sequence_steps if provided, otherwise to the length of the pattern. - indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S]. - mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S]. - """ - B, K, T = z.shape - indexes, mask = self._build_pattern_sequence_scatter_indexes( - T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device) - ) - z = z.view(B, -1) - # we append the special token as the last index of our flattened z tensor - z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1) - values = z[:, indexes.view(-1)] - values = values.view(B, K, indexes.shape[-1]) - return values, indexes, mask - - def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int, - keep_only_valid_steps: bool = False, - is_model_output: bool = False, - device: tp.Union[torch.device, str] = 'cpu'): - """Builds scatter indexes required to retrieve the original multi-codebook sequence - from interleaving pattern. - - Args: - sequence_steps (int): Sequence steps. - n_q (int): Number of codebooks. - keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. - Steps that are beyond valid steps will be replaced by the special_token in that case. - is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not. - device (torch.device or str): Device for created tensors. - Returns: - indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T]. - mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. - """ - ref_layout = self.valid_layout if keep_only_valid_steps else self.layout - # TODO(jade): Do we want to further truncate to only valid timesteps here as well? - timesteps = self.timesteps - #assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" - assert sequence_steps <= len(ref_layout), \ - f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}" - - # ensure we take the appropriate indexes to keep the model output from the first special token as well - if is_model_output: - ref_layout = ref_layout[1:] - - # single item indexing being super slow with pytorch vs. numpy, so we use numpy here - indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy() - mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy() - # fill indexes with last sequence step value that will correspond to our special token - indexes[:] = n_q * sequence_steps - for s, sequence_codes in enumerate(ref_layout): - if s < sequence_steps: - for code in sequence_codes: - if code.t < timesteps: - indexes[code.q, code.t] = s + code.q * sequence_steps - mask[code.q, code.t] = 1 - indexes = torch.from_numpy(indexes).to(device) - mask = torch.from_numpy(mask).to(device) - return indexes, mask - - def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): - """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving. - The sequence is reverted using up to timesteps if specified, and non-pattern coordinates - are filled with the special token. - - Args: - s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S]. - special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence. - Returns: - values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T - corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise. - indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T]. - mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. - """ - B, K, S = s.shape - indexes, mask = self._build_reverted_sequence_scatter_indexes( - S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device) - ) - s = s.view(B, -1) - # we append the special token as the last index of our flattened z tensor - s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1) - values = s[:, indexes.view(-1)] - values = values.view(B, K, indexes.shape[-1]) - return values, indexes, mask - - def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False): - """Revert model logits obtained on a sequence built from the pattern - back to a tensor matching the original sequence. - - This method is similar to ``revert_pattern_sequence`` with the following specificities: - 1. It is designed to work with the extra cardinality dimension - 2. We return the logits for the first sequence item that matches the special_token and - which matching target in the original sequence is the first item of the sequence, - while we skip the last logits as there is no matching target - """ - B, card, K, S = logits.shape - indexes, mask = self._build_reverted_sequence_scatter_indexes( - S, K, keep_only_valid_steps, is_model_output=True, device=logits.device - ) - logits = logits.reshape(B, card, -1) - # we append the special token as the last index of our flattened z tensor - logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S] - values = logits[:, :, indexes.view(-1)] - values = values.view(B, card, K, indexes.shape[-1]) - return values, indexes, mask - - -class CodebooksPatternProvider(ABC): - """Abstraction around providing pattern for interleaving codebooks. - - The CodebooksPatternProvider abstraction allows to implement various strategies to - define interleaving pattern of sequences composed of multiple codebooks. For a given - number of codebooks `n_q`, the pattern provider can generate a specified pattern - corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern - can be used to construct a new sequence from the original codes respecting the specified - pattern. The pattern is defined as a list of list of code coordinates, code coordinate - being a tuple with the original timestep and codebook to build the new sequence. - Note that all patterns must start with an empty list that is then used to insert a first - sequence step of special tokens in the newly generated sequence. - - Args: - n_q (int): number of codebooks. - cached (bool): if True, patterns for a given length are cached. In general - that should be true for efficiency reason to avoid synchronization points. - """ - def __init__(self, n_q: int, cached: bool = True, stereo: bool = False): - assert n_q > 0 - if stereo: - self.n_q = n_q // 2 - else: - self.n_q = n_q - self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore - - @abstractmethod - def get_pattern(self, timesteps: int) -> Pattern: - """Builds pattern with specific interleaving between codebooks. - - Args: - timesteps (int): Total number of timesteps. - """ - raise NotImplementedError() - - -class DelayedPatternProvider(CodebooksPatternProvider): - """Provider for delayed pattern across delayed codebooks. - Codebooks are delayed in the sequence and sequence steps will contain codebooks - from different timesteps. - - Example: - Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence: - [[1, 2, 3, 4], - [1, 2, 3, 4], - [1, 2, 3, 4]] - The resulting sequence obtained from the returned pattern is: - [[S, 1, 2, 3, 4], - [S, S, 1, 2, 3], - [S, S, S, 1, 2]] - (with S being a special token) - - Args: - n_q (int): Number of codebooks. - delays (list of int, optional): Delay for each of the codebooks. - If delays not defined, each codebook is delayed by 1 compared to the previous one. - flatten_first (int): Flatten the first N timesteps. - empty_initial (int): Prepend with N empty list of coordinates. - """ - def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None, - flatten_first: int = 0, empty_initial: int = 0): - super().__init__(n_q) - if delays is None: - delays = list(range(n_q)) - self.delays = delays - self.flatten_first = flatten_first - self.empty_initial = empty_initial - # assert len(self.delays) == self.n_q - assert sorted(self.delays) == self.delays - - def get_pattern(self, timesteps: int) -> Pattern: - out: PatternLayout = [[]] - max_delay = max(self.delays) - if self.empty_initial: - out += [[] for _ in range(self.empty_initial)] - if self.flatten_first: - for t in range(min(timesteps, self.flatten_first)): - for q in range(self.n_q): - out.append([LayoutCoord(t, q)]) - for t in range(self.flatten_first, timesteps + max_delay): - v = [] - for q, delay in enumerate(self.delays): - t_for_q = t - delay - if t_for_q >= self.flatten_first: - v.append(LayoutCoord(t_for_q, q)) - out.append(v) - return Pattern(out, n_q=self.n_q, timesteps=timesteps) - - -class ParallelPatternProvider(DelayedPatternProvider): - """Provider for parallel pattern across codebooks. - This pattern provider is a special case of the delayed pattern with actually no delay, - hence delays=repeat(0, n_q). - - Args: - n_q (int): Number of codebooks. - """ - def __init__(self, n_q: int): - super().__init__(n_q, [0] * n_q) - - -class UnrolledPatternProvider(CodebooksPatternProvider): - """Provider for unrolling codebooks pattern. - This pattern provider enables to represent the codebook flattened completely or only to some extend - while also specifying a given delay between the flattened codebooks representation, allowing to - unroll the codebooks in the sequence. - - Example: - 1. Flattening of the codebooks. - By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q), - taking n_q = 3 and timesteps = 4: - [[1, 2, 3, 4], - [1, 2, 3, 4], - [1, 2, 3, 4]] - will result into: - [[S, S, 1, S, S, 2, S, S, 3, S, S, 4], - [S, 1, S, S, 2, S, S, 3, S, S, 4, S], - [1, S, S, 2, S, S, 3, S, S, 4, S, S]] - 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step - for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example - taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]: - [[1, 2, 3, 4], - [1, 2, 3, 4], - [1, 2, 3, 4]] - will result into: - [[S, 1, S, S, 2, S, S, 3, S, S, 4, S], - [S, 1, S, S, 2, S, S, 3, S, S, 4, S], - [1, S, S, 2, S, S, 3, S, S, 4, S, S]] - 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks - allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the - same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1] - and delays = [0, 3, 3]: - [[1, 2, 3, 4], - [1, 2, 3, 4], - [1, 2, 3, 4]] - will result into: - [[S, S, S, 1, S, 2, S, 3, S, 4], - [S, S, S, 1, S, 2, S, 3, S, 4], - [1, 2, 3, S, 4, S, 5, S, 6, S]] - - Args: - n_q (int): Number of codebooks. - flattening (list of int, optional): Flattening schema over the codebooks. If not defined, - the codebooks will be flattened to 1 codebook per step, meaning that the sequence will - have n_q extra steps for each timestep. - delays (list of int, optional): Delay for each of the codebooks. If not defined, - no delay is added and therefore will default to [0] * ``n_q``. - Note that two codebooks that will be flattened to the same inner step - should have the same delay, otherwise the pattern is considered as invalid. - """ - FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay']) - - def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None, - delays: tp.Optional[tp.List[int]] = None): - super().__init__(n_q) - if flattening is None: - flattening = list(range(n_q)) - if delays is None: - delays = [0] * n_q - assert len(flattening) == n_q - assert len(delays) == n_q - assert sorted(flattening) == flattening - assert sorted(delays) == delays - self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening) - self.max_delay = max(delays) - - def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]): - """Build a flattened codebooks representation as a dictionary of inner step - and the actual codebook indices corresponding to the flattened codebook. For convenience, we - also store the delay associated to the flattened codebook to avoid maintaining an extra mapping. - """ - flattened_codebooks: dict = {} - for q, (inner_step, delay) in enumerate(zip(flattening, delays)): - if inner_step not in flattened_codebooks: - flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay) - else: - flat_codebook = flattened_codebooks[inner_step] - assert flat_codebook.delay == delay, ( - "Delay and flattening between codebooks is inconsistent: ", - "two codebooks flattened to the same position should have the same delay." - ) - flat_codebook.codebooks.append(q) - flattened_codebooks[inner_step] = flat_codebook - return flattened_codebooks - - @property - def _num_inner_steps(self): - """Number of inner steps to unroll between timesteps in order to flatten the codebooks. - """ - return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1 - - def num_virtual_steps(self, timesteps: int) -> int: - return timesteps * self._num_inner_steps + 1 - - def get_pattern(self, timesteps: int) -> Pattern: - """Builds pattern for delay across codebooks. - - Args: - timesteps (int): Total number of timesteps. - """ - # the PatternLayout is built as a tuple of sequence position and list of coordinates - # so that it can be reordered properly given the required delay between codebooks of given timesteps - indexed_out: list = [(-1, [])] - max_timesteps = timesteps + self.max_delay - for t in range(max_timesteps): - # for each timestep, we unroll the flattened codebooks, - # emitting the sequence step with the corresponding delay - for step in range(self._num_inner_steps): - if step in self._flattened_codebooks: - # we have codebooks at this virtual step to emit - step_codebooks = self._flattened_codebooks[step] - t_for_q = t + step_codebooks.delay - coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks] - if t_for_q < max_timesteps and t < max_timesteps: - indexed_out.append((t_for_q, coords)) - else: - # there is no codebook in this virtual step so we emit an empty list - indexed_out.append((t, [])) - out = [coords for _, coords in sorted(indexed_out)] - return Pattern(out, n_q=self.n_q, timesteps=timesteps) - - -class VALLEPattern(CodebooksPatternProvider): - """Almost VALL-E style pattern. - We further allow some delays for the codebooks other than the first one. - - Args: - n_q (int): Number of codebooks. - delays (list of int, optional): Delay for each of the codebooks. - If delays not defined, each codebook is delayed by 1 compared to the previous one. - """ - def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None): - super().__init__(n_q) - if delays is None: - delays = [0] * (n_q - 1) - self.delays = delays - assert len(self.delays) == self.n_q - 1 - assert sorted(self.delays) == self.delays - - def get_pattern(self, timesteps: int) -> Pattern: - out: PatternLayout = [[]] - for t in range(timesteps): - out.append([LayoutCoord(t, 0)]) - max_delay = max(self.delays) - for t in range(timesteps + max_delay): - v = [] - for q, delay in enumerate(self.delays): - t_for_q = t - delay - if t_for_q >= 0: - v.append(LayoutCoord(t_for_q, q + 1)) - out.append(v) - return Pattern(out, n_q=self.n_q, timesteps=timesteps) - - -class MusicLMPattern(CodebooksPatternProvider): - """Almost MusicLM style pattern. This is equivalent to full flattening - but in a different order. - - Args: - n_q (int): Number of codebooks. - group_by (int): Number of codebooks to group together. - """ - def __init__(self, n_q: int, group_by: int = 2): - super().__init__(n_q) - self.group_by = group_by - - def get_pattern(self, timesteps: int) -> Pattern: - out: PatternLayout = [[]] - for offset in range(0, self.n_q, self.group_by): - for t in range(timesteps): - for q in range(offset, offset + self.group_by): - out.append([LayoutCoord(t, q)]) - return Pattern(out, n_q=self.n_q, timesteps=timesteps) diff --git a/audiocraft/audiocraft/modules/conditioners.py b/audiocraft/audiocraft/modules/conditioners.py deleted file mode 100644 index 5d657979c401d806209f7f1af6df1062b7321277..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/modules/conditioners.py +++ /dev/null @@ -1,1678 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -import pretty_midi -from collections import defaultdict -from copy import deepcopy -from dataclasses import dataclass, field -from itertools import chain -import logging -import math -from pathlib import Path -import random -import re -import typing as tp -import warnings - -import einops -from num2words import num2words -import spacy -from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore -import torch -from torch import nn -import torch.nn.functional as F -from torch.nn.utils.rnn import pad_sequence - -from .chroma import ChromaExtractor -from .streaming import StreamingModule -from .transformer import create_sin_embedding -from ..data.audio import audio_read -from ..data.audio_dataset import SegmentInfo -from ..data.audio_utils import convert_audio -from ..environment import AudioCraftEnvironment -from ..quantization import ResidualVectorQuantizer -from ..utils.autocast import TorchAutocast -from ..utils.cache import EmbeddingCache -from ..utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once - - -logger = logging.getLogger(__name__) -TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist) -ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask - - -class WavCondition(tp.NamedTuple): - wav: torch.Tensor - length: torch.Tensor - sample_rate: tp.List[int] - path: tp.List[tp.Optional[str]] = [] - seek_time: tp.List[tp.Optional[float]] = [] - - -class ChordCondition(tp.NamedTuple): - chord: torch.Tensor - length: torch.Tensor - bpm: tp.List[tp.Optional[float]] = [] - path: tp.List[tp.Optional[str]] = [] - seek_frame: tp.List[tp.Optional[float]] = [] - - -class BeatCondition(tp.NamedTuple): - beat: torch.Tensor - length: torch.Tensor - bpm: tp.List[tp.Optional[float]] = [] - path: tp.List[tp.Optional[str]] = [] - seek_frame: tp.List[tp.Optional[float]] = [] - - -class JointEmbedCondition(tp.NamedTuple): - wav: torch.Tensor - text: tp.List[tp.Optional[str]] - length: torch.Tensor - sample_rate: tp.List[int] - path: tp.List[tp.Optional[str]] = [] - seek_time: tp.List[tp.Optional[float]] = [] - - -@dataclass -class ConditioningAttributes: - text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict) - wav: tp.Dict[str, WavCondition] = field(default_factory=dict) - beat: tp.Dict[str, BeatCondition] = field(default_factory=dict) - chord: tp.Dict[str, ChordCondition] = field(default_factory=dict) - joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict) - - def __getitem__(self, item): - return getattr(self, item) - - @property - def text_attributes(self): - return self.text.keys() - - @property - def wav_attributes(self): - return self.wav.keys() - - @property - def beat_attributes(self): - return self.beat.keys() - - @property - def chord_attributes(self): - return self.chord.keys() - - @property - def joint_embed_attributes(self): - return self.joint_embed.keys() - - @property - def attributes(self): - return { - "text": self.text_attributes, - "wav": self.wav_attributes, - "beat" : self.beat_attributes, - "chord": self.chord_attributes, - "joint_embed": self.joint_embed_attributes, - } - - def to_flat_dict(self): - return { - **{f"text.{k}": v for k, v in self.text.items()}, - **{f"wav.{k}": v for k, v in self.wav.items()}, - **{f"beat.{k}": v for k, v in self.beat.items()}, - **{f"chord.{k}": v for k, v in self.chord.items()}, - **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()} - } - - @classmethod - def from_flat_dict(cls, x): - out = cls() - for k, v in x.items(): - kind, att = k.split(".") - out[kind][att] = v - return out - - -class SegmentWithAttributes(SegmentInfo): - """Base class for all dataclasses that are used for conditioning. - All child classes should implement `to_condition_attributes` that converts - the existing attributes to a dataclass of type ConditioningAttributes. - """ - def to_condition_attributes(self) -> ConditioningAttributes: - raise NotImplementedError() - - -def nullify_condition(condition: ConditionType, dim: int = 1): - """Transform an input condition to a null condition. - The way it is done by converting it to a single zero vector similarly - to how it is done inside WhiteSpaceTokenizer and NoopTokenizer. - - Args: - condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor]) - dim (int): The dimension that will be truncated (should be the time dimension) - WARNING!: dim should not be the batch dimension! - Returns: - ConditionType: A tuple of null condition and mask - """ - assert dim != 0, "dim cannot be the batch dimension!" - assert isinstance(condition, tuple) and \ - isinstance(condition[0], torch.Tensor) and \ - isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!" - cond, mask = condition - B = cond.shape[0] - last_dim = cond.dim() - 1 - out = cond.transpose(dim, last_dim) - out = 0. * out[..., :1] - out = out.transpose(dim, last_dim) - mask = torch.zeros((B, 1), device=out.device).int() - assert cond.dim() == out.dim() - return out, mask - - -def nullify_wav(cond: WavCondition) -> WavCondition: - """Transform a WavCondition to a nullified WavCondition. - It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes. - - Args: - cond (WavCondition): Wav condition with wav, tensor of shape [B, T]. - Returns: - WavCondition: Nullified wav condition. - """ - null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1) - return WavCondition( - wav=null_wav, - length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device), - sample_rate=cond.sample_rate, - path=[None] * cond.wav.shape[0], - seek_time=[None] * cond.wav.shape[0], - ) - -def nullify_chord(cond: ChordCondition) -> ChordCondition: - """Transform a ChordCondition to a nullified ChordCondition. - It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes. - - Args: - cond (ChordCondition): Chord condition with chord, tensor of shape [B, C, T]. - Returns: - ChordCondition: Nullified chord condition. - """ - null_chord, _ = nullify_condition((cond.chord, torch.zeros_like(cond.chord)), dim=cond.chord.dim() - 1) - return ChordCondition( - chord=null_chord, - length=torch.tensor([0] * cond.chord.shape[0], device=cond.chord.device), - bpm=[None] * cond.chord.shape[0], - path=[None] * cond.chord.shape[0], - seek_frame=[None] * cond.chord.shape[0], - ) - - -def nullify_beat(cond: BeatCondition) -> BeatCondition: - """ - Args: - cond (ChordCondition): Chord condition with chord, tensor of shape [B, C, T]. - Returns: - ChordCondition: Nullified chord condition. - """ - null_beat, _ = nullify_condition((cond.beat, torch.zeros_like(cond.beat)), dim=cond.beat.dim() - 1) - return BeatCondition( - beat=null_beat, - length=torch.tensor([0] * cond.beat.shape[0], device=cond.beat.device), - bpm=[None] * cond.beat.shape[0], - path=[None] * cond.beat.shape[0], - seek_frame=[None] * cond.beat.shape[0], - ) - - -def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition: - """Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0, - and replacing metadata by dummy attributes. - - Args: - cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T]. - """ - null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1) - return JointEmbedCondition( - wav=null_wav, text=[None] * len(embed.text), - length=torch.LongTensor([0]).to(embed.wav.device), - sample_rate=embed.sample_rate, - path=[None] * embed.wav.shape[0], - seek_time=[0] * embed.wav.shape[0], - ) - - -class Tokenizer: - """Base tokenizer implementation - (in case we want to introduce more advances tokenizers in the future). - """ - def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: - raise NotImplementedError() - - -class WhiteSpaceTokenizer(Tokenizer): - """This tokenizer should be used for natural language descriptions. - For example: - ["he didn't, know he's going home.", 'shorter sentence'] => - [[78, 62, 31, 4, 78, 25, 19, 34], - [59, 77, 0, 0, 0, 0, 0, 0]] - """ - PUNCTUATION = "?:!.,;" - - def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm", - lemma: bool = True, stopwords: bool = True) -> None: - self.n_bins = n_bins - self.pad_idx = pad_idx - self.lemma = lemma - self.stopwords = stopwords - try: - self.nlp = spacy.load(language) - except IOError: - spacy.cli.download(language) # type: ignore - self.nlp = spacy.load(language) - - @tp.no_type_check - def __call__(self, texts: tp.List[tp.Optional[str]], - return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]: - """Take a list of strings and convert them to a tensor of indices. - - Args: - texts (list[str]): List of strings. - return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False. - Returns: - tuple[torch.Tensor, torch.Tensor]: - - Indices of words in the LUT. - - And a mask indicating where the padding tokens are - """ - output, lengths = [], [] - texts = deepcopy(texts) - for i, text in enumerate(texts): - # if current sample doesn't have a certain attribute, replace with pad token - if text is None: - output.append(torch.Tensor([self.pad_idx])) - lengths.append(0) - continue - - # convert numbers to words - text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore - # normalize text - text = self.nlp(text) # type: ignore - # remove stopwords - if self.stopwords: - text = [w for w in text if not w.is_stop] # type: ignore - # remove punctuation - text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore - # lemmatize if needed - text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore - - texts[i] = " ".join(text) - lengths.append(len(text)) - # convert to tensor - tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text]) - output.append(tokens) - - mask = length_to_mask(torch.IntTensor(lengths)).int() - padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t() - if return_text: - return padded_output, mask, texts # type: ignore - return padded_output, mask - - -class NoopTokenizer(Tokenizer): - """This tokenizer should be used for global conditioners such as: artist, genre, key, etc. - The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split - strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will - split it to ["Jeff", "Buckley"] and return an index per word. - - For example: - ["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101] - ["Metal", "Rock", "Classical"] => [0, 223, 51] - """ - def __init__(self, n_bins: int, pad_idx: int = 0): - self.n_bins = n_bins - self.pad_idx = pad_idx - - def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: - output, lengths = [], [] - for text in texts: - # if current sample doesn't have a certain attribute, replace with pad token - if text is None: - output.append(self.pad_idx) - lengths.append(0) - else: - output.append(hash_trick(text, self.n_bins)) - lengths.append(1) - - tokens = torch.LongTensor(output).unsqueeze(1) - mask = length_to_mask(torch.IntTensor(lengths)).int() - return tokens, mask - - -class BaseConditioner(nn.Module): - """Base model for all conditioner modules. - We allow the output dim to be different than the hidden dim for two reasons: - 1) keep our LUTs small when the vocab is large; - 2) make all condition dims consistent. - - Args: - dim (int): Hidden dim of the model. - output_dim (int): Output dim of the conditioner. - """ - def __init__(self, dim: int, output_dim: int): - super().__init__() - self.dim = dim - self.output_dim = output_dim - self.output_proj = nn.Linear(dim, output_dim) - - def tokenize(self, *args, **kwargs) -> tp.Any: - """Should be any part of the processing that will lead to a synchronization - point, e.g. BPE tokenization with transfer to the GPU. - - The returned value will be saved and return later when calling forward(). - """ - raise NotImplementedError() - - def forward(self, inputs: tp.Any) -> ConditionType: - """Gets input that should be used as conditioning (e.g, genre, description or a waveform). - Outputs a ConditionType, after the input data was embedded as a dense vector. - - Returns: - ConditionType: - - A tensor of size [B, T, D] where B is the batch size, T is the length of the - output embedding and D is the dimension of the embedding. - - And a mask indicating where the padding tokens. - """ - raise NotImplementedError() - - -class TextConditioner(BaseConditioner): - ... - - -class LUTConditioner(TextConditioner): - """Lookup table TextConditioner. - - Args: - n_bins (int): Number of bins. - dim (int): Hidden dim of the model (text-encoder/LUT). - output_dim (int): Output dim of the conditioner. - tokenizer (str): Name of the tokenizer. - pad_idx (int, optional): Index for padding token. Defaults to 0. - """ - def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0): - super().__init__(dim, output_dim) - self.embed = nn.Embedding(n_bins, dim) - self.tokenizer: Tokenizer - if tokenizer == 'whitespace': - self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx) - elif tokenizer == 'noop': - self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx) - else: - raise ValueError(f"unrecognized tokenizer `{tokenizer}`.") - - def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: - device = self.embed.weight.device - tokens, mask = self.tokenizer(x) - tokens, mask = tokens.to(device), mask.to(device) - return tokens, mask - - def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType: - tokens, mask = inputs - embeds = self.embed(tokens) - embeds = self.output_proj(embeds) - embeds = (embeds * mask.unsqueeze(-1)) - return embeds, mask - - -class T5Conditioner(TextConditioner): - """T5-based TextConditioner. - - Args: - name (str): Name of the T5 model. - output_dim (int): Output dim of the conditioner. - finetune (bool): Whether to fine-tune T5 at train time. - device (str): Device for T5 Conditioner. - autocast_dtype (tp.Optional[str], optional): Autocast dtype. - word_dropout (float, optional): Word dropout probability. - normalize_text (bool, optional): Whether to apply text normalization. - """ - MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", - "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", - "google/flan-t5-xl", "google/flan-t5-xxl"] - MODELS_DIMS = { - "t5-small": 512, - "t5-base": 768, - "t5-large": 1024, - "t5-3b": 1024, - "t5-11b": 1024, - "google/flan-t5-small": 512, - "google/flan-t5-base": 768, - "google/flan-t5-large": 1024, - "google/flan-t5-3b": 1024, - "google/flan-t5-11b": 1024, - } - - def __init__(self, name: str, output_dim: int, finetune: bool, device: str, - autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0., - normalize_text: bool = False): - assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})" - super().__init__(self.MODELS_DIMS[name], output_dim) - self.device = device - self.name = name - self.finetune = finetune - self.word_dropout = word_dropout - if autocast_dtype is None or self.device == 'cpu': - self.autocast = TorchAutocast(enabled=False) - if self.device != 'cpu': - logger.warning("T5 has no autocast, this might lead to NaN") - else: - dtype = getattr(torch, autocast_dtype) - assert isinstance(dtype, torch.dtype) - logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}") - self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype) - # Let's disable logging temporarily because T5 will vomit some errors otherwise. - # thanks https://gist.github.com/simon-weber/7853144 - previous_level = logging.root.manager.disable - logging.disable(logging.ERROR) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - try: - self.t5_tokenizer = T5Tokenizer.from_pretrained(name) - t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune) - finally: - logging.disable(previous_level) - if finetune: - self.t5 = t5 - else: - # this makes sure that the t5 models is not part - # of the saved checkpoint - self.__dict__['t5'] = t5.to(device) - - self.normalize_text = normalize_text - if normalize_text: - self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True) - - def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: - # if current sample doesn't have a certain attribute, replace with empty string - entries: tp.List[str] = [xi if xi is not None else "" for xi in x] - if self.normalize_text: - _, _, entries = self.text_normalizer(entries, return_text=True) - if self.word_dropout > 0. and self.training: - new_entries = [] - for entry in entries: - words = [word for word in entry.split(" ") if random.random() >= self.word_dropout] - new_entries.append(" ".join(words)) - entries = new_entries - - empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""]) - - inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device) - mask = inputs['attention_mask'] - mask[empty_idx, :] = 0 # zero-out index where the input is non-existant - return inputs - - def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType: - mask = inputs['attention_mask'] - with torch.set_grad_enabled(self.finetune), self.autocast: - embeds = self.t5(**inputs).last_hidden_state - embeds = self.output_proj(embeds.to(self.output_proj.weight)) - embeds = (embeds * mask.unsqueeze(-1)) - return embeds, mask - - -class WaveformConditioner(BaseConditioner): - """Base class for all conditioners that take a waveform as input. - Classes that inherit must implement `_get_wav_embedding` that outputs - a continuous tensor, and `_downsampling_factor` that returns the down-sampling - factor of the embedding model. - - Args: - dim (int): The internal representation dimension. - output_dim (int): Output dimension. - device (tp.Union[torch.device, str]): Device. - """ - def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]): - super().__init__(dim, output_dim) - self.device = device - - def tokenize(self, x: WavCondition) -> WavCondition: - wav, length, sample_rate, path, seek_time = x - assert length is not None - return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time) - - def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: - """Gets as input a WavCondition and returns a dense embedding.""" - raise NotImplementedError() - - def _downsampling_factor(self): - """Returns the downsampling factor of the embedding model.""" - raise NotImplementedError() - - def forward(self, x: WavCondition) -> ConditionType: - """Extract condition embedding and mask from a waveform and its metadata. - Args: - x (WavCondition): Waveform condition containing raw waveform and metadata. - Returns: - ConditionType: a dense vector representing the conditioning along with its mask - """ - wav, lengths, *_ = x - with torch.no_grad(): - embeds = self._get_wav_embedding(x) - embeds = embeds.to(self.output_proj.weight) - embeds = self.output_proj(embeds) - - if lengths is not None: - lengths = lengths / self._downsampling_factor() - mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore - else: - mask = torch.ones_like(embeds) - embeds = (embeds * mask.unsqueeze(2).to(self.device)) - - return embeds, mask - - -class ChromaStemConditioner(WaveformConditioner): - """Chroma conditioner based on stems. - The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as - the drums and bass often dominate the chroma leading to the chroma features - not containing information about the melody. - - Args: - output_dim (int): Output dimension for the conditioner. - sample_rate (int): Sample rate for the chroma extractor. - n_chroma (int): Number of chroma bins for the chroma extractor. - radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12). - duration (int): duration used during training. This is later used for correct padding - in case we are using chroma as prefix. - match_len_on_eval (bool, optional): if True then all chromas are padded to the training - duration. Defaults to False. - eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as - conditions during eval (for cases where we don't want to leak test conditions like MusicCaps). - Defaults to None. - n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0. - device (tp.Union[torch.device, str], optional): Device for the conditioner. - **kwargs: Additional parameters for the chroma extractor. - """ - def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int, - duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None, - n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None, - device: tp.Union[torch.device, str] = 'cpu', **kwargs): - from demucs import pretrained - super().__init__(dim=n_chroma, output_dim=output_dim, device=device) - self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32) - self.sample_rate = sample_rate - self.match_len_on_eval = match_len_on_eval - self.duration = duration - self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device) - stem_sources: list = self.demucs.sources # type: ignore - self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device) - self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma, - radix2_exp=radix2_exp, **kwargs).to(device) - self.chroma_len = self._get_chroma_len() - self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs) - self.cache = None - if cache_path is not None: - self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device, - compute_embed_fn=self._get_full_chroma_for_cache, - extract_embed_fn=self._extract_chroma_chunk) - - def _downsampling_factor(self) -> int: - return self.chroma.winhop - - def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]: - """Load pre-defined waveforms from a json. - These waveforms will be used for chroma extraction during evaluation. - This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps). - """ - if path is None: - return None - - logger.info(f"Loading evaluation wavs from {path}") - from audiocraft.data.audio_dataset import AudioDataset - dataset: AudioDataset = AudioDataset.from_meta( - path, segment_duration=self.duration, min_audio_duration=self.duration, - sample_rate=self.sample_rate, channels=1) - - if len(dataset) > 0: - eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device) - logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner") - return eval_wavs - else: - raise ValueError("Could not find evaluation wavs, check lengths of wavs") - - def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None: - self.eval_wavs = eval_wavs - - def has_eval_wavs(self) -> bool: - return self.eval_wavs is not None - - def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor: - """Sample wavs from a predefined list.""" - assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided." - total_eval_wavs = len(self.eval_wavs) - out = self.eval_wavs - if num_samples > total_eval_wavs: - out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1) - return out[torch.randperm(len(out))][:num_samples] - - def _get_chroma_len(self) -> int: - """Get length of chroma during training.""" - dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device) - dummy_chr = self.chroma(dummy_wav) - return dummy_chr.shape[1] - - @torch.no_grad() - def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: - """Get parts of the wav that holds the melody, extracting the main stems from the wav.""" - from demucs.apply import apply_model - from demucs.audio import convert_audio - with self.autocast: - wav = convert_audio( - wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore - stems = apply_model(self.demucs, wav, device=self.device) - stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning - mix_wav = stems.sum(1) # merge extracted stems to single waveform - mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore - return mix_wav - - @torch.no_grad() - def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor: - """Extract chroma features from the waveform.""" - with self.autocast: - return self.chroma(wav) - - @torch.no_grad() - def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: - """Compute wav embedding, applying stem and chroma extraction.""" - # avoid 0-size tensors when we are working with null conds - if wav.shape[-1] == 1: - return self._extract_chroma(wav) - stems = self._get_stemmed_wav(wav, sample_rate) - chroma = self._extract_chroma(stems) - return chroma - - @torch.no_grad() - def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor: - """Extract chroma from the whole audio waveform at the given path.""" - wav, sr = audio_read(path) - wav = wav[None].to(self.device) - wav = convert_audio(wav, sr, self.sample_rate, to_channels=1) - chroma = self._compute_wav_embedding(wav, self.sample_rate)[0] - return chroma - - def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor: - """Extract a chunk of chroma from the full chroma derived from the full waveform.""" - wav_length = x.wav.shape[-1] - seek_time = x.seek_time[idx] - assert seek_time is not None, ( - "WavCondition seek_time is required " - "when extracting chroma chunks from pre-computed chroma.") - full_chroma = full_chroma.float() - frame_rate = self.sample_rate / self._downsampling_factor() - target_length = int(frame_rate * wav_length / self.sample_rate) - index = int(frame_rate * seek_time) - out = full_chroma[index: index + target_length] - out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0] - return out.to(self.device) - - @torch.no_grad() - def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: - """Get the wav embedding from the WavCondition. - The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly - or will rely on the embedding cache to load the pre-computed embedding if relevant. - """ - sampled_wav: tp.Optional[torch.Tensor] = None - if not self.training and self.eval_wavs is not None: - warn_once(logger, "Using precomputed evaluation wavs!") - sampled_wav = self._sample_eval_wavs(len(x.wav)) - - no_undefined_paths = all(p is not None for p in x.path) - no_nullified_cond = x.wav.shape[-1] > 1 - if sampled_wav is not None: - chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate) - elif self.cache is not None and no_undefined_paths and no_nullified_cond: - paths = [Path(p) for p in x.path if p is not None] - chroma = self.cache.get_embed_from_cache(paths, x) - else: - assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal." - chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0]) - - if self.match_len_on_eval: - B, T, C = chroma.shape - if T > self.chroma_len: - chroma = chroma[:, :self.chroma_len] - logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})") - elif T < self.chroma_len: - n_repeat = int(math.ceil(self.chroma_len / T)) - chroma = chroma.repeat(1, n_repeat, 1) - chroma = chroma[:, :self.chroma_len] - logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})") - - return chroma - - def tokenize(self, x: WavCondition) -> WavCondition: - """Apply WavConditioner tokenization and populate cache if needed.""" - x = super().tokenize(x) - no_undefined_paths = all(p is not None for p in x.path) - if self.cache is not None and no_undefined_paths: - paths = [Path(p) for p in x.path if p is not None] - self.cache.populate_embed_cache(paths, x) - return x - -class ChordProgressionConditioner(BaseConditioner): - """Chord progression conditioning supporting chord progression conditioning. - - Args: - dim (int): Dimension. - output_dim (int): Output dimension. - device (str): Device. - attribute (str): Attribute used by the conditioner. - autocast_dtype (str): Autocast for the conditioner. - """ - - def __init__(self, output_dim: int, device: str, name: str): - n_chroma = 12 - # n_chroma = 24 - super().__init__(dim=n_chroma, output_dim=output_dim) - self.device = device - - def forward(self, x: ChordCondition) -> ConditionType: - chord, lengths, *_ = x - embeds = chord.to(self.output_proj.weight) # chrod is already a tensor, [N, C] - embeds = self.output_proj(embeds) - - if lengths is not None: - mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore - else: - mask = torch.ones_like(embeds) - embeds = (embeds * mask.unsqueeze(2).to(self.device)) - - return embeds, mask - - def tokenize(self, x: ChordCondition) -> ChordCondition: - """Apply ChordConditioner tokenization and populate cache if needed.""" - chord, length, bpm, path, seek_frame = x - chord = F.pad(chord, (0, length[0] - chord.shape[-1])) # [B, C, t] -> [B, C, T] - chord = chord.permute(0, 2, 1) # [B, T, C] - x = ChordCondition(chord.to(self.device), length.to(self.device), bpm, path, seek_frame) - return x - -class BeatConditioner(BaseConditioner): - """Beat conditioning supporting beat conditioning. - - Args: - dim (int): Dimension. - output_dim (int): Output dimension. - device (str): Device. - attribute (str): Attribute used by the conditioner. - autocast_dtype (str): Autocast for the conditioner. - """ - - def __init__(self, output_dim: int, device: str, name: str): - beat_channel = 1 - super().__init__(dim=beat_channel, output_dim=output_dim) - self.device = device - - def forward(self, x: BeatCondition) -> ConditionType: - beat, lengths, *_ = x - embeds = beat.to(self.output_proj.weight) # chrod is already a tensor, [N, C] - embeds = self.output_proj(embeds) - - if lengths is not None: - mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore - else: - mask = torch.ones_like(embeds) - embeds = (embeds * mask.unsqueeze(2).to(self.device)) - - return embeds, mask - - def tokenize(self, x: BeatCondition) -> BeatCondition: - """Apply ChordConditioner tokenization and populate cache if needed.""" - beat, length, bpm, path, seek_frame = x - beat = F.pad(beat, (0, length[0] - beat.shape[-1])) # [B, C, t] -> [B, C, T] - beat = beat.permute(0, 2, 1) # [B, T, C] - x = BeatCondition(beat.to(self.device), length.to(self.device), bpm, path, seek_frame) - return x - - -class JointEmbeddingConditioner(BaseConditioner): - """Joint embedding conditioning supporting both audio or text conditioning. - - Args: - dim (int): Dimension. - output_dim (int): Output dimension. - device (str): Device. - attribute (str): Attribute used by the conditioner. - autocast_dtype (str): Autocast for the conditioner. - quantize (bool): Whether to quantize the CLAP embedding. - n_q (int): Number of residual quantizers (used if quantize is true). - bins (int): Quantizers' codebooks size (used if quantize is true). - kwargs: Additional parameters for residual vector quantizer. - """ - def __init__(self, dim: int, output_dim: int, device: str, attribute: str, - autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True, - n_q: int = 12, bins: int = 1024, **kwargs): - super().__init__(dim=dim, output_dim=output_dim) - self.device = device - self.attribute = attribute - if autocast_dtype is None or device == 'cpu': - self.autocast = TorchAutocast(enabled=False) - logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.") - else: - dtype = getattr(torch, autocast_dtype) - assert isinstance(dtype, torch.dtype) - logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.") - self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype) - # residual vector quantizer to discretize the conditioned embedding - self.quantizer: tp.Optional[ResidualVectorQuantizer] = None - if quantize: - self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs) - - def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]: - """Get joint embedding in latent space from the inputs. - - Returns: - tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding - and corresponding empty indexes. - """ - raise NotImplementedError() - - def forward(self, x: JointEmbedCondition) -> ConditionType: - with self.autocast: - embed, empty_idx = self._get_embed(x) - if self.quantizer is not None: - embed = embed.view(-1, self.dim, 1) - q_res = self.quantizer(embed, frame_rate=1) - out_embed = q_res.x.view(-1, self.dim) - else: - out_embed = embed - out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim) - mask = torch.ones(*out_embed.shape[:2], device=out_embed.device) - mask[empty_idx, :] = 0 # zero-out index where the input is non-existant - out_embed = (out_embed * mask.unsqueeze(-1)) - return out_embed, mask - - def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: - return x - - -class CLAPEmbeddingConditioner(JointEmbeddingConditioner): - """Joint Embedding conditioner based on pre-trained CLAP model. - - This CLAP-based conditioner supports a caching mechanism - over the computed embeddings for faster training. - - Args: - dim (int): Dimension. - output_dim (int): Output dimension. - device (str): Device. - attribute (str): Attribute used by the conditioner. - quantize (bool): Whether to quantize the CLAP embedding. - n_q (int): Number of residual quantizers (used if quantize is true). - bins (int): Quantizers' codebooks size (used if quantize is true). - checkpoint (str): Path to CLAP checkpoint. - model_arch (str): CLAP model architecture. - enable_fusion (bool): Enable fusion for CLAP model. - sample_rate (int): Sample rate used by CLAP model. - max_audio_length (float): Maximum audio length for CLAP model. - audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence. - normalize (bool): Whether to normalize the CLAP embedding. - text_p (float): Probability of using text representation instead of audio at train time. - batch_size (Optional[int]): Batch size for CLAP embedding computation. - autocast_dtype (str): Autocast for the conditioner. - cache_path (Optional[str]): Path for pre-computed embeddings caching. - kwargs: Additional parameters for residual vector quantizer. - """ - def __init__(self, dim: int, output_dim: int, device: str, attribute: str, - quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str, - enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int, - normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None, - autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs): - try: - import laion_clap # type: ignore - except ImportError: - raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'") - checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint) - clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base') - clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch) - load_clap_state_dict(clap_model, checkpoint) - clap_model.eval() - clap_model.to(device) - super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute, - autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins, - **kwargs) - self.checkpoint = checkpoint - self.enable_fusion = enable_fusion - self.model_arch = model_arch - self.clap: laion_clap.CLAP_Module - self.clap_tokenize: RobertaTokenizer - self.clap_sample_rate = sample_rate - self.clap_max_frames = int(self.clap_sample_rate * max_audio_length) - self.clap_stride = int(self.clap_sample_rate * audio_stride) - self.batch_size = batch_size or 1 - self.normalize = normalize - self.text_p = text_p - self.__dict__['clap_tokenize'] = clap_tokenize - self.__dict__['clap'] = clap_model - self.wav_cache, self.text_cache = None, None - if cache_path is not None: - self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device, - compute_embed_fn=self._get_wav_embedding_for_cache, - extract_embed_fn=self._extract_wav_embedding_chunk) - self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device, - compute_embed_fn=self._get_text_embedding_for_cache) - - def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: - # we use the default params from CLAP module here as well - return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt") - - def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor: - """Compute text embedding from CLAP model on a given a batch of text. - - Args: - text (list[str]): List of text for the batch, with B items. - Returns: - torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension. - """ - with torch.no_grad(): - embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True) - return embed.view(embed.size(0), 1, embed.size(-1)) - - def _get_text_embedding_for_cache(self, path: tp.Union[Path, str], - x: JointEmbedCondition, idx: int) -> torch.Tensor: - """Get text embedding function for the cache.""" - text = x.text[idx] - text = text if text is not None else "" - return self._compute_text_embedding([text])[0] - - def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor: - """Preprocess wav to expected format by CLAP model. - - Args: - wav (torch.Tensor): Audio wav, of shape [B, C, T]. - length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B]. - sample_rates (list[int]): Sample rates for each sample in the batch - Returns: - torch.Tensor: Audio wav of shape [B, T]. - """ - assert wav.dim() == 3, "Expecting wav to be [B, C, T]" - if sample_rates is not None: - _wav = [] - for i, audio in enumerate(wav): - sr = sample_rates[i] - audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1) - _wav.append(audio) - wav = torch.stack(_wav, dim=0) - wav = wav.mean(dim=1) - return wav - - def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor, - sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor: - """Compute audio wave embedding from CLAP model. - - Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences, - we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and - average the resulting embeddings. - - Args: - wav (torch.Tensor): Audio wav, of shape [B, C, T]. - length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B]. - sample_rates (list[int]): Sample rates for each sample in the batch. - reduce_mean (bool): Whether to get the average tensor. - Returns: - torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension. - """ - with torch.no_grad(): - wav = self._preprocess_wav(wav, length, sample_rates) - B, T = wav.shape - if T >= self.clap_max_frames: - wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T] - else: - wav = wav.view(-1, 1, T) # [B, F, T] with F=1 - wav = einops.rearrange(wav, 'b f t -> (b f) t') - embed_list = [] - for i in range(0, wav.size(0), self.batch_size): - _wav = wav[i:i+self.batch_size, ...] - _embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True) - embed_list.append(_embed) - embed = torch.cat(embed_list, dim=0) - embed = einops.rearrange(embed, '(b f) d -> b f d', b=B) - if reduce_mean: - embed = embed.mean(dim=1, keepdim=True) - return embed # [B, F, D] with F=1 if reduce_mean is True - - def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path], - x: JointEmbedCondition, idx: int) -> torch.Tensor: - """Compute audio wave embedding for the cache. - The embedding is computed on a given audio read from file. - - Args: - path (str or Path): Path to the full audio file. - Returns: - torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension. - """ - wav, sr = audio_read(path) # [C, T] - wav = wav.unsqueeze(0).to(self.device) # [1, C, T] - wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device) - embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D] - return embed.squeeze(0) # [F, D] - - def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor: - """Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding. - - Args: - full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D]. - x (JointEmbedCondition): Joint embedding condition for the full batch. - idx (int): Index considered for the given embedding to extract. - Returns: - torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D]. - """ - sample_rate = x.sample_rate[idx] - seek_time = x.seek_time[idx] - seek_time = 0. if seek_time is None else seek_time - clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate - end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate - start_offset = int(seek_time * sample_rate // clap_stride) - end_offset = int(end_seek_time * sample_rate // clap_stride) - wav_embed = full_embed[start_offset:end_offset, ...] - wav_embed = wav_embed.mean(dim=0, keepdim=True) - return wav_embed.to(self.device) # [F, D] - - def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor: - """Get CLAP embedding from a batch of text descriptions.""" - no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout - if self.text_cache is not None and no_nullified_cond: - assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided" - paths = [Path(p) for p in x.path if p is not None] - embed = self.text_cache.get_embed_from_cache(paths, x) - else: - text = [xi if xi is not None else "" for xi in x.text] - embed = self._compute_text_embedding(text) - if self.normalize: - embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1) - return embed - - def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor: - """Get CLAP embedding from a batch of audio tensors (and corresponding sample rates).""" - no_undefined_paths = all(p is not None for p in x.path) - no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout - if self.wav_cache is not None and no_undefined_paths and no_nullified_cond: - paths = [Path(p) for p in x.path if p is not None] - embed = self.wav_cache.get_embed_from_cache(paths, x) - else: - embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True) - if self.normalize: - embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1) - return embed - - def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: - # Trying to limit as much as possible sync points when the cache is warm. - no_undefined_paths = all(p is not None for p in x.path) - if self.wav_cache is not None and no_undefined_paths: - assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided" - paths = [Path(p) for p in x.path if p is not None] - self.wav_cache.populate_embed_cache(paths, x) - if self.text_cache is not None and no_undefined_paths: - assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided" - paths = [Path(p) for p in x.path if p is not None] - self.text_cache.populate_embed_cache(paths, x) - return x - - def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]: - """Extract shared latent representation from either the wav or the text using CLAP.""" - # decide whether to use text embedding at train time or not - use_text_embed = random.random() < self.text_p - if self.training and not use_text_embed: - embed = self._get_wav_embedding(x) - empty_idx = torch.LongTensor([]) # we assume we always have the audio wav - else: - embed = self._get_text_embedding(x) - empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""]) - return embed, empty_idx - - -def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes: - """Utility function for nullifying an attribute inside an ConditioningAttributes object. - If the condition is of type "wav", then nullify it using `nullify_condition` function. - If the condition is of any other type, set its value to None. - Works in-place. - """ - if condition_type not in ['text', 'wav', 'beat', 'chord', 'joint_embed']: - raise ValueError( - "dropout_condition got an unexpected condition type!" - f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'" - ) - - if condition not in getattr(sample, condition_type): - raise ValueError( - "dropout_condition received an unexpected condition!" - f" expected wav={sample.wav.keys()} and text={sample.text.keys()}" - f" but got '{condition}' of type '{condition_type}'!" - ) - - if condition_type == 'wav': - wav_cond = sample.wav[condition] - sample.wav[condition] = nullify_wav(wav_cond) - elif condition_type == 'beat': - beat_cond = sample.beat[condition] - sample.beat[condition] = nullify_beat(beat_cond) - elif condition_type == 'chord': - chord_cond = sample.chord[condition] - sample.chord[condition] = nullify_chord(chord_cond) - elif condition_type == 'joint_embed': - embed = sample.joint_embed[condition] - sample.joint_embed[condition] = nullify_joint_embed(embed) - else: - sample.text[condition] = None - - return sample - - -class DropoutModule(nn.Module): - """Base module for all dropout modules.""" - def __init__(self, seed: int = 1234): - super().__init__() - self.rng = torch.Generator() - self.rng.manual_seed(seed) - - -class AttributeDropout(DropoutModule): - """Dropout with a given probability per attribute. - This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes - to be dropped out separately. For example, "artist" can be dropped while "genre" remains. - This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre" - must also be dropped. - - Args: - p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example: - ... - "genre": 0.1, - "artist": 0.5, - "wav": 0.25, - ... - active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False. - seed (int, optional): Random seed. - """ - def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234): - super().__init__(seed=seed) - self.active_on_eval = active_on_eval - # construct dict that return the values from p otherwise 0 - self.p = {} - for condition_type, probs in p.items(): - self.p[condition_type] = defaultdict(lambda: 0, probs) - - def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: - """ - Args: - samples (list[ConditioningAttributes]): List of conditions. - Returns: - list[ConditioningAttributes]: List of conditions after certain attributes were set to None. - """ - if not self.training and not self.active_on_eval: - return samples - - samples = deepcopy(samples) - for condition_type, ps in self.p.items(): # for condition types [text, wav] - for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre]) - if torch.rand(1, generator=self.rng).item() < p: - for sample in samples: - dropout_condition(sample, condition_type, condition) - return samples - - def __repr__(self): - return f"AttributeDropout({dict(self.p)})" - - -class ClassifierFreeGuidanceDropout(DropoutModule): - """Classifier Free Guidance dropout. - All attributes are dropped with the same probability. - - Args: - p (float): Probability to apply condition dropout during training. - seed (int): Random seed. - """ - def __init__(self, p: float, seed: int = 1234): - super().__init__(seed=seed) - self.p = p - - def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: - """ - Args: - samples (list[ConditioningAttributes]): List of conditions. - Returns: - list[ConditioningAttributes]: List of conditions after all attributes were set to None. - """ - if not self.training: - return samples - - # decide on which attributes to drop in a batched fashion - drop = torch.rand(1, generator=self.rng).item() < self.p - if not drop: - return samples - - # nullify conditions of all attributes - samples = deepcopy(samples) - for condition_type in ["wav", "text", "beat", "chord"]: - for sample in samples: - for condition in sample.attributes[condition_type]: - dropout_condition(sample, condition_type, condition) - return samples - - def __repr__(self): - return f"ClassifierFreeGuidanceDropout(p={self.p})" - - -class ConditioningProvider(nn.Module): - """Prepare and provide conditions given all the supported conditioners. - - Args: - conditioners (dict): Dictionary of conditioners. - device (torch.device or str, optional): Device for conditioners and output condition types. - """ - def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"): - super().__init__() - self.device = device - self.conditioners = nn.ModuleDict(conditioners) - - @property - def joint_embed_conditions(self): - return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)] - - @property - def has_joint_embed_conditions(self): - return len(self.joint_embed_conditions) > 0 - - @property - def text_conditions(self): - return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)] - - @property - def wav_conditions(self): - return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)] - - @property - def beat_conditions(self): - return [k for k, v in self.conditioners.items() if isinstance(v, BeatConditioner)] - - @property - def chord_conditions(self): - return [k for k, v in self.conditioners.items() if isinstance(v, ChordProgressionConditioner)] - - @property - def has_wav_condition(self): - return len(self.wav_conditions) > 0 - - def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]: - """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly. - This should be called before starting any real GPU work to avoid synchronization points. - This will return a dict matching conditioner names to their arbitrary tokenized representations. - - Args: - inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing - text and wav conditions. - """ - assert all([isinstance(x, ConditioningAttributes) for x in inputs]), ( - "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]", - f" but types were {set([type(x) for x in inputs])}" - ) - - output = {} - text = self._collate_text(inputs) - beats = self._collate_beats(inputs) - chords = self._collate_chords(inputs) - wavs = self._collate_wavs(inputs) - joint_embeds = self._collate_joint_embeds(inputs) - - assert set(text.keys() | wavs.keys() | chords.keys() | beats.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), ( - f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ", - f"got {text.keys(), wavs.keys(), chords.keys(), beats.keys(), joint_embeds.keys()}" - ) - - for attribute, batch in chain(text.items(), wavs.items(), chords.items(), beats.items(), joint_embeds.items()): - output[attribute] = self.conditioners[attribute].tokenize(batch) - return output - - def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]: - """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations. - The output is for example: - { - "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])), - "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])), - ... - } - - Args: - tokenized (dict): Dict of tokenized representations as returned by `tokenize()`. - """ - output = {} - for attribute, inputs in tokenized.items(): - condition, mask = self.conditioners[attribute](inputs) - output[attribute] = (condition, mask) - return output - - def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]: - """Given a list of ConditioningAttributes objects, compile a dictionary where the keys - are the attributes and the values are the aggregated input per attribute. - For example: - Input: - [ - ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...), - ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...), - ] - Output: - { - "genre": ["Rock", "Hip-hop"], - "description": ["A rock song with a guitar solo", "A hip-hop verse"] - } - - Args: - samples (list of ConditioningAttributes): List of ConditioningAttributes samples. - Returns: - dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch. - """ - out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list) - texts = [x.text for x in samples] - for text in texts: - for condition in self.text_conditions: - out[condition].append(text[condition]) - return out - - def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]: - """Generate a dict where the keys are attributes by which we fetch similar wavs, - and the values are Tensors of wavs according to said attributes. - - *Note*: by the time the samples reach this function, each sample should have some waveform - inside the "wav" attribute. It should be either: - 1. A real waveform - 2. A null waveform due to the sample having no similar waveforms (nullified by the dataset) - 3. A null waveform due to it being dropped in a dropout module (nullified by dropout) - - Args: - samples (list of ConditioningAttributes): List of ConditioningAttributes samples. - Returns: - dict[str, WavCondition]: A dictionary mapping an attribute name to wavs. - """ - wavs = defaultdict(list) - lengths = defaultdict(list) - sample_rates = defaultdict(list) - paths = defaultdict(list) - seek_times = defaultdict(list) - out: tp.Dict[str, WavCondition] = {} - - for sample in samples: - for attribute in self.wav_conditions: - wav, length, sample_rate, path, seek_time = sample.wav[attribute] - assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]" - assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1" - # mono-channel conditioning - wav = wav.mean(1, keepdim=True) # [1, 1, T] - wavs[attribute].append(wav.flatten()) # [T] - lengths[attribute].append(length) - sample_rates[attribute].extend(sample_rate) - paths[attribute].extend(path) - seek_times[attribute].extend(seek_time) - - # stack all wavs to a single tensor - for attribute in self.wav_conditions: - stacked_wav, _ = collate(wavs[attribute], dim=0) - out[attribute] = WavCondition( - stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute], - paths[attribute], seek_times[attribute]) - - return out - - def _collate_chords(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, ChordCondition]: - """Generate a dict where the keys are attributes by which we fetch similar wavs, - and the values are Tensors of wavs according to said attributes. - - *Note*: by the time the samples reach this function, each sample should have some waveform - inside the "wav" attribute. It should be either: - 1. A real waveform - 2. A null waveform due to the sample having no similar waveforms (nullified by the dataset) - 3. A null waveform due to it being dropped in a dropout module (nullified by dropout) - - Args: - samples (list of ConditioningAttributes): List of ConditioningAttributes samples. - Returns: - dict[str, WavCondition]: A dictionary mapping an attribute name to wavs. - """ - chords = defaultdict(list) - lengths = defaultdict(list) - bpms = defaultdict(list) - paths = defaultdict(list) - seek_frames = defaultdict(list) - out: tp.Dict[str, ChordCondition] = {} - - for sample in samples: # sample = ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...) - for attribute in self.chord_conditions: # self.chord_conditions = ['chord'] - chord, length, bpm, path, seek_frame = sample.chord[attribute] - assert chord.dim() == 3, f"Got chord with dim={chord.dim()}, but expected 3 [1, C, T]" - assert chord.size(0) == 1, f"Got chord [B, C, T] with shape={chord.shape}, but expected B == 1" - chords[attribute].append(chord.squeeze(0)) # [1, C, T] -> [N * [C, T]] - lengths[attribute].append(length) # [N, 1] - bpms[attribute].extend(bpm) # [N] - paths[attribute].extend(path) # [N] - seek_frames[attribute].extend(seek_frame) # [N] - - # stack all chords to a single tensor - for attribute in self.chord_conditions: - stacked_chord, _ = collate(chords[attribute], dim=1) # tensor padded here - out[attribute] = ChordCondition( - stacked_chord, torch.cat(lengths[attribute]), bpms[attribute], - paths[attribute], seek_frames[attribute]) - # print(f"chords shape: {chords[attribute][0].shape}") - # print(f"stack chords shape: {stacked_chord.shape}") - return out - - def _collate_beats(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, ChordCondition]: - """Generate a dict where the keys are attributes by which we fetch similar wavs, - and the values are Tensors of wavs according to said attributes. - - Args: - samples (list of ConditioningAttributes): List of ConditioningAttributes samples. - Returns: - dict[str, WavCondition]: A dictionary mapping an attribute name to wavs. - """ - beats = defaultdict(list) - lengths = defaultdict(list) - bpms = defaultdict(list) - paths = defaultdict(list) - seek_frames = defaultdict(list) - out: tp.Dict[str, ChordCondition] = {} - - for sample in samples: # sample = ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...) - for attribute in self.beat_conditions: # self.chord_conditions = ['chord'] - beat, length, bpm, path, seek_frame = sample.beat[attribute] - assert beat.dim() == 3, f"Got chord with dim={beat.dim()}, but expected 3 [1, C, T]" - assert beat.size(0) == 1, f"Got chord [B, C, T] with shape={beat.shape}, but expected B == 1" - beats[attribute].append(beat.squeeze(0)) # [1, C, T] -> [N * [C, T]] - lengths[attribute].append(length) # [N, 1] - bpms[attribute].extend(bpm) # [N] - paths[attribute].extend(path) # [N] - seek_frames[attribute].extend(seek_frame) # [N] - - # stack all chords to a single tensor - for attribute in self.beat_conditions: - stacked_beat, _ = collate(beats[attribute], dim=1) # tensor padded here - out[attribute] = BeatCondition( - stacked_beat, torch.cat(lengths[attribute]), bpms[attribute], - paths[attribute], seek_frames[attribute]) - # print(f"chords shape: {chords[attribute][0].shape}") - # print(f"stack chords shape: {stacked_chord.shape}") - return out - - def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]: - """Generate a dict where the keys are attributes by which we compute joint embeddings, - and the values are Tensors of pre-computed embeddings and the corresponding text attributes. - - Args: - samples (list[ConditioningAttributes]): List of ConditioningAttributes samples. - Returns: - A dictionary mapping an attribute name to joint embeddings. - """ - texts = defaultdict(list) - wavs = defaultdict(list) - lengths = defaultdict(list) - sample_rates = defaultdict(list) - paths = defaultdict(list) - seek_times = defaultdict(list) - channels: int = 0 - - out = {} - for sample in samples: - for attribute in self.joint_embed_conditions: - wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute] - assert wav.dim() == 3 - if channels == 0: - channels = wav.size(1) - else: - assert channels == wav.size(1), "not all audio has same number of channels in batch" - assert wav.size(0) == 1, "Expecting single-wav batch in the collate method" - wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T] - wavs[attribute].append(wav) - texts[attribute].extend(text) - lengths[attribute].append(length) - sample_rates[attribute].extend(sample_rate) - paths[attribute].extend(path) - seek_times[attribute].extend(seek_time) - - for attribute in self.joint_embed_conditions: - stacked_texts = texts[attribute] - stacked_paths = paths[attribute] - stacked_seek_times = seek_times[attribute] - stacked_wavs = pad_sequence(wavs[attribute]).to(self.device) - stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels) - stacked_sample_rates = sample_rates[attribute] - stacked_lengths = torch.cat(lengths[attribute]).to(self.device) - assert stacked_lengths.size(0) == stacked_wavs.size(0) - assert len(stacked_sample_rates) == stacked_wavs.size(0) - assert len(stacked_texts) == stacked_wavs.size(0) - out[attribute] = JointEmbedCondition( - text=stacked_texts, wav=stacked_wavs, - length=stacked_lengths, sample_rate=stacked_sample_rates, - path=stacked_paths, seek_time=stacked_seek_times) - - return out - - -class ConditionFuser(StreamingModule): - """Condition fuser handles the logic to combine the different conditions - to the actual model input. - - Args: - fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse - each condition. For example: - { - "prepend": ["description"], - "sum": ["genre", "bpm"], - "cross": ["description"], - } - cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention. - cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used. - """ - FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate", "concat"] - - def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False, - cross_attention_pos_emb_scale: float = 1.0, in_attn: bool = False): - super().__init__() - assert all( - [k in self.FUSING_METHODS for k in fuse2cond.keys()] - ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}" - self.cross_attention_pos_emb = cross_attention_pos_emb - self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale - self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond - self.cond2fuse: tp.Dict[str, str] = {} - self.in_attn = in_attn - - for fuse_method, conditions in fuse2cond.items(): - for condition in conditions: - if not condition in self.cond2fuse.keys(): - self.cond2fuse[condition] = [fuse_method] - else: - self.cond2fuse[condition].append(fuse_method) - - - def forward( - self, - input: torch.Tensor, - conditions: tp.Dict[str, ConditionType] - ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: - """Fuse the conditions to the provided model input. - - Args: - input (torch.Tensor): Transformer input. - conditions (dict[str, ConditionType]): Dict of conditions. - Returns: - tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input - after the conditions have been fused. The second output tensor is the tensor - used for cross-attention or None if no cross attention inputs exist. - """ - - B, T, _ = input.shape # [B, T, C] - if self.in_attn: - in_attn_cond = torch.zeros_like(input) - else: - in_attn_cond = None - - if 'offsets' in self._streaming_state: - first_step = False - offsets = self._streaming_state['offsets'] - else: - first_step = True - offsets = torch.zeros(B, dtype=torch.long, device=input.device) - - assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \ - f"given conditions contain unknown attributes for fuser, " \ - f"expected {self.cond2fuse.keys()}, got {conditions.keys()}" - cross_attention_output = None - - for cond_type, (cond, cond_mask) in conditions.items(): - fuse_methods = self.cond2fuse[cond_type] - for op in fuse_methods: - if op == 'sum': - cond_sum = cond[:, offsets[0]:offsets[0]+T] - if cond_sum.shape[1] != 0: - if cond_sum.shape[1] < T: - cond_sum = F.pad(cond_sum, (0, 0, 0, T-cond_sum.shape[1]), "constant", 0) # pad last special token dim - input[:, -cond_sum.shape[1]:, :] = input[:, -cond_sum.shape[1]:, :] + cond_sum - if self.in_attn: - in_attn_cond += cond_sum - - elif op == 'input_interpolate': - cond = einops.rearrange(cond, "b t d -> b d t") - cond = F.interpolate(cond, size=input.shape[1]) - input += einops.rearrange(cond, "b d t -> b t d") - - elif op == 'prepend': - if cond_type == 'chord': - cond_prepend = torch.zeros(cond.shape[0], 235, cond.shape[2], device=cond.device) # original musicgen melody has 235 length chroma - if cond.shape[1] == 1500: # if condition not dropout - for i in range(235): - cond_prepend[:, i, :] = cond[:, round(i * (1500/235)), :] # n_frame of chord = 30*50 into 235 time steps - else: - cond_prepend = cond - - if first_step: - input = torch.cat([cond_prepend, input], dim=1) - - elif op == 'cross': - if cross_attention_output is not None: - cross_attention_output = torch.cat([cross_attention_output, cond], dim=1) - else: - cross_attention_output = cond - else: - raise ValueError(f"unknown op ({op})") - - - if self.cross_attention_pos_emb and cross_attention_output is not None: - positions = torch.arange( - cross_attention_output.shape[1], - device=cross_attention_output.device - ).view(1, -1, 1) - pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1]) - cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb - - if self._is_streaming: - self._streaming_state['offsets'] = offsets + T - - return input, in_attn_cond, cross_attention_output \ No newline at end of file diff --git a/audiocraft/audiocraft/modules/conv.py b/audiocraft/audiocraft/modules/conv.py deleted file mode 100644 index d115cbf8729b642ed78608bd00a4d0fd5afae6fd..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/modules/conv.py +++ /dev/null @@ -1,243 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import math -import typing as tp -import warnings - -import torch -from torch import nn -from torch.nn import functional as F -from torch.nn.utils import spectral_norm, weight_norm - - -CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', - 'time_group_norm']) - - -def apply_parametrization_norm(module: nn.Module, norm: str = 'none'): - assert norm in CONV_NORMALIZATIONS - if norm == 'weight_norm': - return weight_norm(module) - elif norm == 'spectral_norm': - return spectral_norm(module) - else: - # We already check was in CONV_NORMALIZATION, so any other choice - # doesn't need reparametrization. - return module - - -def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs): - """Return the proper normalization module. If causal is True, this will ensure the returned - module is causal, or return an error if the normalization doesn't support causal evaluation. - """ - assert norm in CONV_NORMALIZATIONS - if norm == 'time_group_norm': - if causal: - raise ValueError("GroupNorm doesn't support causal evaluation.") - assert isinstance(module, nn.modules.conv._ConvNd) - return nn.GroupNorm(1, module.out_channels, **norm_kwargs) - else: - return nn.Identity() - - -def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, - padding_total: int = 0) -> int: - """See `pad_for_conv1d`.""" - length = x.shape[-1] - n_frames = (length - kernel_size + padding_total) / stride + 1 - ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) - return ideal_length - length - - -def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): - """Pad for a convolution to make sure that the last window is full. - Extra padding is added at the end. This is required to ensure that we can rebuild - an output of the same length, as otherwise, even with padding, some time steps - might get removed. - For instance, with total padding = 4, kernel size = 4, stride = 2: - 0 0 1 2 3 4 5 0 0 # (0s are padding) - 1 2 3 # (output frames of a convolution, last 0 is never used) - 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) - 1 2 3 4 # once you removed padding, we are missing one time step ! - """ - extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) - return F.pad(x, (0, extra_padding)) - - -def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): - """Tiny wrapper around F.pad, just to allow for reflect padding on small input. - If this is the case, we insert extra 0 padding to the right before the reflection happen. - """ - length = x.shape[-1] - padding_left, padding_right = paddings - assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - if mode == 'reflect': - max_pad = max(padding_left, padding_right) - extra_pad = 0 - if length <= max_pad: - extra_pad = max_pad - length + 1 - x = F.pad(x, (0, extra_pad)) - padded = F.pad(x, paddings, mode, value) - end = padded.shape[-1] - extra_pad - return padded[..., :end] - else: - return F.pad(x, paddings, mode, value) - - -def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): - """Remove padding from x, handling properly zero padding. Only for 1d!""" - padding_left, padding_right = paddings - assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - assert (padding_left + padding_right) <= x.shape[-1] - end = x.shape[-1] - padding_right - return x[..., padding_left: end] - - -class NormConv1d(nn.Module): - """Wrapper around Conv1d and normalization applied to this conv - to provide a uniform interface across normalization approaches. - """ - def __init__(self, *args, causal: bool = False, norm: str = 'none', - norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): - super().__init__() - self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) - self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) - self.norm_type = norm - - def forward(self, x): - x = self.conv(x) - x = self.norm(x) - return x - - -class NormConv2d(nn.Module): - """Wrapper around Conv2d and normalization applied to this conv - to provide a uniform interface across normalization approaches. - """ - def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): - super().__init__() - self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) - self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) - self.norm_type = norm - - def forward(self, x): - x = self.conv(x) - x = self.norm(x) - return x - - -class NormConvTranspose1d(nn.Module): - """Wrapper around ConvTranspose1d and normalization applied to this conv - to provide a uniform interface across normalization approaches. - """ - def __init__(self, *args, causal: bool = False, norm: str = 'none', - norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): - super().__init__() - self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) - self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) - self.norm_type = norm - - def forward(self, x): - x = self.convtr(x) - x = self.norm(x) - return x - - -class NormConvTranspose2d(nn.Module): - """Wrapper around ConvTranspose2d and normalization applied to this conv - to provide a uniform interface across normalization approaches. - """ - def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): - super().__init__() - self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) - self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) - - def forward(self, x): - x = self.convtr(x) - x = self.norm(x) - return x - - -class StreamableConv1d(nn.Module): - """Conv1d with some builtin handling of asymmetric or causal padding - and normalization. - """ - def __init__(self, in_channels: int, out_channels: int, - kernel_size: int, stride: int = 1, dilation: int = 1, - groups: int = 1, bias: bool = True, causal: bool = False, - norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, - pad_mode: str = 'reflect'): - super().__init__() - # warn user on unusual setup between dilation and stride - if stride > 1 and dilation > 1: - warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1" - f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).") - self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, - dilation=dilation, groups=groups, bias=bias, causal=causal, - norm=norm, norm_kwargs=norm_kwargs) - self.causal = causal - self.pad_mode = pad_mode - - def forward(self, x): - B, C, T = x.shape - kernel_size = self.conv.conv.kernel_size[0] - stride = self.conv.conv.stride[0] - dilation = self.conv.conv.dilation[0] - kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations - padding_total = kernel_size - stride - extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) - if self.causal: - # Left padding for causal - x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) - else: - # Asymmetric padding required for odd strides - padding_right = padding_total // 2 - padding_left = padding_total - padding_right - x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) - return self.conv(x) - - -class StreamableConvTranspose1d(nn.Module): - """ConvTranspose1d with some builtin handling of asymmetric or causal padding - and normalization. - """ - def __init__(self, in_channels: int, out_channels: int, - kernel_size: int, stride: int = 1, causal: bool = False, - norm: str = 'none', trim_right_ratio: float = 1., - norm_kwargs: tp.Dict[str, tp.Any] = {}): - super().__init__() - self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, - causal=causal, norm=norm, norm_kwargs=norm_kwargs) - self.causal = causal - self.trim_right_ratio = trim_right_ratio - assert self.causal or self.trim_right_ratio == 1., \ - "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" - assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. - - def forward(self, x): - kernel_size = self.convtr.convtr.kernel_size[0] - stride = self.convtr.convtr.stride[0] - padding_total = kernel_size - stride - - y = self.convtr(x) - - # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be - # removed at the very end, when keeping only the right length for the output, - # as removing it here would require also passing the length at the matching layer - # in the encoder. - if self.causal: - # Trim the padding on the right according to the specified ratio - # if trim_right_ratio = 1.0, trim everything from right - padding_right = math.ceil(padding_total * self.trim_right_ratio) - padding_left = padding_total - padding_right - y = unpad1d(y, (padding_left, padding_right)) - else: - # Asymmetric padding required for odd strides - padding_right = padding_total // 2 - padding_left = padding_total - padding_right - y = unpad1d(y, (padding_left, padding_right)) - return y diff --git a/audiocraft/audiocraft/modules/diffusion_schedule.py b/audiocraft/audiocraft/modules/diffusion_schedule.py deleted file mode 100644 index 74ca6e3f2e7c4ff904d96dade315b0b46856778d..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/modules/diffusion_schedule.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Functions for Noise Schedule, defines diffusion process, reverse process and data processor. -""" - -from collections import namedtuple -import random -import typing as tp -import julius -import torch - -TrainingItem = namedtuple("TrainingItem", "noisy noise step") - - -def betas_from_alpha_bar(alpha_bar): - alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]]) - return 1 - alphas - - -class SampleProcessor(torch.nn.Module): - def project_sample(self, x: torch.Tensor): - """Project the original sample to the 'space' where the diffusion will happen.""" - return x - - def return_sample(self, z: torch.Tensor): - """Project back from diffusion space to the actual sample space.""" - return z - - -class MultiBandProcessor(SampleProcessor): - """ - MultiBand sample processor. The input audio is splitted across - frequency bands evenly distributed in mel-scale. - - Each band will be rescaled to match the power distribution - of Gaussian noise in that band, using online metrics - computed on the first few samples. - - Args: - n_bands (int): Number of mel-bands to split the signal over. - sample_rate (int): Sample rate of the audio. - num_samples (int): Number of samples to use to fit the rescaling - for each band. The processor won't be stable - until it has seen that many samples. - power_std (float or list/tensor): The rescaling factor computed to match the - power of Gaussian noise in each band is taken to - that power, i.e. `1.` means full correction of the energy - in each band, and values less than `1` means only partial - correction. Can be used to balance the relative importance - of low vs. high freq in typical audio signals. - """ - def __init__(self, n_bands: int = 8, sample_rate: float = 24_000, - num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.): - super().__init__() - self.n_bands = n_bands - self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands) - self.num_samples = num_samples - self.power_std = power_std - if isinstance(power_std, list): - assert len(power_std) == n_bands - power_std = torch.tensor(power_std) - self.register_buffer('counts', torch.zeros(1)) - self.register_buffer('sum_x', torch.zeros(n_bands)) - self.register_buffer('sum_x2', torch.zeros(n_bands)) - self.register_buffer('sum_target_x2', torch.zeros(n_bands)) - self.counts: torch.Tensor - self.sum_x: torch.Tensor - self.sum_x2: torch.Tensor - self.sum_target_x2: torch.Tensor - - @property - def mean(self): - mean = self.sum_x / self.counts - return mean - - @property - def std(self): - std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() - return std - - @property - def target_std(self): - target_std = self.sum_target_x2 / self.counts - return target_std - - def project_sample(self, x: torch.Tensor): - assert x.dim() == 3 - bands = self.split_bands(x) - if self.counts.item() < self.num_samples: - ref_bands = self.split_bands(torch.randn_like(x)) - self.counts += len(x) - self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1) - self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1) - self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1) - rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size - bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1) - return bands.sum(dim=0) - - def return_sample(self, x: torch.Tensor): - assert x.dim() == 3 - bands = self.split_bands(x) - rescale = (self.std / self.target_std) ** self.power_std - bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1) - return bands.sum(dim=0) - - -class NoiseSchedule: - """Noise schedule for diffusion. - - Args: - beta_t0 (float): Variance of the first diffusion step. - beta_t1 (float): Variance of the last diffusion step. - beta_exp (float): Power schedule exponent - num_steps (int): Number of diffusion step. - variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde" - clip (float): clipping value for the denoising steps - rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1) - repartition (str): shape of the schedule only power schedule is supported - sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution - noise_scale (float): Scaling factor for the noise - """ - def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta', - clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1, - repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None, - sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs): - - self.beta_t0 = beta_t0 - self.beta_t1 = beta_t1 - self.variance = variance - self.num_steps = num_steps - self.clip = clip - self.sample_processor = sample_processor - self.rescale = rescale - self.n_bands = n_bands - self.noise_scale = noise_scale - assert n_bands is None - if repartition == "power": - self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps, - device=device, dtype=torch.float) ** beta_exp - else: - raise RuntimeError('Not implemented') - self.rng = random.Random(1234) - - def get_beta(self, step: tp.Union[int, torch.Tensor]): - if self.n_bands is None: - return self.betas[step] - else: - return self.betas[:, step] # [n_bands, len(step)] - - def get_initial_noise(self, x: torch.Tensor): - if self.n_bands is None: - return torch.randn_like(x) - return torch.randn((x.size(0), self.n_bands, x.size(2))) - - def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor: - """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step.""" - if step is None: - return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands - if type(step) is int: - return (1 - self.betas[:step + 1]).prod() - else: - return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1) - - def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem: - """Create a noisy data item for diffusion model training: - - Args: - x (torch.Tensor): clean audio data torch.tensor(bs, 1, T) - tensor_step (bool): If tensor_step = false, only one step t is sample, - the whole batch is diffused to the same step and t is int. - If tensor_step = true, t is a tensor of size (x.size(0),) - every element of the batch is diffused to a independently sampled. - """ - step: tp.Union[int, torch.Tensor] - if tensor_step: - bs = x.size(0) - step = torch.randint(0, self.num_steps, size=(bs,), device=x.device) - else: - step = self.rng.randrange(self.num_steps) - alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1] - - x = self.sample_processor.project_sample(x) - noise = torch.randn_like(x) - noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale - return TrainingItem(noisy, noise, step) - - def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None, - condition: tp.Optional[torch.Tensor] = None, return_list: bool = False): - """Full ddpm reverse process. - - Args: - model (nn.Module): Diffusion model. - initial (tensor): Initial Noise. - condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation). - return_list (bool): Whether to return the whole process or only the sampled point. - """ - alpha_bar = self.get_alpha_bar(step=self.num_steps - 1) - current = initial - iterates = [initial] - for step in range(self.num_steps)[::-1]: - with torch.no_grad(): - estimate = model(current, step, condition=condition).sample - alpha = 1 - self.betas[step] - previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt() - previous_alpha_bar = self.get_alpha_bar(step=step - 1) - if step == 0: - sigma2 = 0 - elif self.variance == 'beta': - sigma2 = 1 - alpha - elif self.variance == 'beta_tilde': - sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha) - elif self.variance == 'none': - sigma2 = 0 - else: - raise ValueError(f'Invalid variance type {self.variance}') - - if sigma2 > 0: - previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale - if self.clip: - previous = previous.clamp(-self.clip, self.clip) - current = previous - alpha_bar = previous_alpha_bar - if step == 0: - previous *= self.rescale - if return_list: - iterates.append(previous.cpu()) - - if return_list: - return iterates - else: - return self.sample_processor.return_sample(previous) - - def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None, - condition: tp.Optional[torch.Tensor] = None, return_list: bool = False): - """Reverse process that only goes through Markov chain states in step_list.""" - if step_list is None: - step_list = list(range(1000))[::-50] + [0] - alpha_bar = self.get_alpha_bar(step=self.num_steps - 1) - alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu() - betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled) - current = initial * self.noise_scale - iterates = [current] - for idx, step in enumerate(step_list[:-1]): - with torch.no_grad(): - estimate = model(current, step, condition=condition).sample * self.noise_scale - alpha = 1 - betas_subsampled[-1 - idx] - previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt() - previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1]) - if step == step_list[-2]: - sigma2 = 0 - previous_alpha_bar = torch.tensor(1.0) - else: - sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha) - if sigma2 > 0: - previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale - if self.clip: - previous = previous.clamp(-self.clip, self.clip) - current = previous - alpha_bar = previous_alpha_bar - if step == 0: - previous *= self.rescale - if return_list: - iterates.append(previous.cpu()) - if return_list: - return iterates - else: - return self.sample_processor.return_sample(previous) diff --git a/audiocraft/audiocraft/modules/lstm.py b/audiocraft/audiocraft/modules/lstm.py deleted file mode 100644 index c0866175950c1ca4f6cca98649525e6481853bba..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/modules/lstm.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from torch import nn - - -class StreamableLSTM(nn.Module): - """LSTM without worrying about the hidden state, nor the layout of the data. - Expects input as convolutional layout. - """ - def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): - super().__init__() - self.skip = skip - self.lstm = nn.LSTM(dimension, dimension, num_layers) - - def forward(self, x): - x = x.permute(2, 0, 1) - y, _ = self.lstm(x) - if self.skip: - y = y + x - y = y.permute(1, 2, 0) - return y diff --git a/audiocraft/audiocraft/modules/rope.py b/audiocraft/audiocraft/modules/rope.py deleted file mode 100644 index 503e6748df2bb72b3c864c20b37cba5498ffdd21..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/modules/rope.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import typing as tp - -from torch import nn -import torch - - -class XPos(nn.Module): - """Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1). - This applies an exponential decay to the RoPE rotation matrix. - - Args: - dim (int): Embedding dimension. - smoothing (float): Smoothing factor applied to the decay rates. - base_scale (int): Base decay rate, given in terms of scaling time. - device (torch.device, optional): Device on which to initialize the module. - dtype (torch.dtype): dtype to use to generate the embedding. - """ - def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512, - device=None, dtype: torch.dtype = torch.float32): - super().__init__() - assert dim % 2 == 0 - assert dtype in [torch.float64, torch.float32] - self.dtype = dtype - self.base_scale = base_scale - - half_dim = dim // 2 - adim = torch.arange(half_dim, device=device, dtype=dtype) - decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing) - self.register_buffer("decay_rates", decay_rates) - self.decay: tp.Optional[torch.Tensor] = None - - def get_decay(self, start: int, end: int): - """Create complex decay tensor, cache values for fast computation.""" - if self.decay is None or end > self.decay.shape[0]: - assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker. - idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype) - power = idx / self.base_scale - scale = self.decay_rates ** power.unsqueeze(-1) - self.decay = torch.polar(scale, torch.zeros_like(scale)) - return self.decay[start:end] # [T, C/2] - - -class RotaryEmbedding(nn.Module): - """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864). - - Args: - dim (int): Embedding dimension (twice the number of frequencies). - max_period (float): Maximum period of the rotation frequencies. - xpos (bool): Use xPos, applies an exponential decay to rotation matrix. - scale (float): Scale of positional embedding, set to 0 to deactivate. - device (torch.device, optional): Device on which to initialize the module. - dtype (torch.dtype): dtype to use to generate the embedding. - """ - def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False, - scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32): - super().__init__() - assert dim % 2 == 0 - self.scale = scale - assert dtype in [torch.float64, torch.float32] - self.dtype = dtype - - adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)] - frequencies = 1.0 / (max_period ** (adim / dim)) - self.register_buffer("frequencies", frequencies) - self.rotation: tp.Optional[torch.Tensor] = None - - self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None - - def get_rotation(self, start: int, end: int): - """Create complex rotation tensor, cache values for fast computation.""" - if self.rotation is None or end > self.rotation.shape[0]: - assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker. - idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype) - angles = torch.outer(idx, self.frequencies) - self.rotation = torch.polar(torch.ones_like(angles), angles) - return self.rotation[start:end] - - def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False): - """Apply rope rotation to query or key tensor.""" - T = x.shape[1] - rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2) - - if self.xpos: - decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2) - else: - decay = 1.0 - - if invert_decay: - decay = decay ** -1 - - x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2)) - scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale) - x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2) - - return x_out.type_as(x) - - def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0): - """ Apply rope rotation to both query and key tensors. - Supports streaming mode, in which query and key are not expected to have the same shape. - In streaming mode, key will be of length [P + C] with P the cached past timesteps, but - query will be [C] (typically C == 1). - - Args: - query (torch.Tensor): Query to rotate. - key (torch.Tensor): Key to rotate. - start (int): Start index of the sequence for time offset. - """ - query_timesteps = query.shape[1] - key_timesteps = key.shape[1] - streaming_offset = key_timesteps - query_timesteps - - query_out = self.rotate(query, start + streaming_offset) - key_out = self.rotate(key, start, invert_decay=True) - - return query_out, key_out diff --git a/audiocraft/audiocraft/modules/seanet.py b/audiocraft/audiocraft/modules/seanet.py deleted file mode 100644 index 3e5998e9153afb6e68ea410d565e00ea835db248..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/modules/seanet.py +++ /dev/null @@ -1,258 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import typing as tp - -import numpy as np -import torch.nn as nn - -from .conv import StreamableConv1d, StreamableConvTranspose1d -from .lstm import StreamableLSTM - - -class SEANetResnetBlock(nn.Module): - """Residual block from SEANet model. - - Args: - dim (int): Dimension of the input/output. - kernel_sizes (list): List of kernel sizes for the convolutions. - dilations (list): List of dilations for the convolutions. - activation (str): Activation function. - activation_params (dict): Parameters to provide to the activation function. - norm (str): Normalization method. - norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. - causal (bool): Whether to use fully causal convolution. - pad_mode (str): Padding mode for the convolutions. - compress (int): Reduced dimensionality in residual branches (from Demucs v3). - true_skip (bool): Whether to use true skip connection or a simple - (streamable) convolution as the skip connection. - """ - def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1], - activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, - norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False, - pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True): - super().__init__() - assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations' - act = getattr(nn, activation) - hidden = dim // compress - block = [] - for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): - in_chs = dim if i == 0 else hidden - out_chs = dim if i == len(kernel_sizes) - 1 else hidden - block += [ - act(**activation_params), - StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation, - norm=norm, norm_kwargs=norm_params, - causal=causal, pad_mode=pad_mode), - ] - self.block = nn.Sequential(*block) - self.shortcut: nn.Module - if true_skip: - self.shortcut = nn.Identity() - else: - self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params, - causal=causal, pad_mode=pad_mode) - - def forward(self, x): - return self.shortcut(x) + self.block(x) - - -class SEANetEncoder(nn.Module): - """SEANet encoder. - - Args: - channels (int): Audio channels. - dimension (int): Intermediate representation dimension. - n_filters (int): Base width for the model. - n_residual_layers (int): nb of residual layers. - ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of - upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here - that must match the decoder order. We use the decoder order as some models may only employ the decoder. - activation (str): Activation function. - activation_params (dict): Parameters to provide to the activation function. - norm (str): Normalization method. - norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. - kernel_size (int): Kernel size for the initial convolution. - last_kernel_size (int): Kernel size for the initial convolution. - residual_kernel_size (int): Kernel size for the residual layers. - dilation_base (int): How much to increase the dilation with each layer. - causal (bool): Whether to use fully causal convolution. - pad_mode (str): Padding mode for the convolutions. - true_skip (bool): Whether to use true skip connection or a simple - (streamable) convolution as the skip connection in the residual network blocks. - compress (int): Reduced dimensionality in residual branches (from Demucs v3). - lstm (int): Number of LSTM layers at the end of the encoder. - disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. - For the encoder, it corresponds to the N first blocks. - """ - def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, - ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, - norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, - last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, - pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, - disable_norm_outer_blocks: int = 0): - super().__init__() - self.channels = channels - self.dimension = dimension - self.n_filters = n_filters - self.ratios = list(reversed(ratios)) - del ratios - self.n_residual_layers = n_residual_layers - self.hop_length = np.prod(self.ratios) - self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks - self.disable_norm_outer_blocks = disable_norm_outer_blocks - assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \ - "Number of blocks for which to disable norm is invalid." \ - "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." - - act = getattr(nn, activation) - mult = 1 - model: tp.List[nn.Module] = [ - StreamableConv1d(channels, mult * n_filters, kernel_size, - norm='none' if self.disable_norm_outer_blocks >= 1 else norm, - norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) - ] - # Downsample to raw audio scale - for i, ratio in enumerate(self.ratios): - block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm - # Add residual layers - for j in range(n_residual_layers): - model += [ - SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1], - dilations=[dilation_base ** j, 1], - norm=block_norm, norm_params=norm_params, - activation=activation, activation_params=activation_params, - causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)] - - # Add downsampling layers - model += [ - act(**activation_params), - StreamableConv1d(mult * n_filters, mult * n_filters * 2, - kernel_size=ratio * 2, stride=ratio, - norm=block_norm, norm_kwargs=norm_params, - causal=causal, pad_mode=pad_mode), - ] - mult *= 2 - - if lstm: - model += [StreamableLSTM(mult * n_filters, num_layers=lstm)] - - model += [ - act(**activation_params), - StreamableConv1d(mult * n_filters, dimension, last_kernel_size, - norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm, - norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) - ] - - self.model = nn.Sequential(*model) - - def forward(self, x): - return self.model(x) - - -class SEANetDecoder(nn.Module): - """SEANet decoder. - - Args: - channels (int): Audio channels. - dimension (int): Intermediate representation dimension. - n_filters (int): Base width for the model. - n_residual_layers (int): nb of residual layers. - ratios (Sequence[int]): kernel size and stride ratios. - activation (str): Activation function. - activation_params (dict): Parameters to provide to the activation function. - final_activation (str): Final activation function after all convolutions. - final_activation_params (dict): Parameters to provide to the activation function. - norm (str): Normalization method. - norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. - kernel_size (int): Kernel size for the initial convolution. - last_kernel_size (int): Kernel size for the initial convolution. - residual_kernel_size (int): Kernel size for the residual layers. - dilation_base (int): How much to increase the dilation with each layer. - causal (bool): Whether to use fully causal convolution. - pad_mode (str): Padding mode for the convolutions. - true_skip (bool): Whether to use true skip connection or a simple. - (streamable) convolution as the skip connection in the residual network blocks. - compress (int): Reduced dimensionality in residual branches (from Demucs v3). - lstm (int): Number of LSTM layers at the end of the encoder. - disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. - For the decoder, it corresponds to the N last blocks. - trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. - If equal to 1.0, it means that all the trimming is done at the right. - """ - def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, - ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, - final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None, - norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, - last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, - pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, - disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0): - super().__init__() - self.dimension = dimension - self.channels = channels - self.n_filters = n_filters - self.ratios = ratios - del ratios - self.n_residual_layers = n_residual_layers - self.hop_length = np.prod(self.ratios) - self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks - self.disable_norm_outer_blocks = disable_norm_outer_blocks - assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \ - "Number of blocks for which to disable norm is invalid." \ - "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." - - act = getattr(nn, activation) - mult = int(2 ** len(self.ratios)) - model: tp.List[nn.Module] = [ - StreamableConv1d(dimension, mult * n_filters, kernel_size, - norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm, - norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) - ] - - if lstm: - model += [StreamableLSTM(mult * n_filters, num_layers=lstm)] - - # Upsample to raw audio scale - for i, ratio in enumerate(self.ratios): - block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm - # Add upsampling layers - model += [ - act(**activation_params), - StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2, - kernel_size=ratio * 2, stride=ratio, - norm=block_norm, norm_kwargs=norm_params, - causal=causal, trim_right_ratio=trim_right_ratio), - ] - # Add residual layers - for j in range(n_residual_layers): - model += [ - SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1], - dilations=[dilation_base ** j, 1], - activation=activation, activation_params=activation_params, - norm=block_norm, norm_params=norm_params, causal=causal, - pad_mode=pad_mode, compress=compress, true_skip=true_skip)] - - mult //= 2 - - # Add final layers - model += [ - act(**activation_params), - StreamableConv1d(n_filters, channels, last_kernel_size, - norm='none' if self.disable_norm_outer_blocks >= 1 else norm, - norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) - ] - # Add optional final activation to decoder (eg. tanh) - if final_activation is not None: - final_act = getattr(nn, final_activation) - final_activation_params = final_activation_params or {} - model += [ - final_act(**final_activation_params) - ] - self.model = nn.Sequential(*model) - - def forward(self, z): - y = self.model(z) - return y diff --git a/audiocraft/audiocraft/modules/streaming.py b/audiocraft/audiocraft/modules/streaming.py deleted file mode 100644 index fba06936294ca15d72acd2d44f9dbda39a638107..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/modules/streaming.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Streaming module API that should be implemented by all Streaming components, -""" - -from contextlib import contextmanager -import typing as tp -from torch import nn -import torch - - -State = tp.Dict[str, torch.Tensor] - - -class StreamingModule(nn.Module): - """Common API for streaming components. - - Each streaming component has a streaming state, which is just a dict[str, Tensor]. - By convention, the first dim of each tensor must be the batch size. - Don't use dots in the key names, as this would clash with submodules - (like in state_dict). - - If `self._is_streaming` is True, the component should use and remember - the proper state inside `self._streaming_state`. - - To set a streaming component in streaming state, use - - with module.streaming(): - ... - - This will automatically reset the streaming state when exiting the context manager. - This also automatically propagates to all streaming children module. - - Some module might also implement the `StreamingModule.flush` method, although - this one is trickier, as all parents module must be StreamingModule and implement - it as well for it to work properly. See `StreamingSequential` after. - """ - def __init__(self) -> None: - super().__init__() - self._streaming_state: State = {} - self._is_streaming = False - - def _apply_named_streaming(self, fn: tp.Any): - for name, module in self.named_modules(): - if isinstance(module, StreamingModule): - fn(name, module) - - def _set_streaming(self, streaming: bool): - def _set_streaming(name, module): - module._is_streaming = streaming - self._apply_named_streaming(_set_streaming) - - @contextmanager - def streaming(self): - """Context manager to enter streaming mode. Reset streaming state on exit.""" - self._set_streaming(True) - try: - yield - finally: - self._set_streaming(False) - self.reset_streaming() - - def reset_streaming(self): - """Reset the streaming state.""" - def _reset(name: str, module: StreamingModule): - module._streaming_state.clear() - - self._apply_named_streaming(_reset) - - def get_streaming_state(self) -> State: - """Return the streaming state, including that of sub-modules.""" - state: State = {} - - def _add(name: str, module: StreamingModule): - if name: - name += "." - for key, value in module._streaming_state.items(): - state[name + key] = value - - self._apply_named_streaming(_add) - return state - - def set_streaming_state(self, state: State): - """Set the streaming state, including that of sub-modules.""" - state = dict(state) - - def _set(name: str, module: StreamingModule): - if name: - name += "." - module._streaming_state.clear() - for key, value in list(state.items()): - # complexity is not ideal here, but probably fine. - if key.startswith(name): - local_key = key[len(name):] - if '.' not in local_key: - module._streaming_state[local_key] = value - del state[key] - - self._apply_named_streaming(_set) - assert len(state) == 0, list(state.keys()) - - def flush(self, x: tp.Optional[torch.Tensor] = None): - """Flush any remaining outputs that were waiting for completion. - Typically, for convolutions, this will add the final padding - and process the last buffer. - - This should take an optional argument `x`, which will be provided - if a module before this one in the streaming pipeline has already - spitted out a flushed out buffer. - """ - if x is None: - return None - else: - return self(x) - - -class StreamingSequential(StreamingModule, nn.Sequential): - """A streaming compatible alternative of `nn.Sequential`. - """ - def flush(self, x: tp.Optional[torch.Tensor] = None): - for module in self: - if isinstance(module, StreamingModule): - x = module.flush(x) - elif x is not None: - x = module(x) - return x diff --git a/audiocraft/audiocraft/modules/transformer.py b/audiocraft/audiocraft/modules/transformer.py deleted file mode 100644 index cdc45cf87ad44e2bed3c7f5499429c87d81797c0..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/modules/transformer.py +++ /dev/null @@ -1,752 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Transformer model, with streaming support, xformer attention support -and easy causal attention with a potentially finite receptive field. - -See `StreamingTransformer` for more information. - -Unlike regular PyTorch Transformer, we make the hard choice that batches are first. -""" - -import typing as tp - -from einops import rearrange -import torch -import torch.nn as nn -from torch.nn import functional as F -from torch.utils.checkpoint import checkpoint as torch_checkpoint -from xformers import ops - -from .rope import RotaryEmbedding -from .streaming import StreamingModule - -_efficient_attention_backend: str = 'torch' - - -def set_efficient_attention_backend(backend: str = 'torch'): - # Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster). - global _efficient_attention_backend - assert _efficient_attention_backend in ['xformers', 'torch'] - _efficient_attention_backend = backend - - -def _get_attention_time_dimension() -> int: - if _efficient_attention_backend == 'torch': - return 2 - else: - return 1 - - -def _is_profiled() -> bool: - # Return true if we are currently running with a xformers profiler activated. - try: - from xformers.profiler import profiler - except ImportError: - return False - return profiler._Profiler._CURRENT_PROFILER is not None - - -def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module: - """Create normalization module for transformer encoder layer. - - Args: - norm_type (str): Normalization method. - dim (int): Dimension of the normalized layer. - **kwargs (dict): Additional parameters for normalization layer. - Returns: - nn.Module: Normalization module. - """ - if norm_type == 'layer_norm': - return nn.LayerNorm(dim, eps=1e-5, **kwargs) - else: - raise ValueError(f"Unknown norm type: {norm_type}") - - -def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000, - dtype: torch.dtype = torch.float32) -> torch.Tensor: - """Create sinusoidal positional embedding, with shape `[B, T, C]`. - - Args: - positions (torch.Tensor): LongTensor of positions. - dim (int): Dimension of the embedding. - max_period (float): Maximum period of the cosine/sine functions. - dtype (torch.dtype or str): dtype to use to generate the embedding. - Returns: - torch.Tensor: Sinusoidal positional embedding. - """ - # We aim for BTC format - assert dim % 2 == 0 - half_dim = dim // 2 - positions = positions.to(dtype) - adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1) - max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point - phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) - return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) - - -def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers.""" - if n_rep == 1: - return x - if _efficient_attention_backend == 'torch': - bs, n_kv_heads, slen, head_dim = x.shape - return ( - x[:, :, None, :, :] - .expand(bs, n_kv_heads, n_rep, slen, head_dim) - .reshape(bs, n_kv_heads * n_rep, slen, head_dim) - ) - else: - bs, slen, n_kv_heads, head_dim = x.shape - return ( - x[:, :, :, None, :] - .expand(bs, slen, n_kv_heads, n_rep, head_dim) - .reshape(bs, slen, n_kv_heads * n_rep, head_dim) - ) - - -class LayerScale(nn.Module): - """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). - This rescales diagonally the residual outputs close to 0, with a learnt scale. - - Args: - channels (int): Number of channels. - init (float): Initial scale. - channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`. - device (torch.device or str, optional): Device on which to initialize the module. - dtype (torch.dtype, optional): dtype to use to initialize the module. - """ - def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True, - device=None, dtype=None): - super().__init__() - self.channel_last = channel_last - self.scale = nn.Parameter( - torch.full((channels,), init, - requires_grad=True, device=device, dtype=dtype)) - - def forward(self, x: torch.Tensor): - if self.channel_last: - return self.scale * x - else: - return self.scale[:, None] * x - - -class StreamingMultiheadAttention(StreamingModule): - """Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation. - - Args: - embed_dim (int): Dimension to project to. - num_heads (int): Number of heads. - dropout (float): Dropout level. - bias (bool): Use bias in projections. - causal (bool): Causal mask applied automatically. - past_context (int, optional): Receptive field for the causal mask, infinite if None. - custom (bool): Use custom MHA implementation, for testing / benchmarking. - memory_efficient (bool): Use xformers based memory efficient attention. - attention_as_float32 (bool): Perform the attention as float32 - (especially important with memory_efficient as autocast won't do this automatically). - rope (`RotaryEmbedding`, optional): Rope embedding to use. - cross_attention: Should be true when used as a cross attention. - All keys and values must be available at once, streaming is only for the queries. - Cannot be used with `causal` or `rope` (as it wouldn't make sens to - interpret the time steps in the keys relative to those in the queries). - safe_streaming (bool): Bug fix, will go away with xformers update. - qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product. - kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). - This will lead to faster decoding time on A100 or other GPUs with tensorcore. - device (torch.device, optional): Device on which to initialize. - dtype (torch.dtype, optional): dtype to use. - """ - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, - causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False, - memory_efficient: bool = False, attention_as_float32: bool = False, - rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False, - safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1, - device=None, dtype=None): - super().__init__() - factory_kwargs = {'device': device, 'dtype': dtype} - if past_context is not None: - assert causal - - self.embed_dim = embed_dim - self.causal = causal - self.past_context = past_context - self.memory_efficient = memory_efficient - self.attention_as_float32 = attention_as_float32 - self.rope = rope - self.cross_attention = cross_attention - self.safe_streaming = safe_streaming - self.num_heads = num_heads - self.dropout = dropout - self.kv_repeat = kv_repeat - if cross_attention: - assert not causal, "Causal cannot work with cross attention." - assert rope is None, "Rope cannot work with cross attention." - - if memory_efficient: - _verify_xformers_memory_efficient_compat() - - self.custom = _is_custom(custom, memory_efficient) - if self.custom: - out_dim = embed_dim - assert num_heads % kv_repeat == 0 - assert not cross_attention or kv_repeat == 1 - num_kv = num_heads // kv_repeat - kv_dim = (embed_dim // num_heads) * num_kv - out_dim += 2 * kv_dim - in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs) - # We try to follow the default PyTorch MHA convention, to easily compare results. - self.in_proj_weight = in_proj.weight - self.in_proj_bias = in_proj.bias - if bias: - self.in_proj_bias.data.zero_() # Following Pytorch convention - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) - if bias: - self.out_proj.bias.data.zero_() - else: - assert not qk_layer_norm - assert kv_repeat == 1 - self.mha = nn.MultiheadAttention( - embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True, - **factory_kwargs) - self.qk_layer_norm = qk_layer_norm - if qk_layer_norm: - assert self.custom - assert kv_repeat == 1 - ln_dim = embed_dim - self.q_layer_norm = nn.LayerNorm(ln_dim) - self.k_layer_norm = nn.LayerNorm(ln_dim) - - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): - if not self.custom: - # Support compat with regular MHA - keys = [n for n, _ in self.mha.named_parameters()] - for key in keys: - if prefix + key in state_dict: - state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key) - super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) - - def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype): - # Return a causal mask, accounting for potentially stored past keys/values - # We actually return a bias for the attention score, as this has the same - # convention both in the builtin MHA in Pytorch, and Xformers functions. - time_dim = _get_attention_time_dimension() - if self.memory_efficient: - from xformers.ops import LowerTriangularMask - if current_steps == 1: - # If we only have one step, then we do not need a mask. - return None - elif 'past_keys' in self._streaming_state: - raise RuntimeError("Not supported at the moment") - else: - # Then we can safely use a lower triangular mask - return LowerTriangularMask() - if self._streaming_state: - past_keys = self._streaming_state['past_keys'] - past_steps = past_keys.shape[time_dim] - else: - past_steps = 0 - - queries_pos = torch.arange( - past_steps, current_steps + past_steps, device=device).view(-1, 1) - keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1) - delta = queries_pos - keys_pos - valid = delta >= 0 - if self.past_context is not None: - valid &= (delta <= self.past_context) - return torch.where( - valid, - torch.zeros([], device=device, dtype=dtype), - torch.full([], float('-inf'), device=device, dtype=dtype)) - - def _complete_kv(self, k, v): - time_dim = _get_attention_time_dimension() - if self.cross_attention: - # With cross attention we assume all keys and values - # are already available, and streaming is with respect - # to the queries only. - return k, v - # Complete the key/value pair using the streaming state. - if self._streaming_state: - pk = self._streaming_state['past_keys'] - nk = torch.cat([pk, k], dim=time_dim) - if v is k: - nv = nk - else: - pv = self._streaming_state['past_values'] - nv = torch.cat([pv, v], dim=time_dim) - else: - nk = k - nv = v - - assert nk.shape[time_dim] == nv.shape[time_dim] - offset = 0 - if self.past_context is not None: - offset = max(0, nk.shape[time_dim] - self.past_context) - if self._is_streaming: - self._streaming_state['past_keys'] = nk[:, offset:] - if v is not k: - self._streaming_state['past_values'] = nv[:, offset:] - if 'offset' in self._streaming_state: - self._streaming_state['offset'] += offset - else: - self._streaming_state['offset'] = torch.tensor(0) - return nk, nv - - def _apply_rope(self, query: torch.Tensor, key: torch.Tensor): - # TODO: fix and verify layout. - assert _efficient_attention_backend == 'xformers', "Rope not supported with torch attn." - # Apply rope embeddings to query and key tensors. - assert self.rope is not None - if 'past_keys' in self._streaming_state: - past_keys_offset = self._streaming_state['past_keys'].shape[1] - else: - past_keys_offset = 0 - if 'offset' in self._streaming_state: - past_context_offset = int(self._streaming_state['offset'].item()) - else: - past_context_offset = 0 - streaming_offset = past_context_offset + past_keys_offset - return self.rope.rotate_qk(query, key, start=streaming_offset) - - def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - key_padding_mask=None, need_weights=False, attn_mask=None, - average_attn_weights=True, is_causal=False): - assert attn_mask is None - assert not is_causal, ("New param added in torch 2.0.1 not supported, " - "use the causal args in the constructor.") - - time_dim = _get_attention_time_dimension() - if time_dim == 2: - layout = "b h t d" - else: - layout = "b t h d" - dtype = query.dtype - if self._is_streaming: - assert self.causal or self.cross_attention, \ - "Streaming only available for causal or cross attention" - - if self.causal: - # At the moment we specialize only for the self-attention case. - assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value" - assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value" - attn_mask = self._get_mask(query.shape[1], query.device, query.dtype) - - if self.custom: - # custom implementation - assert need_weights is False - assert key_padding_mask is None - if self.cross_attention: - # Different queries, keys, values, we have to spit manually the weights - # before applying the linear. - dim = self.in_proj_weight.shape[0] // 3 - if self.in_proj_bias is None: - bias_q, bias_k, bias_v = None, None, None - else: - bias_q = self.in_proj_bias[:dim] - bias_k = self.in_proj_bias[dim: 2 * dim] - bias_v = self.in_proj_bias[2 * dim:] - q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q) - # todo: when streaming, we could actually save k, v and check the shape actually match. - k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k) - v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v) - if self.qk_layer_norm is True: - q = self.q_layer_norm(q) - k = self.k_layer_norm(k) - q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]] - else: - if not _is_profiled(): - # profiling breaks that propertysomehow. - assert query is key, "specialized implementation" - assert value is key, "specialized implementation" - projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias) - if self.kv_repeat == 1: - if time_dim == 2: - bound_layout = "b h p t d" - else: - bound_layout = "b t p h d" - packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads) - q, k, v = ops.unbind(packed, dim=2) - else: - embed_dim = self.embed_dim - per_head_dim = (embed_dim // self.num_heads) - kv_heads = self.num_heads // self.kv_repeat - q = projected[:, :, :embed_dim] - start = embed_dim - end = start + per_head_dim * kv_heads - k = projected[:, :, start: end] - v = projected[:, :, end:] - q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads) - k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads) - v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads) - - if self.qk_layer_norm is True: - assert self.kv_repeat == 1 - q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]] - q = self.q_layer_norm(q) - k = self.k_layer_norm(k) - q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]] - if self.rope: - q, k = self._apply_rope(q, k) - k, v = self._complete_kv(k, v) - if self.kv_repeat > 1: - k = expand_repeated_kv(k, self.kv_repeat) - v = expand_repeated_kv(v, self.kv_repeat) - if self.attention_as_float32: - q, k, v = [x.float() for x in [q, k, v]] - if self.memory_efficient: - p = self.dropout if self.training else 0 - if _efficient_attention_backend == 'torch': - x = torch.nn.functional.scaled_dot_product_attention( - q, k, v, is_causal=attn_mask is not None, dropout_p=p) - else: - x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p) - else: - # We include the dot product as float32, for consistency - # with the other implementations that include that step - # as part of the attention. Note that when using `autocast`, - # the einsums would be done as bfloat16, but the softmax - # would be done as bfloat16, so `attention_as_float32` will - # extend a bit the range of operations done in float32, - # although this should make no difference. - q = q / q.shape[-1] ** 0.5 - key_layout = layout.replace('t', 'k') - query_layout = layout - if self._is_streaming and self.safe_streaming and q.device.type == 'cuda': - with torch.autocast(device_type=q.device.type, dtype=torch.float32): - pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k) - else: - pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k) - if attn_mask is not None: - pre_w = pre_w + attn_mask - w = torch.softmax(pre_w, dim=-1) - w = F.dropout(w, self.dropout, training=self.training).to(v) - # Key and value have the same format. - x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v) - x = x.to(dtype) - x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads) - x = self.out_proj(x) - else: - key, value = self._complete_kv(key, value) - if self.attention_as_float32: - query, key, value = [x.float() for x in [query, key, value]] - x, _ = self.mha( - query, key, value, key_padding_mask, - need_weights, attn_mask, average_attn_weights) - x = x.to(dtype) - - return x, None - - -class StreamingTransformerLayer(nn.TransformerEncoderLayer): - """TransformerLayer with Streaming / Causal support. - This also integrates cross_attention, when passing `cross_attention=True`, - rather than having two separate classes like in PyTorch. - - Args: - d_model (int): Dimension of the data. - num_heads (int): Number of heads. - dim_feedforward (int): Intermediate dimension of FF module. - dropout (float): Dropout both for MHA and FF. - bias_ff (bool): Use bias for FF. - bias_attn (bool): Use bias for MHA. - causal (bool): Causal mask applied automatically. - past_context (int, optional): Receptive field for the causal mask, infinite if None. - custom (bool): Use custom MHA implementation, for testing / benchmarking. - memory_efficient (bool): Use xformers based memory efficient attention. - attention_as_float32 (bool): Perform the attention as float32 - (especially important with memory_efficient as autocast won't do this automatically). - qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention. - qk_layer_norm_cross (bool): Same for the cross attention. - cross_attention (bool): If True, expect to get secondary input for cross-attention. - Cross attention will use the default MHA, as it typically won't require - special treatment. - layer_scale (float, optional): If not None, LayerScale will be used with - the given value as initial scale. - rope (`RotaryEmbedding`, optional): Rope embedding to use. - attention_dropout (float, optional): If not None, separate the value of the dimension dropout - in FFN and of the attention dropout. - kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). - This will lead to faster decoding time on A100 or other GPUs with tensorcore. - device (torch.device, optional): Device on which to initialize. - dtype (torch.dtype, optional): dtype to use. - **kwargs: See `nn.TransformerEncoderLayer`. - """ - def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1, - bias_ff: bool = True, bias_attn: bool = True, causal: bool = False, - past_context: tp.Optional[int] = None, custom: bool = False, - memory_efficient: bool = False, attention_as_float32: bool = False, - qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False, - cross_attention: bool = False, layer_scale: tp.Optional[float] = None, - rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None, - kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs): - super().__init__(d_model, num_heads, dim_feedforward, dropout, - device=device, dtype=dtype, batch_first=True, **kwargs) - factory_kwargs = {'device': device, 'dtype': dtype} - # Redefine self_attn to our streaming multi-head attention - attn_kwargs: tp.Dict[str, tp.Any] = { - 'embed_dim': d_model, - 'num_heads': num_heads, - 'dropout': dropout if attention_dropout is None else attention_dropout, - 'bias': bias_attn, - 'custom': custom, - 'memory_efficient': memory_efficient, - 'attention_as_float32': attention_as_float32, - } - self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention( - causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm, - kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore - # Redefine feedforward layers to expose bias parameter - self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs) - self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs) - - self.layer_scale_1: nn.Module - self.layer_scale_2: nn.Module - if layer_scale is None: - self.layer_scale_1 = nn.Identity() - self.layer_scale_2 = nn.Identity() - else: - self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) - self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) - - self.cross_attention: tp.Optional[nn.Module] = None - if cross_attention: - self.cross_attention = StreamingMultiheadAttention( - cross_attention=True, qk_layer_norm=qk_layer_norm_cross, - **attn_kwargs, **factory_kwargs) - # Norm and dropout - self.dropout_cross = nn.Dropout(dropout) - # eps value matching that used in PyTorch reference implementation. - self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs) - self.layer_scale_cross: nn.Module - if layer_scale is None: - self.layer_scale_cross = nn.Identity() - else: - self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs) - self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore - self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore - - def _cross_attention_block(self, src: torch.Tensor, - cross_attention_src: torch.Tensor) -> torch.Tensor: - assert self.cross_attention is not None - # queries are from src, keys and values from cross_attention_src. - x = self.cross_attention( - src, cross_attention_src, cross_attention_src, need_weights=False)[0] - return self.dropout_cross(x) # type: ignore - - def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None, # type: ignore - src_key_padding_mask: tp.Optional[torch.Tensor] = None, - cross_attention_src: tp.Optional[torch.Tensor] = None): - if self.cross_attention is None: - assert cross_attention_src is None - else: - assert cross_attention_src is not None - x = src - if self.norm_first: - x = x + self.layer_scale_1( - self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)) - if cross_attention_src is not None: - x = x + self.layer_scale_cross( - self._cross_attention_block( - self.norm_cross(x), cross_attention_src)) - x = x + self.layer_scale_2(self._ff_block(self.norm2(x))) - else: - x = self.norm1(x + self.layer_scale_1( - self._sa_block(x, src_mask, src_key_padding_mask))) - if cross_attention_src is not None: - x = self.norm_cross( - x + self.layer_scale_cross( - self._cross_attention_block(src, cross_attention_src))) - x = self.norm2(x + self.layer_scale_2(self._ff_block(x))) - return x - - -class StreamingTransformer(StreamingModule): - """Transformer with Streaming / Causal support. - - Args: - d_model (int): Dimension of the data. - num_heads (int): Number of heads. - dim_feedforward (int): Intermediate dimension of FF module. - dropout (float): Dropout both for MHA and FF. - bias_ff (bool): Use bias for FF. - bias_attn (bool): Use bias for MHA. - causal (bool): Causal mask applied automatically. - past_context (int, optional): Receptive field for the causal mask, infinite if None. - custom (bool): Use custom MHA implementation, for testing / benchmarking. - memory_efficient (bool): Use xformers based memory efficient attention. - attention_as_float32 (bool): Perform the attention as float32 - (especially important with memory_efficient as autocast won't do this automatically). - cross_attention (bool): If True, expect to get secondary input for cross-attention. - layer_scale (float, optional): If not None, LayerScale will be used - with the given value as initial scale. - positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope). - max_period (float): Maximum period of the time embedding. - positional_scale (float): Scale of positional embedding, set to 0 to deactivate. - xpos (bool): Apply xpos exponential decay to positional embedding (rope only). - lr (float, optional): learning rate override through the `make_optim_group` API. - weight_decay (float, optional): Weight_decay override through the `make_optim_group` API. - layer_class: (subclass of `StreamingTransformerLayer): class to use - to initialize the layers, allowing further customization outside of AudioCraft. - checkpointing (str): Checkpointing strategy to reduce memory usage. - No checkpointing if set to 'none'. Per layer checkpointing using PyTorch - if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice, - minimal memory usage, but maximal runtime). Finally, `xformers_default` provide - a policy for opting-out some operations of the checkpointing like - linear layers and attention, providing a middle ground between speed and memory. - device (torch.device, optional): Device on which to initialize. - dtype (torch.dtype, optional): dtype to use. - **kwargs: See `nn.TransformerEncoderLayer`. - """ - def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048, - dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True, - causal: bool = False, past_context: tp.Optional[int] = None, - custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, - cross_attention: bool = False, layer_scale: tp.Optional[float] = None, - positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1., - xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None, - layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer, - checkpointing: str = 'none', device=None, dtype=None, **kwargs): - super().__init__() - assert d_model % num_heads == 0 - - self.positional_embedding = positional_embedding - self.max_period = max_period - self.positional_scale = positional_scale - self.weight_decay = weight_decay - self.lr = lr - - assert positional_embedding in ['sin', 'rope', 'sin_rope'] - self.rope: tp.Optional[RotaryEmbedding] = None - if self.positional_embedding in ['rope', 'sin_rope']: - assert _is_custom(custom, memory_efficient) - self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period, - xpos=xpos, scale=positional_scale, device=device) - - self.checkpointing = checkpointing - - assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm'] - if self.checkpointing.startswith('xformers'): - _verify_xformers_internal_compat() - - self.layers = nn.ModuleList() - for idx in range(num_layers): - self.layers.append( - layer_class( - d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward, - dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn, - causal=causal, past_context=past_context, custom=custom, - memory_efficient=memory_efficient, attention_as_float32=attention_as_float32, - cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope, - device=device, dtype=dtype, **kwargs)) - - if self.checkpointing != 'none': - for layer in self.layers: - # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the - # backward hook inside of FSDP... - layer._magma_checkpointed = True # type: ignore - assert layer.layer_drop == 0., "Need further checking" # type: ignore - - def _apply_layer(self, layer, *args, **kwargs): - method = self.checkpointing - if method == 'none': - return layer(*args, **kwargs) - elif method == 'torch': - return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs) - elif method.startswith('xformers'): - from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy - if method == 'xformers_default': - # those operations will be saved, and not recomputed. - # According to Francisco we can get smarter policies but this is a good start. - allow_list = [ - "xformers.efficient_attention_forward_cutlass.default", - "xformers_flash.flash_fwd.default", - "aten.addmm.default", - "aten.mm.default", - ] - elif method == 'xformers_mm': - # those operations will be saved, and not recomputed. - # According to Francisco we can get smarter policies but this is a good start. - allow_list = [ - "aten.addmm.default", - "aten.mm.default", - ] - else: - raise ValueError(f"xformers checkpointing xformers policy {method} is not known.") - policy_fn = _get_default_policy(allow_list) - return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs) - else: - raise ValueError(f"Checkpointing method {method} is unknown.") - - def forward(self, x: torch.Tensor, in_attn_src: torch.Tensor, *args, **kwargs): - B, T, C = x.shape - if in_attn_src is not None: - _, in_attn_t, _ = in_attn_src.shape - - if 'offsets' in self._streaming_state: - offsets = self._streaming_state['offsets'] - else: - offsets = torch.zeros(B, dtype=torch.long, device=x.device) - - if self.positional_embedding in ['sin', 'sin_rope']: - positions = torch.arange(T, device=x.device).view(1, -1, 1) - positions = positions + offsets.view(-1, 1, 1) - pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype) - x = x + self.positional_scale * pos_emb - - for idx, layer in enumerate(self.layers): - if (idx % 4 == 0) and (idx < 36) and (idx != 0): - if in_attn_src is not None: - x[:, -in_attn_t:, :] += in_attn_src - x = self._apply_layer(layer, x, *args, **kwargs) - - if self._is_streaming: - self._streaming_state['offsets'] = offsets + T - - return x - - def make_optim_group(self): - group = {"params": list(self.parameters())} - if self.lr is not None: - group["lr"] = self.lr - if self.weight_decay is not None: - group["weight_decay"] = self.weight_decay - return group - - -# special attention related function - -def _verify_xformers_memory_efficient_compat(): - try: - from xformers.ops import memory_efficient_attention, LowerTriangularMask # noqa - except ImportError: - raise ImportError( - "xformers is not installed. Please install it and try again.\n" - "To install on AWS and Azure, run \n" - "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n" - "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n" - "To install on FAIR Cluster, run \n" - "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n" - "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n") - - -def _verify_xformers_internal_compat(): - try: - from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy # noqa - except ImportError: - raise ImportError( - "Francisco's fairinternal xformers is not installed. Please install it and try again.\n" - "To install on AWS and Azure, run \n" - "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n" - "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n" - "To install on FAIR Cluster, run \n" - "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n" - "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n") - - -def _is_custom(custom: bool, memory_efficient: bool): - return custom or memory_efficient diff --git a/audiocraft/audiocraft/optim/__init__.py b/audiocraft/audiocraft/optim/__init__.py deleted file mode 100644 index f48c17dfafa9a2be46a91ed1fb64f54c5572a730..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/optim/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Optimization stuff. In particular, optimizers (DAdaptAdam), schedulers -and Exponential Moving Average. -""" - -# flake8: noqa -from .cosine_lr_scheduler import CosineLRScheduler -from .dadam import DAdaptAdam -from .inverse_sqrt_lr_scheduler import InverseSquareRootLRScheduler -from .linear_warmup_lr_scheduler import LinearWarmupLRScheduler -from .polynomial_decay_lr_scheduler import PolynomialDecayLRScheduler -from .ema import ModuleDictEMA diff --git a/audiocraft/audiocraft/optim/__pycache__/__init__.cpython-311.pyc b/audiocraft/audiocraft/optim/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index ea7df3e581582945cdb2f4c3f929e32dadf8a213..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/optim/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/optim/__pycache__/cosine_lr_scheduler.cpython-311.pyc b/audiocraft/audiocraft/optim/__pycache__/cosine_lr_scheduler.cpython-311.pyc deleted file mode 100644 index d5097cf28e02745c4a6d4b847e74e74cd544af4b..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/optim/__pycache__/cosine_lr_scheduler.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/optim/__pycache__/dadam.cpython-311.pyc b/audiocraft/audiocraft/optim/__pycache__/dadam.cpython-311.pyc deleted file mode 100644 index 7f74a9afebdecda78edf2983f0c37e2dd8398b65..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/optim/__pycache__/dadam.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/optim/__pycache__/ema.cpython-311.pyc b/audiocraft/audiocraft/optim/__pycache__/ema.cpython-311.pyc deleted file mode 100644 index 2eaca706c5f31558985e0f3c095393c7a746df4a..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/optim/__pycache__/ema.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/optim/__pycache__/fsdp.cpython-311.pyc b/audiocraft/audiocraft/optim/__pycache__/fsdp.cpython-311.pyc deleted file mode 100644 index 27f6b9f5076074bcb30c8296fc22cff8c6c09f99..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/optim/__pycache__/fsdp.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/optim/__pycache__/inverse_sqrt_lr_scheduler.cpython-311.pyc b/audiocraft/audiocraft/optim/__pycache__/inverse_sqrt_lr_scheduler.cpython-311.pyc deleted file mode 100644 index 646ae383db386d84dd5370e0eb98633c00d7e63a..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/optim/__pycache__/inverse_sqrt_lr_scheduler.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/optim/__pycache__/linear_warmup_lr_scheduler.cpython-311.pyc b/audiocraft/audiocraft/optim/__pycache__/linear_warmup_lr_scheduler.cpython-311.pyc deleted file mode 100644 index 76489594d3db29881ec02557a42b99a66f19c6c1..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/optim/__pycache__/linear_warmup_lr_scheduler.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/optim/__pycache__/polynomial_decay_lr_scheduler.cpython-311.pyc b/audiocraft/audiocraft/optim/__pycache__/polynomial_decay_lr_scheduler.cpython-311.pyc deleted file mode 100644 index 95055259f9bf9cd6bb2274ec11276d004afd422c..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/optim/__pycache__/polynomial_decay_lr_scheduler.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/optim/cosine_lr_scheduler.py b/audiocraft/audiocraft/optim/cosine_lr_scheduler.py deleted file mode 100644 index 1e4f0bbf28f1ad893a301f1bfac1da8e97370337..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/optim/cosine_lr_scheduler.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import math - -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler - - -class CosineLRScheduler(_LRScheduler): - """Cosine LR scheduler. - - Args: - optimizer (Optimizer): Torch optimizer. - warmup_steps (int): Number of warmup steps. - total_steps (int): Total number of steps. - lr_min_ratio (float): Minimum learning rate. - cycle_length (float): Cycle length. - """ - def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int, - lr_min_ratio: float = 0.0, cycle_length: float = 1.0): - self.warmup_steps = warmup_steps - assert self.warmup_steps >= 0 - self.total_steps = total_steps - assert self.total_steps >= 0 - self.lr_min_ratio = lr_min_ratio - self.cycle_length = cycle_length - super().__init__(optimizer) - - def _get_sched_lr(self, lr: float, step: int): - if step < self.warmup_steps: - lr_ratio = step / self.warmup_steps - lr = lr_ratio * lr - elif step <= self.total_steps: - s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) - lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \ - (1. + math.cos(math.pi * s / self.cycle_length)) - lr = lr_ratio * lr - else: - lr_ratio = self.lr_min_ratio - lr = lr_ratio * lr - return lr - - def get_lr(self): - return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs] diff --git a/audiocraft/audiocraft/optim/dadam.py b/audiocraft/audiocraft/optim/dadam.py deleted file mode 100644 index a84402f744867610180b9576b2ee3302501fd035..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/optim/dadam.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import logging -from typing import TYPE_CHECKING, Any - -import torch -import torch.optim -import torch.distributed as dist - -if TYPE_CHECKING: - from torch.optim.optimizer import _params_t -else: - _params_t = Any - - -logger = logging.getLogger(__name__) - - -def to_real(x): - if torch.is_complex(x): - return x.real - else: - return x - - -class DAdaptAdam(torch.optim.Optimizer): - """Adam with D-Adaptation automatic step-sizes. - Leave LR set to 1 unless you encounter instability. - - Args: - params (iterable): - Iterable of parameters to optimize or dicts defining parameter groups. - lr (float): - Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate. - betas (tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - momentum (float): - Momentum value in the range [0,1) (default: 0.9). - eps (float): - Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). - weight_decay (float): - Weight decay, i.e. a L2 penalty (default: 0). - log_every (int): - Log using print every k steps, default 0 (no logging). - decouple (boolean): - Use AdamW style decoupled weight decay - d0 (float): - Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. - growth_rate (float): - prevent the D estimate from growing faster than this multiplicative rate. - Default is inf, for unrestricted. Values like 1.02 give a kind of learning - rate warmup effect. - fsdp_in_use (bool): - If you're using sharded parameters, this should be set to True. The optimizer - will attempt to auto-detect this, but if you're using an implementation other - than PyTorch's builtin version, the auto-detection won't work. - """ - def __init__(self, params, lr=1.0, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - log_every=0, - decouple=True, - d0=1e-6, - growth_rate=float('inf')): - if not 0.0 < d0: - raise ValueError("Invalid d0 value: {}".format(d0)) - if not 0.0 < lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 < eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - - if decouple: - logger.info("Using decoupled weight decay") - - from .fsdp import is_fsdp_used - fsdp_in_use = is_fsdp_used() - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, - d=d0, - k=0, - gsq_weighted=0.0, - log_every=log_every, - decouple=decouple, - growth_rate=growth_rate, - fsdp_in_use=fsdp_in_use) - - super().__init__(params, defaults) - - @property - def supports_memory_efficient_fp16(self): - return False - - @property - def supports_flat_params(self): - return True - - def step(self, closure=None): - """Performs a single optimization step. - - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - g_sq = 0.0 - sksq_weighted = 0.0 - sk_l1 = 0.0 - - lr = max(group['lr'] for group in self.param_groups) - - group = self.param_groups[0] - gsq_weighted = group['gsq_weighted'] - d = group['d'] - dlr = d*lr - - growth_rate = group['growth_rate'] - decouple = group['decouple'] - fsdp_in_use = group['fsdp_in_use'] - log_every = group['log_every'] - - beta1, beta2 = group['betas'] - - for group in self.param_groups: - group_lr = group['lr'] - decay = group['weight_decay'] - k = group['k'] - eps = group['eps'] - - if group_lr not in [lr, 0.0]: - raise RuntimeError("Setting different lr values in different parameter " - "groups is only supported for values of 0") - - for p in group['params']: - if p.grad is None: - continue - if hasattr(p, "_fsdp_flattened"): - fsdp_in_use = True - grad = p.grad.data - - # Apply weight decay (coupled variant) - if decay != 0 and not decouple: - grad.add_(p.data, alpha=decay) - - state = self.state[p] - - # State initialization - if 'step' not in state: - state['step'] = 0 - state['s'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like( - to_real(p.data), memory_format=torch.preserve_format).detach() - - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - - grad_grad = to_real(grad * grad.conj()) - - # Adam EMA updates - if group_lr > 0: - exp_avg.mul_(beta1).add_(grad, alpha=dlr*(1-beta1)) - exp_avg_sq.mul_(beta2).add_(grad_grad, alpha=1-beta2) - - denom = exp_avg_sq.sqrt().add_(eps) - - g_sq += grad_grad.div_(denom).sum().item() - - s = state['s'] - s.mul_(beta2).add_(grad, alpha=dlr*(1-beta2)) - sksq_weighted += to_real(s * s.conj()).div_(denom).sum().item() - sk_l1 += s.abs().sum().item() - - ###### - - gsq_weighted = beta2*gsq_weighted + g_sq*(dlr**2)*(1-beta2) - d_hat = d - - # if we have not done any progres, return - # if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0) - if sk_l1 == 0: - return loss - - if lr > 0.0: - if fsdp_in_use: - dist_tensor = torch.zeros(3, device='cuda') - dist_tensor[0] = sksq_weighted - dist_tensor[1] = gsq_weighted - dist_tensor[2] = sk_l1 - dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) - global_sksq_weighted = dist_tensor[0] - global_gsq_weighted = dist_tensor[1] - global_sk_l1 = dist_tensor[2] - else: - global_sksq_weighted = sksq_weighted - global_gsq_weighted = gsq_weighted - global_sk_l1 = sk_l1 - - d_hat = (global_sksq_weighted/(1-beta2) - global_gsq_weighted)/global_sk_l1 - d = max(d, min(d_hat, d*growth_rate)) - - if log_every > 0 and k % log_every == 0: - logger.info( - f"(k={k}) dlr: {dlr:1.1e} d_hat: {d_hat:1.1e}, d: {d:1.8}. " - f"sksq_weighted={global_sksq_weighted:1.1e} gsq_weighted={global_gsq_weighted:1.1e} " - f"sk_l1={global_sk_l1:1.1e}{' (FSDP)' if fsdp_in_use else ''}") - - for group in self.param_groups: - group['gsq_weighted'] = gsq_weighted - group['d'] = d - - group_lr = group['lr'] - decay = group['weight_decay'] - k = group['k'] - eps = group['eps'] - - for p in group['params']: - if p.grad is None: - continue - grad = p.grad.data - - state = self.state[p] - - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - - state['step'] += 1 - - denom = exp_avg_sq.sqrt().add_(eps) - denom = denom.type(p.type()) - - # Apply weight decay (decoupled variant) - if decay != 0 and decouple and group_lr > 0: - p.data.add_(p.data, alpha=-decay * dlr) - - # Take step - p.data.addcdiv_(exp_avg, denom, value=-1) - - group['k'] = k + 1 - - return loss diff --git a/audiocraft/audiocraft/optim/ema.py b/audiocraft/audiocraft/optim/ema.py deleted file mode 100644 index 4337eaff066a8ca124dca3e3e63ee36e417c055c..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/optim/ema.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# ModelEMA implementation is taken from -# https://github.com/facebookresearch/demucs - -from collections import defaultdict -import typing as tp - -import torch -import torch.nn as nn - - -def _get_all_non_persistent_buffers_set(module: nn.Module, root: str = "") -> set: - names: set = set() - for (name, sub_module) in module.named_modules(): - if name == '': - buffer_names = module._non_persistent_buffers_set - buffer_names = {f"{root}.{buff_name}" if len(root) > 0 else buff_name - for buff_name in buffer_names} - names.update(buffer_names) - else: - sub_name = f"{root}.{name}" if len(root) > 0 else name - sub_buffer_names = _get_all_non_persistent_buffers_set(sub_module, sub_name) - names.update(sub_buffer_names) - return names - - -def _get_named_tensors(module: nn.Module): - non_persistent_buffers_set = _get_all_non_persistent_buffers_set(module) - named_buffers = [(name, buffer) for (name, buffer) in module.named_buffers() - if name not in non_persistent_buffers_set] - named_parameters = list(module.named_parameters()) - return named_parameters + named_buffers - - -class ModuleDictEMA: - """Exponential Moving Average over a nn.ModuleDict. - - You can switch to the EMA weights temporarily. - """ - def __init__(self, module_dict: nn.ModuleDict, decay: float = 0.999, - unbias: bool = True, device: tp.Union[torch.device, str] = 'cpu'): - self.decay = decay - self.module_dict = module_dict - self.state: dict = defaultdict(dict) - self.count = 0 - self.device = device - self.unbias = unbias - self._init() - - def _init(self): - for module_name, module in self.module_dict.items(): - for key, val in _get_named_tensors(module): - if not val.is_floating_point(): - continue - device = self.device or val.device - if key not in self.state[module_name]: - self.state[module_name][key] = val.detach().to(device, copy=True) - - def step(self): - if self.unbias: - self.count = self.count * self.decay + 1 - w = 1 / self.count - else: - w = 1 - self.decay - for module_name, module in self.module_dict.items(): - for key, val in _get_named_tensors(module): - if not val.is_floating_point(): - continue - device = self.device or val.device - self.state[module_name][key].mul_(1 - w) - self.state[module_name][key].add_(val.detach().to(device), alpha=w) - - def state_dict(self): - return {'state': self.state, 'count': self.count} - - def load_state_dict(self, state): - self.count = state['count'] - for module_name, module in state['state'].items(): - for key, val in module.items(): - self.state[module_name][key].copy_(val) diff --git a/audiocraft/audiocraft/optim/fsdp.py b/audiocraft/audiocraft/optim/fsdp.py deleted file mode 100644 index b3c1a55b6bf1a33092a021c5cefbbb2ae848918a..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/optim/fsdp.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Wrapper around FSDP for more convenient use in the training loops. -""" - -from contextlib import contextmanager -import typing as tp -import dora -import torch - -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import ( - MixedPrecision, ShardingStrategy, FullStateDictConfig, StateDictType) -from torch.distributed._shard.sharded_tensor.api import ShardedTensor - - -def is_fsdp_used() -> bool: - """Return whether we are using FSDP.""" - # A bit of a hack but should work from anywhere. - if dora.is_xp(): - cfg = dora.get_xp().cfg - if hasattr(cfg, 'fsdp'): - return cfg.fsdp.use - return False - - -def is_sharded_tensor(x: tp.Any) -> bool: - return isinstance(x, ShardedTensor) - - -@contextmanager -def switch_to_full_state_dict(models: tp.List[FSDP]): - # Another bug in FSDP makes it that we cannot use the `state_dict_type` API, - # so let's do thing manually. - for model in models: - FSDP.set_state_dict_type( # type: ignore - model, StateDictType.FULL_STATE_DICT, - FullStateDictConfig(offload_to_cpu=True, rank0_only=True)) - try: - yield - finally: - for model in models: - FSDP.set_state_dict_type(model, StateDictType.LOCAL_STATE_DICT) # type: ignore - - -def wrap_with_fsdp(cfg, model: torch.nn.Module, - block_classes: tp.Optional[tp.Set[tp.Type]] = None) -> FSDP: - """Wraps a model with FSDP.""" - # Some of the typing is disabled until this gets integrated - # into the stable version of PyTorch. - from torch.distributed.fsdp.wrap import ModuleWrapPolicy # type: ignore - - # we import this here to prevent circular import. - from ..modules.transformer import StreamingTransformerLayer - from ..modules.conditioners import ConditioningProvider - - _fix_post_backward_hook() - - assert cfg.use - sharding_strategy_dict = { - "no_shard": ShardingStrategy.NO_SHARD, - "shard_grad_op": ShardingStrategy.SHARD_GRAD_OP, - "full_shard": ShardingStrategy.FULL_SHARD, - } - - dtype_dict = { - "float32": torch.float32, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - } - - mixed_precision_config = MixedPrecision( - param_dtype=dtype_dict[cfg.param_dtype], - reduce_dtype=dtype_dict[cfg.reduce_dtype], - buffer_dtype=dtype_dict[cfg.buffer_dtype], - ) - - sharding_strategy_config = sharding_strategy_dict[cfg.sharding_strategy] - # The following is going to require being a bit smart - # when doing LM, because this would flush the weights for every time step - # during generation. One possiblity is to use hybrid sharding: - # See: https://pytorch.org/docs/master/fsdp.html#torch.distributed.fsdp.ShardingStrategy - assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \ - "Not supported at the moment, requires a bit more work." - - local_rank = dora.distrib.get_distrib_spec().local_rank - assert local_rank < torch.cuda.device_count(), "Please upgrade Dora!" - - auto_wrap_policy = None - if block_classes is None: - block_classes = {StreamingTransformerLayer, ConditioningProvider} - if cfg.per_block: - auto_wrap_policy = ModuleWrapPolicy(block_classes) - wrapped = _FSDPFixStateDict( - model, - sharding_strategy=sharding_strategy_config, - mixed_precision=mixed_precision_config, - device_id=local_rank, - sync_module_states=True, - use_orig_params=True, - auto_wrap_policy=auto_wrap_policy, - ) # type: ignore - FSDP.set_state_dict_type(wrapped, StateDictType.LOCAL_STATE_DICT) # type: ignore - - # Let the wrapped model know about the wrapping! - # We use __dict__ to avoid it going into the state dict. - # This is a bit dirty, but needed during generation, as otherwise - # the wrapped model would call itself and bypass FSDP. - for module in FSDP.fsdp_modules(wrapped): - original = module._fsdp_wrapped_module - original.__dict__['_fsdp'] = module - return wrapped - - -def purge_fsdp(model: FSDP): - """Purge the FSDP cached shard inside the model. This should - allow setting the best state or switching to the EMA. - """ - from torch.distributed.fsdp._runtime_utils import _reshard # type: ignore - for module in FSDP.fsdp_modules(model): - handles = module._handles - if not handles: - continue - handle = handles[0] - unsharded_flat_param = handle._get_padded_unsharded_flat_param() - storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore - if storage_size == 0: - continue - true_list = [True for h in handles] - _reshard(module, handles, true_list) - - -class _FSDPFixStateDict(FSDP): - @staticmethod - def _name_without_fsdp_prefix(name: str) -> str: - from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE # type: ignore - parts = name.split('.') - new_parts = [part for part in parts if part != FSDP_WRAPPED_MODULE] - return '.'.join(new_parts) - - def state_dict(self) -> tp.Dict[str, tp.Any]: # type: ignore - state = dict(super().state_dict()) - for key, value in list(state.items()): - if is_sharded_tensor(value): - del state[key] - return state - - def load_state_dict(self, state: tp.Dict[str, tp.Any]): # type: ignore - if self._state_dict_type is StateDictType.FULL_STATE_DICT: - super().load_state_dict(state) - purge_fsdp(self) - return - # Fix FSDP load state dict in all situation. - # Use this only with LOCAL_STATE_DICT !!! - current_state = dict(super().state_dict()) - for key, value in state.items(): - key = _FSDPFixStateDict._name_without_fsdp_prefix(key) - if key not in current_state: - # Emulate strict loading manually. - raise RuntimeError(f"Unknown state key {key}") - current_state[key].copy_(value) - - # Purging cached weights from previous forward. - purge_fsdp(self) - - -_hook_fixed = False - - -def _fix_post_backward_hook(): - global _hook_fixed - if _hook_fixed: - return - _hook_fixed = True - - from torch.distributed.fsdp import _runtime_utils - from torch.distributed.fsdp._common_utils import TrainingState, HandleTrainingState - old_hook = _runtime_utils._post_backward_hook - - def _post_backward_hook(state, handle, *args, **kwargs): - checkpointed = getattr(state._fsdp_wrapped_module, '_audiocraft_checkpointed', False) - if checkpointed: - # there will be one more forward in the backward with checkpointing and that will - # massively confuse FSDP, so we have to make it think everything - # is going according to the plan. - state.training_state = TrainingState.FORWARD_BACKWARD - handle._training_state = HandleTrainingState.BACKWARD_PRE - old_hook(state, handle, *args, **kwargs) - - _runtime_utils._post_backward_hook = _post_backward_hook diff --git a/audiocraft/audiocraft/optim/inverse_sqrt_lr_scheduler.py b/audiocraft/audiocraft/optim/inverse_sqrt_lr_scheduler.py deleted file mode 100644 index 920192e8842c5635bf6f7f76618fa4a6f4b0114a..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/optim/inverse_sqrt_lr_scheduler.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import typing as tp - -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler - - -class InverseSquareRootLRScheduler(_LRScheduler): - """Inverse square root LR scheduler. - - Args: - optimizer (Optimizer): Torch optimizer. - warmup_steps (int): Number of warmup steps. - warmup_init_lr (tp.Optional[float]): Initial learning rate - during warmup phase. When not set, use the provided learning rate. - """ - def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0): - self.warmup_steps = warmup_steps - self.warmup_init_lr = warmup_init_lr - super().__init__(optimizer) - - def _get_sched_lr(self, lr: float, step: int): - if step < self.warmup_steps: - warmup_init_lr = self.warmup_init_lr or 0 - lr_step = (lr - warmup_init_lr) / self.warmup_steps - lr = warmup_init_lr + step * lr_step - else: - decay_factor = lr * self.warmup_steps**0.5 - lr = decay_factor * step**-0.5 - return lr - - def get_lr(self): - return [self._get_sched_lr(base_lr, self._step_count) for base_lr in self.base_lrs] diff --git a/audiocraft/audiocraft/optim/linear_warmup_lr_scheduler.py b/audiocraft/audiocraft/optim/linear_warmup_lr_scheduler.py deleted file mode 100644 index 03274a1ae52b6f20473973b77619f34b2bddd6a1..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/optim/linear_warmup_lr_scheduler.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import typing as tp - -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler - - -class LinearWarmupLRScheduler(_LRScheduler): - """Inverse square root LR scheduler. - - Args: - optimizer (Optimizer): Torch optimizer. - warmup_steps (int): Number of warmup steps. - warmup_init_lr (tp.Optional[float]): Initial learning rate - during warmup phase. When not set, use the provided learning rate. - """ - def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0): - self.warmup_steps = warmup_steps - self.warmup_init_lr = warmup_init_lr - super().__init__(optimizer) - - def _get_sched_lr(self, lr: float, step: int): - if step < self.warmup_steps: - warmup_init_lr = self.warmup_init_lr or 0 - lr_step = (lr - warmup_init_lr) / self.warmup_steps - lr = warmup_init_lr + step * lr_step - return lr - - def get_lr(self): - return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs] diff --git a/audiocraft/audiocraft/optim/polynomial_decay_lr_scheduler.py b/audiocraft/audiocraft/optim/polynomial_decay_lr_scheduler.py deleted file mode 100644 index c5ea30b094538269dbb0055ab3163f84d1cf6e90..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/optim/polynomial_decay_lr_scheduler.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler - - -class PolynomialDecayLRScheduler(_LRScheduler): - """Polynomial decay LR scheduler. - - Args: - optimizer (Optimizer): Torch optimizer. - warmup_steps (int): Number of warmup steps. - total_steps (int): Total number of steps. - end_lr (float): Final learning rate to achieve over total number of steps. - zero_lr_warmup_steps (int): Number of steps with a learning rate of value 0. - power (float): Decay exponent. - """ - def __init__(self, optimizer: Optimizer, warmup_steps: int, total_steps: int, - end_lr: float = 0., zero_lr_warmup_steps: int = 0, power: float = 1.): - self.warmup_steps = warmup_steps - self.total_steps = total_steps - self.end_lr = end_lr - self.zero_lr_warmup_steps = zero_lr_warmup_steps - self.power = power - super().__init__(optimizer) - - def _get_sched_lr(self, lr: float, step: int): - if self.zero_lr_warmup_steps > 0 and step <= self.zero_lr_warmup_steps: - lr = 0 - elif self.warmup_steps > 0 and step <= self.warmup_steps + self.zero_lr_warmup_steps: - lr_ratio = (step - self.zero_lr_warmup_steps) / float(self.warmup_steps) - lr = lr_ratio * lr - elif step >= self.total_steps: - lr = self.end_lr - else: - total_warmup_steps = self.warmup_steps + self.zero_lr_warmup_steps - lr_range = lr - self.end_lr - pct_remaining = 1 - (step - total_warmup_steps) / (self.total_steps - total_warmup_steps) - lr = lr_range * pct_remaining ** self.power + self.end_lr - return lr - - def get_lr(self): - return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs] diff --git a/audiocraft/audiocraft/py.typed b/audiocraft/audiocraft/py.typed deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/audiocraft/audiocraft/quantization/__init__.py b/audiocraft/audiocraft/quantization/__init__.py deleted file mode 100644 index 1e0c7e429ab96d67be667e23bf7a0ffa389c036b..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/quantization/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""RVQ.""" -# flake8: noqa -from .vq import ResidualVectorQuantizer -from .base import BaseQuantizer, DummyQuantizer, QuantizedResult diff --git a/audiocraft/audiocraft/quantization/__pycache__/__init__.cpython-311.pyc b/audiocraft/audiocraft/quantization/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 3b500a65692c06df8dff32b951dbb73a634c296e..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/quantization/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/quantization/__pycache__/base.cpython-311.pyc b/audiocraft/audiocraft/quantization/__pycache__/base.cpython-311.pyc deleted file mode 100644 index 2874f6d0c656d2eacf7d4d2115ae612114bfc167..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/quantization/__pycache__/base.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/quantization/__pycache__/core_vq.cpython-311.pyc b/audiocraft/audiocraft/quantization/__pycache__/core_vq.cpython-311.pyc deleted file mode 100644 index 086aaecbbebbd0d2ab05046e77f7e933023f5b76..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/quantization/__pycache__/core_vq.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/quantization/__pycache__/vq.cpython-311.pyc b/audiocraft/audiocraft/quantization/__pycache__/vq.cpython-311.pyc deleted file mode 100644 index dc8eeeb7fc1b12b54d614b49bb309b85e7de640e..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/quantization/__pycache__/vq.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/quantization/base.py b/audiocraft/audiocraft/quantization/base.py deleted file mode 100644 index a77fefb98e62a5bbc6385910261ffdde2ffa5a25..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/quantization/base.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Base class for all quantizers. -""" - -from dataclasses import dataclass, field -import typing as tp - -import torch -from torch import nn - - -@dataclass -class QuantizedResult: - x: torch.Tensor - codes: torch.Tensor - bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. - penalty: tp.Optional[torch.Tensor] = None - metrics: dict = field(default_factory=dict) - - -class BaseQuantizer(nn.Module): - """Base class for quantizers. - """ - - def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult: - """ - Given input tensor x, returns first the quantized (or approximately quantized) - representation along with quantized codes, bandwidth, and any penalty term for the loss. - Finally, this returns a dict of metrics to update logging etc. - Frame rate must be passed so that the bandwidth is properly computed. - """ - raise NotImplementedError() - - def encode(self, x: torch.Tensor) -> torch.Tensor: - """Encode a given input tensor with the specified sample rate at the given bandwidth.""" - raise NotImplementedError() - - def decode(self, codes: torch.Tensor) -> torch.Tensor: - """Decode the given codes to the quantized representation.""" - raise NotImplementedError() - - @property - def total_codebooks(self): - """Total number of codebooks.""" - raise NotImplementedError() - - @property - def num_codebooks(self): - """Number of active codebooks.""" - raise NotImplementedError() - - def set_num_codebooks(self, n: int): - """Set the number of active codebooks.""" - raise NotImplementedError() - - -class DummyQuantizer(BaseQuantizer): - """Fake quantizer that actually does not perform any quantization. - """ - def __init__(self): - super().__init__() - - def forward(self, x: torch.Tensor, frame_rate: int): - q = x.unsqueeze(1) - return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x)) - - def encode(self, x: torch.Tensor) -> torch.Tensor: - """Encode a given input tensor with the specified sample rate at the given bandwidth. - In the case of the DummyQuantizer, the codes are actually identical - to the input and resulting quantized representation as no quantization is done. - """ - return x.unsqueeze(1) - - def decode(self, codes: torch.Tensor) -> torch.Tensor: - """Decode the given codes to the quantized representation. - In the case of the DummyQuantizer, the codes are actually identical - to the input and resulting quantized representation as no quantization is done. - """ - return codes.squeeze(1) - - @property - def total_codebooks(self): - """Total number of codebooks.""" - return 1 - - @property - def num_codebooks(self): - """Total number of codebooks.""" - return self.total_codebooks - - def set_num_codebooks(self, n: int): - """Set the number of active codebooks.""" - raise AttributeError("Cannot override the number of codebooks for the dummy quantizer") diff --git a/audiocraft/audiocraft/quantization/core_vq.py b/audiocraft/audiocraft/quantization/core_vq.py deleted file mode 100644 index da02a6ce3a7de15353f0fba9e826052beb67c436..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/quantization/core_vq.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import typing as tp - -from einops import rearrange, repeat -import flashy -import torch -from torch import nn, einsum -import torch.nn.functional as F - - -def exists(val: tp.Optional[tp.Any]) -> bool: - return val is not None - - -def default(val: tp.Any, d: tp.Any) -> tp.Any: - return val if exists(val) else d - - -def l2norm(t): - return F.normalize(t, p=2, dim=-1) - - -def ema_inplace(moving_avg, new, decay: float): - moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) - - -def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): - return (x + epsilon) / (x.sum() + n_categories * epsilon) - - -def uniform_init(*shape: int): - t = torch.empty(shape) - nn.init.kaiming_uniform_(t) - return t - - -def sample_vectors(samples, num: int): - num_samples, device = samples.shape[0], samples.device - - if num_samples >= num: - indices = torch.randperm(num_samples, device=device)[:num] - else: - indices = torch.randint(0, num_samples, (num,), device=device) - - return samples[indices] - - -def kmeans(samples, num_clusters: int, num_iters: int = 10): - dim, dtype = samples.shape[-1], samples.dtype - - means = sample_vectors(samples, num_clusters) - - for _ in range(num_iters): - diffs = rearrange(samples, "n d -> n () d") - rearrange( - means, "c d -> () c d" - ) - dists = -(diffs ** 2).sum(dim=-1) - - buckets = dists.max(dim=-1).indices - bins = torch.bincount(buckets, minlength=num_clusters) - zero_mask = bins == 0 - bins_min_clamped = bins.masked_fill(zero_mask, 1) - - new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) - new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) - new_means = new_means / bins_min_clamped[..., None] - - means = torch.where(zero_mask[..., None], means, new_means) - - return means, bins - - -def orthogonal_loss_fn(t): - # eq (2) from https://arxiv.org/abs/2112.00384 - n = t.shape[0] - normed_codes = l2norm(t) - identity = torch.eye(n, device=t.device) - cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes) - return ((cosine_sim - identity) ** 2).sum() / (n ** 2) - - -class EuclideanCodebook(nn.Module): - """Codebook with Euclidean distance. - - Args: - dim (int): Dimension. - codebook_size (int): Codebook size. - kmeans_init (bool): Whether to use k-means to initialize the codebooks. - If set to true, run the k-means algorithm on the first training batch and use - the learned centroids as initialization. - kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. - decay (float): Decay for exponential moving average over the codebooks. - epsilon (float): Epsilon value for numerical stability. - threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes - that have an exponential moving average cluster size less than the specified threshold with - randomly selected vector from the current batch. - """ - def __init__( - self, - dim: int, - codebook_size: int, - kmeans_init: int = False, - kmeans_iters: int = 10, - decay: float = 0.8, - epsilon: float = 1e-5, - threshold_ema_dead_code: int = 2, - ): - super().__init__() - self.decay = decay - init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros - embed = init_fn(codebook_size, dim) - - self.codebook_size = codebook_size - - self.kmeans_iters = kmeans_iters - self.epsilon = epsilon - self.threshold_ema_dead_code = threshold_ema_dead_code - - self.register_buffer("inited", torch.Tensor([not kmeans_init])) - self.register_buffer("cluster_size", torch.zeros(codebook_size)) - self.register_buffer("embed", embed) - self.register_buffer("embed_avg", embed.clone()) - - @torch.jit.ignore - def init_embed_(self, data): - if self.inited: - return - - embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) - self.embed.data.copy_(embed) - self.embed_avg.data.copy_(embed.clone()) - self.cluster_size.data.copy_(cluster_size) - self.inited.data.copy_(torch.Tensor([True])) - # Make sure all buffers across workers are in sync after initialization - flashy.distrib.broadcast_tensors(self.buffers()) - - def replace_(self, samples, mask): - modified_codebook = torch.where( - mask[..., None], sample_vectors(samples, self.codebook_size), self.embed - ) - self.embed.data.copy_(modified_codebook) - - def expire_codes_(self, batch_samples): - if self.threshold_ema_dead_code == 0: - return - - expired_codes = self.cluster_size < self.threshold_ema_dead_code - if not torch.any(expired_codes): - return - - batch_samples = rearrange(batch_samples, "... d -> (...) d") - self.replace_(batch_samples, mask=expired_codes) - flashy.distrib.broadcast_tensors(self.buffers()) - - def preprocess(self, x): - x = rearrange(x, "... d -> (...) d") - return x - - def quantize(self, x): - embed = self.embed.t() - dist = -( - x.pow(2).sum(1, keepdim=True) - - 2 * x @ embed - + embed.pow(2).sum(0, keepdim=True) - ) - embed_ind = dist.max(dim=-1).indices - return embed_ind - - def postprocess_emb(self, embed_ind, shape): - return embed_ind.view(*shape[:-1]) - - def dequantize(self, embed_ind): - quantize = F.embedding(embed_ind, self.embed) - return quantize - - def encode(self, x): - shape = x.shape - # pre-process - x = self.preprocess(x) - # quantize - embed_ind = self.quantize(x) - # post-process - embed_ind = self.postprocess_emb(embed_ind, shape) - return embed_ind - - def decode(self, embed_ind): - quantize = self.dequantize(embed_ind) - return quantize - - def forward(self, x): - shape, dtype = x.shape, x.dtype - x = self.preprocess(x) - self.init_embed_(x) - - embed_ind = self.quantize(x) - embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) - embed_ind = self.postprocess_emb(embed_ind, shape) - quantize = self.dequantize(embed_ind) - - if self.training: - # We do the expiry of code at that point as buffers are in sync - # and all the workers will take the same decision. - self.expire_codes_(x) - ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) - embed_sum = x.t() @ embed_onehot - ema_inplace(self.embed_avg, embed_sum.t(), self.decay) - cluster_size = ( - laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) - * self.cluster_size.sum() - ) - embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) - self.embed.data.copy_(embed_normalized) - - return quantize, embed_ind - - -class VectorQuantization(nn.Module): - """Vector quantization implementation. - Currently supports only euclidean distance. - - Args: - dim (int): Dimension - codebook_size (int): Codebook size - codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. - decay (float): Decay for exponential moving average over the codebooks. - epsilon (float): Epsilon value for numerical stability. - kmeans_init (bool): Whether to use kmeans to initialize the codebooks. - kmeans_iters (int): Number of iterations used for kmeans initialization. - threshold_ema_dead_code (int): - channels_last (bool): Channels are the last dimension in the input tensors. - commitment_weight (float): Weight for commitment loss. - orthogonal_reg_weight (float): Orthogonal regularization weights. - orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. - orthogonal_reg_max_codes (optional int): Maximum number of codes to consider - for orthogonal regularization. - threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes - that have an exponential moving average cluster size less than the specified threshold with - randomly selected vector from the current batch. - """ - def __init__( - self, - dim: int, - codebook_size: int, - codebook_dim: tp.Optional[int] = None, - decay: float = 0.8, - epsilon: float = 1e-5, - kmeans_init: bool = False, - kmeans_iters: int = 10, - threshold_ema_dead_code: int = 2, - channels_last: bool = False, - commitment_weight: float = 1., - orthogonal_reg_weight: float = 0.0, - orthogonal_reg_active_codes_only: bool = False, - orthogonal_reg_max_codes: tp.Optional[int] = None, - ): - super().__init__() - _codebook_dim: int = default(codebook_dim, dim) - - requires_projection = _codebook_dim != dim - self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) - self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) - - self.epsilon = epsilon - self.commitment_weight = commitment_weight - - self.orthogonal_reg_weight = orthogonal_reg_weight - self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only - self.orthogonal_reg_max_codes = orthogonal_reg_max_codes - - self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, - kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, - decay=decay, epsilon=epsilon, - threshold_ema_dead_code=threshold_ema_dead_code) - self.codebook_size = codebook_size - - self.channels_last = channels_last - - @property - def codebook(self): - return self._codebook.embed - - @property - def inited(self): - return self._codebook.inited - - def _preprocess(self, x): - if not self.channels_last: - x = rearrange(x, "b d n -> b n d") - return x - - def _postprocess(self, quantize): - if not self.channels_last: - quantize = rearrange(quantize, "b n d -> b d n") - return quantize - - def encode(self, x): - x = self._preprocess(x) - x = self.project_in(x) - embed_in = self._codebook.encode(x) - return embed_in - - def decode(self, embed_ind): - quantize = self._codebook.decode(embed_ind) - quantize = self.project_out(quantize) - quantize = self._postprocess(quantize) - return quantize - - def forward(self, x): - device = x.device - x = self._preprocess(x) - - x = self.project_in(x) - quantize, embed_ind = self._codebook(x) - - if self.training: - quantize = x + (quantize - x).detach() - - loss = torch.tensor([0.0], device=device, requires_grad=self.training) - - if self.training: - if self.commitment_weight > 0: - commit_loss = F.mse_loss(quantize.detach(), x) - loss = loss + commit_loss * self.commitment_weight - - if self.orthogonal_reg_weight > 0: - codebook = self.codebook - - if self.orthogonal_reg_active_codes_only: - # only calculate orthogonal loss for the activated codes for this batch - unique_code_ids = torch.unique(embed_ind) - codebook = codebook[unique_code_ids] - - num_codes = codebook.shape[0] - if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes: - rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes] - codebook = codebook[rand_ids] - - orthogonal_reg_loss = orthogonal_loss_fn(codebook) - loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight - - quantize = self.project_out(quantize) - quantize = self._postprocess(quantize) - - return quantize, embed_ind, loss - - -class ResidualVectorQuantization(nn.Module): - """Residual vector quantization implementation. - - Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf - """ - def __init__(self, *, num_quantizers, **kwargs): - super().__init__() - self.layers = nn.ModuleList( - [VectorQuantization(**kwargs) for _ in range(num_quantizers)] - ) - - def forward(self, x, n_q: tp.Optional[int] = None): - quantized_out = 0.0 - residual = x - - all_losses = [] - all_indices = [] - - n_q = n_q or len(self.layers) - - for i, layer in enumerate(self.layers[:n_q]): - quantized, indices, loss = layer(residual) - residual = residual - quantized - quantized_out = quantized_out + quantized - all_indices.append(indices) - all_losses.append(loss) - - out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) - return quantized_out, out_indices, out_losses - - def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: - residual = x - all_indices = [] - n_q = n_q or len(self.layers) - for layer in self.layers[:n_q]: - indices = layer.encode(residual) - quantized = layer.decode(indices) - residual = residual - quantized - all_indices.append(indices) - out_indices = torch.stack(all_indices) - return out_indices - - def decode(self, q_indices: torch.Tensor) -> torch.Tensor: - quantized_out = torch.tensor(0.0, device=q_indices.device) - for i, indices in enumerate(q_indices): - layer = self.layers[i] - quantized = layer.decode(indices) - quantized_out = quantized_out + quantized - return quantized_out diff --git a/audiocraft/audiocraft/quantization/vq.py b/audiocraft/audiocraft/quantization/vq.py deleted file mode 100644 index aa57bea59db95ddae35e0657f723ca3a29ee943b..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/quantization/vq.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import math -import typing as tp - -import torch - -from .base import BaseQuantizer, QuantizedResult -from .core_vq import ResidualVectorQuantization - - -class ResidualVectorQuantizer(BaseQuantizer): - """Residual Vector Quantizer. - - Args: - dimension (int): Dimension of the codebooks. - n_q (int): Number of residual vector quantizers used. - q_dropout (bool): Random quantizer drop out at train time. - bins (int): Codebook size. - decay (float): Decay for exponential moving average over the codebooks. - kmeans_init (bool): Whether to use kmeans to initialize the codebooks. - kmeans_iters (int): Number of iterations used for kmeans initialization. - threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes - that have an exponential moving average cluster size less than the specified threshold with - randomly selected vector from the current batch. - orthogonal_reg_weight (float): Orthogonal regularization weights. - orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. - orthogonal_reg_max_codes (optional int): Maximum number of codes to consider. - for orthogonal regularization. - """ - def __init__( - self, - dimension: int = 256, - n_q: int = 8, - q_dropout: bool = False, - bins: int = 1024, - decay: float = 0.99, - kmeans_init: bool = True, - kmeans_iters: int = 10, - threshold_ema_dead_code: int = 2, - orthogonal_reg_weight: float = 0.0, - orthogonal_reg_active_codes_only: bool = False, - orthogonal_reg_max_codes: tp.Optional[int] = None, - ): - super().__init__() - self.max_n_q = n_q - self.n_q = n_q - self.q_dropout = q_dropout - self.dimension = dimension - self.bins = bins - self.decay = decay - self.kmeans_init = kmeans_init - self.kmeans_iters = kmeans_iters - self.threshold_ema_dead_code = threshold_ema_dead_code - self.orthogonal_reg_weight = orthogonal_reg_weight - self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only - self.orthogonal_reg_max_codes = orthogonal_reg_max_codes - self.vq = ResidualVectorQuantization( - dim=self.dimension, - codebook_size=self.bins, - num_quantizers=self.n_q, - decay=self.decay, - kmeans_init=self.kmeans_init, - kmeans_iters=self.kmeans_iters, - threshold_ema_dead_code=self.threshold_ema_dead_code, - orthogonal_reg_weight=self.orthogonal_reg_weight, - orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only, - orthogonal_reg_max_codes=self.orthogonal_reg_max_codes, - channels_last=False - ) - - def forward(self, x: torch.Tensor, frame_rate: int): - n_q = self.n_q - if self.training and self.q_dropout: - n_q = int(torch.randint(1, self.n_q + 1, (1,)).item()) - bw_per_q = math.log2(self.bins) * frame_rate / 1000 - quantized, codes, commit_loss = self.vq(x, n_q=n_q) - codes = codes.transpose(0, 1) - # codes is [B, K, T], with T frames, K nb of codebooks. - bw = torch.tensor(n_q * bw_per_q).to(x) - return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) - - def encode(self, x: torch.Tensor) -> torch.Tensor: - """Encode a given input tensor with the specified frame rate at the given bandwidth. - The RVQ encode method sets the appropriate number of quantizer to use - and returns indices for each quantizer. - """ - n_q = self.n_q - codes = self.vq.encode(x, n_q=n_q) - codes = codes.transpose(0, 1) - # codes is [B, K, T], with T frames, K nb of codebooks. - return codes - - def decode(self, codes: torch.Tensor) -> torch.Tensor: - """Decode the given codes to the quantized representation.""" - # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T]. - codes = codes.transpose(0, 1) - quantized = self.vq.decode(codes) - return quantized - - @property - def total_codebooks(self): - return self.max_n_q - - @property - def num_codebooks(self): - return self.n_q - - def set_num_codebooks(self, n: int): - assert n > 0 and n <= self.max_n_q - self.n_q = n diff --git a/audiocraft/audiocraft/solvers/__init__.py b/audiocraft/audiocraft/solvers/__init__.py deleted file mode 100644 index ae19f3a8c51abf469697d6affa91449d668716ba..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/solvers/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -""" -Solvers. A Solver is a training recipe, combining the dataloaders, models, -optimizer, losses etc into a single convenient object. -""" - -# flake8: noqa -from .audiogen import AudioGenSolver -from .builders import get_solver -from .base import StandardSolver -from .compression import CompressionSolver -from .musicgen import MusicGenSolver -from .diffusion import DiffusionSolver diff --git a/audiocraft/audiocraft/solvers/__pycache__/__init__.cpython-311.pyc b/audiocraft/audiocraft/solvers/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 7a729df3a624d38e1f88090914cbfc5b298f0828..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/solvers/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/solvers/__pycache__/audiogen.cpython-311.pyc b/audiocraft/audiocraft/solvers/__pycache__/audiogen.cpython-311.pyc deleted file mode 100644 index 000c01bc8d7cd1a46692aca0e49a926e93f79fed..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/solvers/__pycache__/audiogen.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/solvers/__pycache__/base.cpython-311.pyc b/audiocraft/audiocraft/solvers/__pycache__/base.cpython-311.pyc deleted file mode 100644 index e9da2965e67be0fb9b878aa29b5a940147201bb7..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/solvers/__pycache__/base.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/solvers/__pycache__/builders.cpython-311.pyc b/audiocraft/audiocraft/solvers/__pycache__/builders.cpython-311.pyc deleted file mode 100644 index 2da29792bcd01b33ccac5755c30b392040948904..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/solvers/__pycache__/builders.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/solvers/__pycache__/compression.cpython-311.pyc b/audiocraft/audiocraft/solvers/__pycache__/compression.cpython-311.pyc deleted file mode 100644 index b8fed1c7b2d243a2d3edefa92f6d8071509b8c28..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/solvers/__pycache__/compression.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/solvers/__pycache__/diffusion.cpython-311.pyc b/audiocraft/audiocraft/solvers/__pycache__/diffusion.cpython-311.pyc deleted file mode 100644 index f5ced6280c572c4fa197db693d472e709244aae8..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/solvers/__pycache__/diffusion.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/solvers/__pycache__/musicgen.cpython-311.pyc b/audiocraft/audiocraft/solvers/__pycache__/musicgen.cpython-311.pyc deleted file mode 100644 index e0bf6167640481c3554a3f4740690147d8ad2bc6..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/solvers/__pycache__/musicgen.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/solvers/audiogen.py b/audiocraft/audiocraft/solvers/audiogen.py deleted file mode 100644 index 1568f97fe7b84b90c7ef760ef5606fe0a475545a..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/solvers/audiogen.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from . import builders, musicgen - - -class AudioGenSolver(musicgen.MusicGenSolver): - """Solver for AudioGen re-implementation training task. - - Note that this implementation does not strictly follows - the method proposed in https://arxiv.org/abs/2209.15352 - but is derived from MusicGen's training pipeline. - - More information can be found in the AudioGen model card. - """ - DATASET_TYPE: builders.DatasetType = builders.DatasetType.SOUND diff --git a/audiocraft/audiocraft/solvers/base.py b/audiocraft/audiocraft/solvers/base.py deleted file mode 100644 index 0432e44a36838c5731711f9d54f81822b21f20bd..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/solvers/base.py +++ /dev/null @@ -1,631 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from abc import ABC, abstractmethod -from contextlib import contextmanager -from pathlib import Path -import typing as tp - -import flashy -import omegaconf -import torch -from torch import nn - -from .. import optim -from ..optim import fsdp -from ..utils import checkpoint -from ..utils.autocast import TorchAutocast -from ..utils.best_state import BestStateDictManager -from ..utils.deadlock import DeadlockDetect -from ..utils.profiler import Profiler -from ..utils.utils import copy_state, dict_from_config, model_hash, with_rank_rng - - -class StandardSolver(ABC, flashy.BaseSolver): - """Standard solver for AudioCraft. - - The standard solver implements a base training loop with the following stages: - train, valid, evaluate and generate that are expected to be all defined for - solvers in AudioCraft. It also provides a nice default management of Dora history replay, - checkpoint management across epoch, and logging configuration. - - AudioCraft solvers must inherit from the StandardSolver and define the methods - associated to each stage as well as the show, build_model and build_dataloaders methods. - """ - def __init__(self, cfg: omegaconf.DictConfig): - super().__init__() - self.logger.info(f"Instantiating solver {self.__class__.__name__} for XP {self.xp.sig}") - self.logger.info(f"All XP logs are stored in {self.xp.folder}") - self.cfg = cfg - self.device = cfg.device - self.model: nn.Module - self._continue_best_source_keys = ['best_state', 'fsdp_best_state'] - self._fsdp_modules: tp.List[fsdp.FSDP] = [] - self._ema_sources: nn.ModuleDict = nn.ModuleDict() - self.ema: tp.Optional[optim.ModuleDictEMA] = None - self.dataloaders: tp.Dict[str, torch.utils.data.DataLoader] = dict() - self._log_updates = self.cfg.logging.get('log_updates', 10) - if self.cfg.logging.log_tensorboard: - self.init_tensorboard(**self.cfg.get('tensorboard')) - if self.cfg.logging.log_wandb and self: - self.init_wandb(**self.cfg.get('wandb')) - # keep a copy of the best performing state for stateful objects - # used for evaluation and generation stages - dtype_best: tp.Optional[torch.dtype] = None - if self.cfg.fsdp.use: - dtype_best = getattr(torch, self.cfg.fsdp.param_dtype) # type: ignore - assert isinstance(dtype_best, torch.dtype) - elif self.cfg.autocast: - dtype_best = getattr(torch, self.cfg.autocast_dtype) # type: ignore - assert isinstance(dtype_best, torch.dtype) - self.best_state: BestStateDictManager = BestStateDictManager(dtype=dtype_best) - # Hacky support for keeping a copy of the full best state in rank0. - self.fsdp_best_state: tp.Dict[str, tp.Any] = {} - self.register_stateful('best_state', 'fsdp_best_state') # register best_state object to keep it in state_dict - self._new_best_state: bool = False # should save a new checkpoint - # instantiate datasets and appropriate number of updates per epoch - self.build_dataloaders() - if self.cfg.execute_only is None: - assert 'train' in self.dataloaders, "The train dataset split must be provided." - assert 'valid' in self.dataloaders, "The valid dataset split must be provided." - self.train_updates_per_epoch = len(self.dataloaders['train']) if 'train' in self.dataloaders else 0 - if self.cfg.optim.updates_per_epoch: - self.train_updates_per_epoch = self.cfg.optim.updates_per_epoch - self.total_updates = self.train_updates_per_epoch * self.cfg.optim.epochs - # instantiate model & exponential moving average on the model - self.build_model() - self.logger.info("Model hash: %s", model_hash(self.model)) - assert 'model' in self.stateful.sources, \ - "Please register the model to stateful with self.register_stateful('model') in build_model." - self.profiler = Profiler(self.model, **self.cfg.profiler) - self.initialize_ema() - self.register_stateful('ema') - assert self.ema is None or 'ema' in self.stateful.sources, \ - "Please register the ema to stateful with self.register_stateful('ema') in build_model." - self.deadlock_detect = DeadlockDetect(**self.cfg.deadlock) - # basic statistics on the trained model - model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad) / 1e6 - # one copy of grad, one copy of momentum, one copy of denominator and model weights. - # and 4 bytes for each float! - mem_usage = model_size * 4 * 4 / 1000 - self.logger.info("Model size: %.2f M params", model_size) - self.logger.info("Base memory usage, with model, grad and optim: %.2f GB", mem_usage) - - @property - def autocast(self): - """Convenient autocast (or not) using the solver configuration.""" - return TorchAutocast(enabled=self.cfg.autocast, device_type=self.device, dtype=self.autocast_dtype) - - def _get_state_source(self, name) -> flashy.state.StateDictSource: - # Internal utility to get a state source from the solver - return self.stateful.sources[name] - - @property - def best_metric_name(self) -> tp.Optional[str]: - """Metric name used to identify the best state. This metric should be stored in the metrics - used on the stage for best state identification (most likely, `valid`). If None, then - no best state is saved. - """ - return None - - def register_best_state(self, *args: str): - """Register state sources in `BestStateDictManager` to keep their best states along with their - latest states. The best state will be used at evaluation stages instead of the latest states. - - Shortcut around `BestStateDictManager.register` method. You can pass any number of - attribute, included nested attributes and those will be included into the checkpoints - and automatically restored when `BaseSolver.restore` is called. - """ - for name in args: - state_source = self._get_state_source(name) - assert name in self.stateful.sources, "Registered states in best should be registered in stateful first!" - self.best_state.register(name, state_source) - - def register_ema(self, *args: str): - """Register state sources for exponential moving average. - - The registered sources are used to instantiate a ModuleDictEMA instance. - The ModuleDictEMA keeps a `nn.ModuleDict` module that is updated when self.ema.step() is called - and swapped with the original state sources with self.swap_ema_state() method. - - Usage: - self.register_ema('model') - """ - assert self.ema is None, "Cannot register state source to already instantiated EMA." - for name in args: - self._ema_sources[name] = getattr(self, name) - - def wrap_with_fsdp(self, model: torch.nn.Module, *args, **kwargs): - model = fsdp.wrap_with_fsdp(self.cfg.fsdp, model, *args, **kwargs) - if isinstance(model, fsdp.FSDP): - self._fsdp_modules.append(model) - return model - - def update_best_state_from_stage(self, stage_name: str = 'valid'): - """Update latest best state based on pending metrics of a given stage. This method relies - on the `BestStateDictManager.update` method to update the best state_dict with latest weights - if the registered states happen to match to the best performing setup. - """ - if self.best_metric_name is None: - # when no best metric is defined, the last state is always the best - self._new_best_state = True - self.logger.info("Updating best state with current state.") - else: - assert stage_name in self._pending_metrics, f"Metrics for stage {stage_name} not found." - assert self.best_metric_name in self._pending_metrics[stage_name], \ - f"Best metric not found in {stage_name} metrics. Cannot register best state" - current_score = self._pending_metrics[stage_name][self.best_metric_name] - all_best_metric_scores = [ - past_metrics[stage_name][self.best_metric_name] - for past_metrics in self.history - ] - all_best_metric_scores.append(current_score) - best_score = min(all_best_metric_scores) - self._new_best_state = current_score == best_score - if self._new_best_state: - old_best = min(all_best_metric_scores[:-1] + [float('inf')]) - self.logger.info( - f"New best state with {self.best_metric_name}={current_score:.3f} (was {old_best:.3f})") - - if self._new_best_state: - if self.cfg.fsdp.use: - # this will give an empty state dict on all ranks but the rank 0 - # which will have a copy in memory of the full model. - with fsdp.switch_to_full_state_dict(self._fsdp_modules): - for name in self.best_state.states.keys(): - state_source = self._get_state_source(name) - self.best_state.update(name, state_source) - # we save to a different dict. - self.fsdp_best_state.update(self.best_state.state_dict()) - # We cannot efficiently load fsdp_best_state when using FSDP, - # so we have do do a second pass, with the local shards. - for name in self.best_state.states.keys(): - state_source = self._get_state_source(name) - self.best_state.update(name, state_source) - - def _load_new_state_dict(self, state_dict: dict) -> dict: - old_states = {} - for name, new_state in state_dict.items(): - state_source = self._get_state_source(name) - old_states[name] = copy_state(state_source.state_dict()) - state_source.load_state_dict(new_state) - return old_states - - @contextmanager - def swap_best_state(self): - self.logger.debug(f"Swapping to best state for: {', '.join(self.best_state.state_dict().keys())}") - old_states = self._load_new_state_dict(self.best_state.state_dict()) - try: - yield - finally: - self.logger.debug("Swapping back from best to original state") - for name, old_state in old_states.items(): - state_source = self._get_state_source(name) - state_source.load_state_dict(old_state) - - @contextmanager - def swap_ema_state(self): - if self.ema is None: - yield - else: - ema_state_dict = self.ema.state_dict()['state'] - self.logger.debug(f"Swapping to EMA state for: {', '.join(ema_state_dict.keys())}") - old_states = self._load_new_state_dict(ema_state_dict) - try: - yield - finally: - self.logger.debug("Swapping back from EMA state to original state") - for name, old_state in old_states.items(): - state_source = self._get_state_source(name) - state_source.load_state_dict(old_state) - - @property - def is_training(self): - return self.current_stage == 'train' - - def log_model_summary(self, model: nn.Module): - """Log model summary, architecture and size of the model.""" - self.logger.info(model) - mb = sum(p.numel() for p in model.parameters()) * 4 / 2 ** 20 - self.logger.info("Size: %.1f MB", mb) - - @abstractmethod - def build_model(self): - """Method to implement to initialize model.""" - ... - - def initialize_ema(self): - """Initialize exponential moving average with the registered sources. - EMA object is created if the optim.ema.model.decay value is non-null. - """ - from .builders import get_ema - self.ema = get_ema(self._ema_sources, self.cfg.optim.ema) - if self.ema is None: - self.logger.info('No EMA on the model.') - else: - assert self.cfg.optim.ema.updates > 0 - self.logger.info( - f'Initializing EMA on the model with decay = {self.ema.decay}' - f' every {self.cfg.optim.ema.updates} updates' - ) - - @abstractmethod - def build_dataloaders(self): - """Method to implement to initialize dataloaders.""" - ... - - @abstractmethod - def show(self): - """Method to log any information without running the job.""" - ... - - @property - def log_updates(self): - # convenient access to log updates - return self._log_updates - - def checkpoint_path(self, **kwargs): - kwargs.setdefault('use_fsdp', self.cfg.fsdp.use) - return self.folder / checkpoint.checkpoint_name(**kwargs) - - def epoch_checkpoint_path(self, epoch: int, **kwargs): - kwargs.setdefault('use_fsdp', self.cfg.fsdp.use) - return self.folder / checkpoint.checkpoint_name(str(epoch), **kwargs) - - def checkpoint_path_with_name(self, name: str, **kwargs): - kwargs.setdefault('use_fsdp', self.cfg.fsdp.use) - return self.folder / checkpoint.checkpoint_name(name=name, **kwargs) - - def save_checkpoints(self): - """Save checkpoint, optionally keeping a copy for a given epoch.""" - is_sharded = self.cfg.fsdp.use - if not flashy.distrib.is_rank_zero() and not is_sharded: - return - self.logger.info("Model hash: %s", model_hash(self.model)) - state = self.state_dict() - epoch = self.epoch - 1 # pushing metrics will increase the epoch in Flashy, so we do -1 here - - # save minimal state_dict as new checkpoint every X epoch - if self.cfg.checkpoint.save_every: - if epoch % self.cfg.checkpoint.save_every == 0: - minimal_state = state - if self.cfg.checkpoint.keep_every_states is not None and len(self.cfg.checkpoint.keep_every_states) > 0: - minimal_state = { - name: source for name, source in state.items() - if name in self.cfg.checkpoint.keep_every_states - } - epoch_checkpoint_path = self.epoch_checkpoint_path(epoch) - checkpoint.save_checkpoint(minimal_state, epoch_checkpoint_path, is_sharded) - - # save checkpoint as latest checkpoint - if self.cfg.checkpoint.save_last: - last_checkpoint_path = self.checkpoint_path() - checkpoint.save_checkpoint(state, last_checkpoint_path, is_sharded) - - # flush any stale checkpoint to reduce disk footprint - checkpoint.flush_stale_checkpoints(self.checkpoint_path()) - - def load_from_pretrained(self, name: str) -> dict: - raise NotImplementedError("Solver does not provide a way to load pretrained models.") - - def load_checkpoints(self, load_best: bool = False, ignore_state_keys: tp.List[str] = []) -> tp.Optional[dict]: - """Load last checkpoint or the one specified in continue_from. - - Args: - load_best (bool): Whether to load from best state dict or not. - Best state dict is always used when not loading the current xp. - ignore_state_keys (list of str): List of sources to ignore when loading the state, e.g. `optimizer`. - Returns: - state (dict, optional): The loaded state dictionary. - """ - # load checkpoints from xp folder or cfg.continue_from - is_sharded = self.cfg.fsdp.use - load_from_path: tp.Optional[Path] = None - checkpoint_source: tp.Optional[checkpoint.CheckpointSource] = None - - if load_best: - self.logger.info("Trying to load state_dict from best state.") - - state: tp.Optional[dict] = None - rank0_checkpoint_path = self.checkpoint_path(use_fsdp=False) - current_checkpoint_path = self.checkpoint_path() - _pretrained_prefix = '//pretrained/' - continue_pretrained = (self.cfg.continue_from or '').startswith(_pretrained_prefix) - if rank0_checkpoint_path.exists(): - self.logger.info(f"Loading existing checkpoint: {current_checkpoint_path}") - load_from_path = current_checkpoint_path - checkpoint.check_sharded_checkpoint(current_checkpoint_path, rank0_checkpoint_path) - checkpoint_source = checkpoint.CheckpointSource.CURRENT_XP - elif self.cfg.continue_from and not continue_pretrained: - self.logger.info(f"Continuing from provided checkpoint: {self.cfg.continue_from}") - # we're always continuing from consolidated checkpoints: self.cfg.use_fsdp and not continue_best - load_from_path = checkpoint.resolve_checkpoint_path(self.cfg.continue_from, use_fsdp=False) - if load_from_path is None: - self.logger.error('Could not resolve the continue_from checkpoint %s', self.cfg.continue_from) - raise RuntimeError(f'Could not resolve continue_from checkpoint {self.cfg.continue_from}') - checkpoint_source = checkpoint.CheckpointSource.OTHER - - if load_from_path is not None: - state = checkpoint.load_checkpoint(load_from_path, is_sharded) - elif continue_pretrained: - self.logger.info("Loading a pretrained model. Ignoring 'load_best' and 'ignore_state_keys' params.") - state = self.load_from_pretrained(self.cfg.continue_from[len(_pretrained_prefix):]) - checkpoint_source = checkpoint.CheckpointSource.PRETRAINED - load_best = True - - # checkpoints are not from the current xp, we only retrieve the best state - if checkpoint_source is not None and checkpoint_source != checkpoint.CheckpointSource.CURRENT_XP: - assert state is not None - self.logger.info("Checkpoint source is not the current xp: Load state_dict from best state.") - load_best = True - state = {key: state[key] for key in self._continue_best_source_keys if key in state} - # loaded checkpoints are FSDP checkpoints: we're reading the best state - # from FSDP and we drop the regular best_state - if 'fsdp_best_state' in state and state['fsdp_best_state']: - state.pop('best_state', None) - self.logger.info("... Loaded checkpoint has FSDP best state") - # FSDP is enabled in the solver, if the loaded checkpoints do not have FSDP support - # then we're initializing FSDP best state with the regular best state - elif self.cfg.fsdp.use: - if 'fsdp_best_state' not in state or not state['fsdp_best_state']: - # we swap non-FSDP checkpoints best_state to FSDP-compatible best state - state['fsdp_best_state'] = state.pop('best_state') - self.logger.info("... Loaded checkpoint does not have FSDP best state. Use regular best state") - - if state is not None: - if load_best: - self.logger.info("Ignoring keys when loading best %r", ignore_state_keys) - for key in set(ignore_state_keys): - if key in state: - state.pop(key) - has_best_state = 'best_state' in state or 'fsdp_best_state' in state - assert has_best_state, ("Trying to load best state but neither 'best_state'", - " or 'fsdp_best_state' found in checkpoints.") - self.load_state_dict(state) - - # for FSDP, let's make extra sure nothing bad happened with out of sync - # checkpoints across workers. - epoch = float(self.epoch) - avg_epoch = flashy.distrib.average_metrics({'epoch': epoch})['epoch'] - if avg_epoch != epoch: - raise RuntimeError( - f"Inconsistent loading of checkpoints happened, our epoch is {epoch} " - f"but average of epochs is {avg_epoch}, at least one gpu must have a " - "different epoch number.") - - # on load_best, properly reinitialize state_dict, best states and ema - # otherwise we load from the current xp and don't alter anything - if load_best: - self.logger.info("Loading state_dict from best state.") - if not self.cfg.fsdp.use and self.fsdp_best_state: - # loading from an FSDP checkpoint but with FSDP deactivated - self.logger.info("... Loading from FSDP best state dict.") - self.best_state.load_state_dict(self.fsdp_best_state) - - # if load_best, we permanently override the regular state_dict with the best state - if self.cfg.fsdp.use: - self.logger.info("FSDP is used, loading from FSDP best state.") - with fsdp.switch_to_full_state_dict(self._fsdp_modules): - # this might be really fragile but okay for now. - self.load_state_dict(self.fsdp_best_state) - else: - # we permanently swap the stateful objects to their best state - self._load_new_state_dict(self.best_state.state_dict()) - - # the EMA modules should also be instantiated with best state. - # the easiest way to do so is to reinitialize a new EMA with best state loaded. - if self.ema is not None: - self.logger.info("Re-initializing EMA from best state") - self.initialize_ema() - - if self.cfg.fsdp.use: - self.logger.info("Re-initializing best state after using FSDP best state.") - for name in self.best_state.states.keys(): - state_source = self._get_state_source(name) - self.best_state.update(name, state_source) - - return state - - def restore(self, load_best: bool = False, replay_metrics: bool = False, - ignore_state_keys: tp.List[str] = []) -> bool: - """Restore the status of a solver for a given xp. - - Args: - load_best (bool): if `True`, load the best state from the checkpoint. - replay_metrics (bool): if `True`, logs all the metrics from past epochs. - ignore_state_keys (list of str): list of sources to ignore when loading the state, e.g. `optimizer`. - """ - self.logger.info("Restoring weights and history.") - restored_checkpoints = self.load_checkpoints(load_best, ignore_state_keys) - - self.logger.info("Model hash: %s", model_hash(self.model)) - - if replay_metrics and len(self.history) > 0: - self.logger.info("Replaying past metrics...") - for epoch, stages in enumerate(self.history): - for stage_name, metrics in stages.items(): - # We manually log the metrics summary to the result logger - # as we don't want to add them to the pending metrics - self.result_logger._log_summary(stage_name, metrics, step=epoch + 1, step_name='epoch', - formatter=self.get_formatter(stage_name)) - return restored_checkpoints is not None - - def commit(self, save_checkpoints: bool = True): - """Commit metrics to dora and save checkpoints at the end of an epoch.""" - # we override commit to introduce more complex checkpoint saving behaviors - self.history.append(self._pending_metrics) # This will increase self.epoch - if save_checkpoints: - self.save_checkpoints() - self._start_epoch() - if flashy.distrib.is_rank_zero(): - self.xp.link.update_history(self.history) - - def run_epoch(self): - """Run a single epoch with all stages. - - Metrics for a given stage are stored in _pending_metrics and committed by the solver afterwards. - Children solvers can extend this method with custom behavior, e.g.: - - def run_epoch(self): - ... # custom code - super().run_epoch() - ... # custom code - """ - self.run_stage('train', self.train) - with torch.no_grad(): - with self.swap_ema_state(): - self.run_stage('valid', self.valid) - # the best state is updated with EMA states if available - self.update_best_state_from_stage('valid') - with self.swap_best_state(): - if self.should_run_stage('evaluate'): - self.run_stage('evaluate', self.evaluate) - if self.should_run_stage('generate'): - self.run_stage('generate', with_rank_rng()(self.generate)) - - def run(self): - """Training loop.""" - assert len(self.state_dict()) > 0 - self.restore(replay_metrics=True) # load checkpoint and replay history - self.log_hyperparams(dict_from_config(self.cfg)) - for epoch in range(self.epoch, self.cfg.optim.epochs + 1): - if self.should_stop_training(): - return - self.run_epoch() - # Commit will send the metrics to Dora and save checkpoints by default. - self.commit() - - def should_stop_training(self) -> bool: - """Check whether we should stop training or not.""" - return self.epoch > self.cfg.optim.epochs - - def should_run_stage(self, stage_name) -> bool: - """Check whether we want to run the specified stages.""" - stage_every = self.cfg[stage_name].get('every', None) - is_last_epoch = self.epoch == self.cfg.optim.epochs - is_epoch_every = (stage_every and self.epoch % stage_every == 0) - return is_last_epoch or is_epoch_every - - @abstractmethod - def run_step(self, idx: int, batch: tp.Any, metrics: dict): - """Perform one training or valid step on a given batch.""" - ... - - def common_train_valid(self, dataset_split: str, **kwargs: tp.Any): - """Common logic for train and valid stages.""" - self.model.train(self.is_training) - - loader = self.dataloaders[dataset_split] - # get a different order for distributed training, otherwise this will get ignored - if flashy.distrib.world_size() > 1 \ - and isinstance(loader.sampler, torch.utils.data.distributed.DistributedSampler): - loader.sampler.set_epoch(self.epoch) - updates_per_epoch = self.train_updates_per_epoch if self.is_training else len(loader) - if self.cfg.benchmark_no_load: - self.logger.warning("Fake loading for benchmarking: re-using first batch") - batch = next(iter(loader)) - loader = [batch] * updates_per_epoch # type: ignore - lp = self.log_progress(self.current_stage, loader, total=updates_per_epoch, updates=self.log_updates) - average = flashy.averager() # epoch wise average - instant_average = flashy.averager() # average between two logging - metrics: dict = {} - - with self.profiler, self.deadlock_detect: # profiler will only run for the first 20 updates. - for idx, batch in enumerate(lp): - self.deadlock_detect.update('batch') - if idx >= updates_per_epoch: - break - metrics = {} - metrics = self.run_step(idx, batch, metrics) - self.deadlock_detect.update('step') - # run EMA step - if self.ema is not None and self.is_training and (idx + 1) % self.cfg.optim.ema.updates == 0: - self.logger.debug("EMA model step") - self.ema.step() - self.deadlock_detect.update('ema') - self.profiler.step() - instant_metrics = instant_average(metrics) - if lp.update(**instant_metrics): - instant_average = flashy.averager() # reset averager between two logging - metrics = average(metrics) # epoch wise average - self.deadlock_detect.update('end_batch') - - metrics = flashy.distrib.average_metrics(metrics, updates_per_epoch) - return metrics - - def train(self): - """Train stage.""" - return self.common_train_valid('train') - - def valid(self): - """Valid stage.""" - return self.common_train_valid('valid') - - @abstractmethod - def evaluate(self): - """Evaluate stage.""" - ... - - @abstractmethod - def generate(self): - """Generate stage.""" - ... - - def run_one_stage(self, stage_name: str): - """Run only the specified stage. - This method is useful to only generate samples from a trained experiment - or rerun the validation or evaluation stages. - """ - fn = { - 'generate': with_rank_rng()(self.generate), - 'evaluate': self.evaluate, - 'valid': self.valid, - } - if stage_name not in fn: - raise ValueError(f'Trying to run stage {stage_name} is not supported.') - assert len(self.state_dict()) > 0 - self._start_epoch() - with torch.no_grad(), self.swap_best_state(): - self.run_stage(stage_name, fn[stage_name]) - if not self.cfg.execute_inplace: - self.commit(save_checkpoints=False) - - @staticmethod - def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None, - device: tp.Optional[str] = None, autocast: bool = True, - batch_size: tp.Optional[int] = None, - override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None, - **kwargs): - """Mostly a convenience function around audiocraft.train.get_solver_from_sig, - populating all the proper param, deactivating EMA, FSDP, loading the best state, - basically all you need to get a solver ready to "play" with in single GPU mode - and with minimal memory overhead. - - Args: - sig (str): signature to load. - dtype (str or None): potential dtype, as a string, i.e. 'float16'. - device (str or None): potential device, as a string, i.e. 'cuda'. - override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'. - """ - from audiocraft import train - our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}} - our_override_cfg['autocast'] = autocast - if dtype is not None: - our_override_cfg['dtype'] = dtype - if device is not None: - our_override_cfg['device'] = device - if batch_size is not None: - our_override_cfg['dataset'] = {'batch_size': batch_size} - if override_cfg is None: - override_cfg = {} - override_cfg = omegaconf.OmegaConf.merge( - omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore - solver = train.get_solver_from_sig( - sig, override_cfg=override_cfg, - load_best=True, disable_fsdp=True, - ignore_state_keys=['optimizer', 'ema'], **kwargs) - solver.model.eval() - return solver diff --git a/audiocraft/audiocraft/solvers/builders.py b/audiocraft/audiocraft/solvers/builders.py deleted file mode 100644 index 304d8f08d33a70e8be9388c855b2ae43bdf2683b..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/solvers/builders.py +++ /dev/null @@ -1,363 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -All the functions to build the relevant solvers and used objects -from the Hydra config. -""" - -from enum import Enum -import logging -import typing as tp - -import dora -import flashy -import omegaconf -import torch -from torch import nn -from torch.optim import Optimizer -# LRScheduler was renamed in some torch versions -try: - from torch.optim.lr_scheduler import LRScheduler # type: ignore -except ImportError: - from torch.optim.lr_scheduler import _LRScheduler as LRScheduler - -from .base import StandardSolver -from .. import adversarial, data, losses, metrics, optim -from ..utils.utils import dict_from_config, get_loader - - -logger = logging.getLogger(__name__) - - -class DatasetType(Enum): - AUDIO = "audio" - MUSIC = "music" - SOUND = "sound" - - -def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver: - """Instantiate solver from config.""" - from .audiogen import AudioGenSolver - from .compression import CompressionSolver - from .musicgen import MusicGenSolver - from .diffusion import DiffusionSolver - klass = { - 'compression': CompressionSolver, - 'musicgen': MusicGenSolver, - 'audiogen': AudioGenSolver, - 'lm': MusicGenSolver, # backward compatibility - 'diffusion': DiffusionSolver, - 'sound_lm': AudioGenSolver, # backward compatibility - }[cfg.solver] - return klass(cfg) # type: ignore - - -def get_optim_parameter_groups(model: nn.Module): - """Create parameter groups for the model using the appropriate method - if defined for each modules, to create the different groups. - - Args: - model (nn.Module): torch model - Returns: - List of parameter groups - """ - seen_params: tp.Set[nn.parameter.Parameter] = set() - other_params = [] - groups = [] - for name, module in model.named_modules(): - if hasattr(module, 'make_optim_group'): - group = module.make_optim_group() - params = set(group['params']) - assert params.isdisjoint(seen_params) - seen_params |= set(params) - groups.append(group) - for param in model.parameters(): - if param not in seen_params: - other_params.append(param) - groups.insert(0, {'params': other_params}) - parameters = groups - return parameters - - -def get_optimizer(params: tp.Union[nn.Module, tp.Iterable[torch.Tensor]], cfg: omegaconf.DictConfig) -> Optimizer: - """Build torch optimizer from config and set of parameters. - Supported optimizers: Adam, AdamW - - Args: - params (nn.Module or iterable of torch.Tensor): Parameters to optimize. - cfg (DictConfig): Optimization-related configuration. - Returns: - torch.optim.Optimizer. - """ - if 'optimizer' not in cfg: - if getattr(cfg, 'optim', None) is not None: - raise KeyError("Optimizer not found in config. Try instantiating optimizer from cfg.optim?") - else: - raise KeyError("Optimizer not found in config.") - - parameters = get_optim_parameter_groups(params) if isinstance(params, nn.Module) else params - optimizer: torch.optim.Optimizer - if cfg.optimizer == 'adam': - optimizer = torch.optim.Adam(parameters, lr=cfg.lr, **cfg.adam) - elif cfg.optimizer == 'adamw': - optimizer = torch.optim.AdamW(parameters, lr=cfg.lr, **cfg.adam) - elif cfg.optimizer == 'dadam': - optimizer = optim.DAdaptAdam(parameters, lr=cfg.lr, **cfg.adam) - else: - raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}") - return optimizer - - -def get_lr_scheduler(optimizer: torch.optim.Optimizer, - cfg: omegaconf.DictConfig, - total_updates: int) -> tp.Optional[LRScheduler]: - """Build torch learning rate scheduler from config and associated optimizer. - Supported learning rate schedulers: ExponentialLRScheduler, PlateauLRScheduler - - Args: - optimizer (torch.optim.Optimizer): Optimizer. - cfg (DictConfig): Schedule-related configuration. - total_updates (int): Total number of updates. - Returns: - torch.optim.Optimizer. - """ - if 'lr_scheduler' not in cfg: - raise KeyError("LR Scheduler not found in config") - - lr_sched: tp.Optional[LRScheduler] = None - if cfg.lr_scheduler == 'step': - lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, **cfg.step) - elif cfg.lr_scheduler == 'exponential': - lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=cfg.exponential) - elif cfg.lr_scheduler == 'cosine': - kwargs = dict_from_config(cfg.cosine) - warmup_steps = kwargs.pop('warmup') - lr_sched = optim.CosineLRScheduler( - optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs) - elif cfg.lr_scheduler == 'polynomial_decay': - kwargs = dict_from_config(cfg.polynomial_decay) - warmup_steps = kwargs.pop('warmup') - lr_sched = optim.PolynomialDecayLRScheduler( - optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs) - elif cfg.lr_scheduler == 'inverse_sqrt': - kwargs = dict_from_config(cfg.inverse_sqrt) - warmup_steps = kwargs.pop('warmup') - lr_sched = optim.InverseSquareRootLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs) - elif cfg.lr_scheduler == 'linear_warmup': - kwargs = dict_from_config(cfg.linear_warmup) - warmup_steps = kwargs.pop('warmup') - lr_sched = optim.LinearWarmupLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs) - elif cfg.lr_scheduler is not None: - raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}") - return lr_sched - - -def get_ema(module_dict: nn.ModuleDict, cfg: omegaconf.DictConfig) -> tp.Optional[optim.ModuleDictEMA]: - """Initialize Exponential Moving Average. - - Args: - module_dict (nn.ModuleDict): ModuleDict for which to compute the EMA. - cfg (omegaconf.DictConfig): Optim EMA configuration. - Returns: - optim.ModuleDictEMA: EMA version of the ModuleDict. - """ - kw: tp.Dict[str, tp.Any] = dict(cfg) - use = kw.pop('use', False) - decay = kw.pop('decay', None) - device = kw.pop('device', None) - if not use: - return None - if len(module_dict) == 0: - raise ValueError("Trying to build EMA but an empty module_dict source is provided!") - ema_module = optim.ModuleDictEMA(module_dict, decay=decay, device=device) - return ema_module - - -def get_loss(loss_name: str, cfg: omegaconf.DictConfig): - """Instantiate loss from configuration.""" - klass = { - 'l1': torch.nn.L1Loss, - 'l2': torch.nn.MSELoss, - 'mel': losses.MelSpectrogramL1Loss, - 'mrstft': losses.MRSTFTLoss, - 'msspec': losses.MultiScaleMelSpectrogramLoss, - 'sisnr': losses.SISNR, - }[loss_name] - kwargs = dict(getattr(cfg, loss_name)) - return klass(**kwargs) - - -def get_balancer(loss_weights: tp.Dict[str, float], cfg: omegaconf.DictConfig) -> losses.Balancer: - """Instantiate loss balancer from configuration for the provided weights.""" - kwargs: tp.Dict[str, tp.Any] = dict_from_config(cfg) - return losses.Balancer(loss_weights, **kwargs) - - -def get_adversary(name: str, cfg: omegaconf.DictConfig) -> nn.Module: - """Initialize adversary from config.""" - klass = { - 'msd': adversarial.MultiScaleDiscriminator, - 'mpd': adversarial.MultiPeriodDiscriminator, - 'msstftd': adversarial.MultiScaleSTFTDiscriminator, - }[name] - adv_cfg: tp.Dict[str, tp.Any] = dict(getattr(cfg, name)) - return klass(**adv_cfg) - - -def get_adversarial_losses(cfg) -> nn.ModuleDict: - """Initialize dict of adversarial losses from config.""" - device = cfg.device - adv_cfg = getattr(cfg, 'adversarial') - adversaries = adv_cfg.get('adversaries', []) - adv_loss_name = adv_cfg['adv_loss'] - feat_loss_name = adv_cfg.get('feat_loss') - normalize = adv_cfg.get('normalize', True) - feat_loss: tp.Optional[adversarial.FeatureMatchingLoss] = None - if feat_loss_name: - assert feat_loss_name in ['l1', 'l2'], f"Feature loss only support L1 or L2 but {feat_loss_name} found." - loss = get_loss(feat_loss_name, cfg) - feat_loss = adversarial.FeatureMatchingLoss(loss, normalize) - loss = adversarial.get_adv_criterion(adv_loss_name) - loss_real = adversarial.get_real_criterion(adv_loss_name) - loss_fake = adversarial.get_fake_criterion(adv_loss_name) - adv_losses = nn.ModuleDict() - for adv_name in adversaries: - adversary = get_adversary(adv_name, cfg).to(device) - optimizer = get_optimizer(adversary.parameters(), cfg.optim) - adv_loss = adversarial.AdversarialLoss( - adversary, - optimizer, - loss=loss, - loss_real=loss_real, - loss_fake=loss_fake, - loss_feat=feat_loss, - normalize=normalize - ) - adv_losses[adv_name] = adv_loss - return adv_losses - - -def get_visqol(cfg: omegaconf.DictConfig) -> metrics.ViSQOL: - """Instantiate ViSQOL metric from config.""" - kwargs = dict_from_config(cfg) - return metrics.ViSQOL(**kwargs) - - -def get_fad(cfg: omegaconf.DictConfig) -> metrics.FrechetAudioDistanceMetric: - """Instantiate Frechet Audio Distance metric from config.""" - kwargs = dict_from_config(cfg.tf) - xp = dora.get_xp() - kwargs['log_folder'] = xp.folder - return metrics.FrechetAudioDistanceMetric(**kwargs) - - -def get_kldiv(cfg: omegaconf.DictConfig) -> metrics.KLDivergenceMetric: - """Instantiate KL-Divergence metric from config.""" - kld_metrics = { - 'passt': metrics.PasstKLDivergenceMetric, - } - klass = kld_metrics[cfg.model] - kwargs = dict_from_config(cfg.get(cfg.model)) - return klass(**kwargs) - - -def get_text_consistency(cfg: omegaconf.DictConfig) -> metrics.TextConsistencyMetric: - """Instantiate Text Consistency metric from config.""" - text_consistency_metrics = { - 'clap': metrics.CLAPTextConsistencyMetric - } - klass = text_consistency_metrics[cfg.model] - kwargs = dict_from_config(cfg.get(cfg.model)) - return klass(**kwargs) - - -def get_chroma_cosine_similarity(cfg: omegaconf.DictConfig) -> metrics.ChromaCosineSimilarityMetric: - """Instantiate Chroma Cosine Similarity metric from config.""" - assert cfg.model == 'chroma_base', "Only support 'chroma_base' method for chroma cosine similarity metric" - kwargs = dict_from_config(cfg.get(cfg.model)) - return metrics.ChromaCosineSimilarityMetric(**kwargs) - - -def get_audio_datasets(cfg: omegaconf.DictConfig, - dataset_type: DatasetType = DatasetType.AUDIO) -> tp.Dict[str, torch.utils.data.DataLoader]: - """Build AudioDataset from configuration. - - Args: - cfg (omegaconf.DictConfig): Configuration. - dataset_type: The type of dataset to create. - Returns: - dict[str, torch.utils.data.DataLoader]: Map of dataloader for each data split. - """ - dataloaders: dict = {} - - sample_rate = cfg.sample_rate - channels = cfg.channels - seed = cfg.seed - max_sample_rate = cfg.datasource.max_sample_rate - max_channels = cfg.datasource.max_channels - - assert cfg.dataset is not None, "Could not find dataset definition in config" - - dataset_cfg = dict_from_config(cfg.dataset) - splits_cfg: dict = {} - splits_cfg['train'] = dataset_cfg.pop('train') - splits_cfg['valid'] = dataset_cfg.pop('valid') - splits_cfg['evaluate'] = dataset_cfg.pop('evaluate') - splits_cfg['generate'] = dataset_cfg.pop('generate') - execute_only_stage = cfg.get('execute_only', None) - - for split, path in cfg.datasource.items(): - if not isinstance(path, str): - continue # skipping this as not a path - if execute_only_stage is not None and split != execute_only_stage: - continue - logger.info(f"Loading audio data split {split}: {str(path)}") - assert ( - cfg.sample_rate <= max_sample_rate - ), f"Expecting a max sample rate of {max_sample_rate} for datasource but {sample_rate} found." - assert ( - cfg.channels <= max_channels - ), f"Expecting a max number of channels of {max_channels} for datasource but {channels} found." - - split_cfg = splits_cfg[split] - split_kwargs = {k: v for k, v in split_cfg.items()} - kwargs = {**dataset_cfg, **split_kwargs} # split kwargs overrides default dataset_cfg - kwargs['sample_rate'] = sample_rate - kwargs['channels'] = channels - - if kwargs.get('permutation_on_files') and cfg.optim.updates_per_epoch: - kwargs['num_samples'] = ( - flashy.distrib.world_size() * cfg.dataset.batch_size * cfg.optim.updates_per_epoch) - - num_samples = kwargs['num_samples'] - shuffle = kwargs['shuffle'] - - return_info = kwargs.pop('return_info') - batch_size = kwargs.pop('batch_size', None) - num_workers = kwargs.pop('num_workers') - - if dataset_type == DatasetType.MUSIC: - dataset = data.music_dataset.MusicDataset.from_meta(path, **kwargs) - elif dataset_type == DatasetType.SOUND: - dataset = data.sound_dataset.SoundDataset.from_meta(path, **kwargs) - elif dataset_type == DatasetType.AUDIO: - dataset = data.info_audio_dataset.InfoAudioDataset.from_meta(path, return_info=return_info, **kwargs) - else: - raise ValueError(f"Dataset type is unsupported: {dataset_type}") - - loader = get_loader( - dataset, - num_samples, - batch_size=batch_size, - num_workers=num_workers, - seed=seed, - collate_fn=dataset.collater if return_info else None, - shuffle=shuffle, - ) - dataloaders[split] = loader - - return dataloaders diff --git a/audiocraft/audiocraft/solvers/compression.py b/audiocraft/audiocraft/solvers/compression.py deleted file mode 100644 index b757503472a3bfbf90e1636999e64913848a7474..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/solvers/compression.py +++ /dev/null @@ -1,328 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import logging -import multiprocessing -from pathlib import Path -import typing as tp - -import flashy -import omegaconf -import torch -from torch import nn - -from . import base, builders -from .. import models, quantization -from ..utils import checkpoint -from ..utils.samples.manager import SampleManager -from ..utils.utils import get_pool_executor - - -logger = logging.getLogger(__name__) - - -class CompressionSolver(base.StandardSolver): - """Solver for compression task. - - The compression task combines a set of perceptual and objective losses - to train an EncodecModel (composed of an encoder-decoder and a quantizer) - to perform high fidelity audio reconstruction. - """ - def __init__(self, cfg: omegaconf.DictConfig): - super().__init__(cfg) - self.rng: torch.Generator # set at each epoch - self.adv_losses = builders.get_adversarial_losses(self.cfg) - self.aux_losses = nn.ModuleDict() - self.info_losses = nn.ModuleDict() - assert not cfg.fsdp.use, "FSDP not supported by CompressionSolver." - loss_weights = dict() - for loss_name, weight in self.cfg.losses.items(): - if loss_name in ['adv', 'feat']: - for adv_name, _ in self.adv_losses.items(): - loss_weights[f'{loss_name}_{adv_name}'] = weight - elif weight > 0: - self.aux_losses[loss_name] = builders.get_loss(loss_name, self.cfg) - loss_weights[loss_name] = weight - else: - self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg) - self.balancer = builders.get_balancer(loss_weights, self.cfg.balancer) - self.register_stateful('adv_losses') - - @property - def best_metric_name(self) -> tp.Optional[str]: - # best model is the last for the compression model - return None - - def build_model(self): - """Instantiate model and optimizer.""" - # Model and optimizer - self.model = models.builders.get_compression_model(self.cfg).to(self.device) - self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) - self.register_stateful('model', 'optimizer') - self.register_best_state('model') - self.register_ema('model') - - def build_dataloaders(self): - """Instantiate audio dataloaders for each stage.""" - self.dataloaders = builders.get_audio_datasets(self.cfg) - - def show(self): - """Show the compression model and employed adversarial loss.""" - self.logger.info(f"Compression model with {self.model.quantizer.total_codebooks} codebooks:") - self.log_model_summary(self.model) - self.logger.info("Adversarial loss:") - self.log_model_summary(self.adv_losses) - self.logger.info("Auxiliary losses:") - self.logger.info(self.aux_losses) - self.logger.info("Info losses:") - self.logger.info(self.info_losses) - - def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): - """Perform one training or valid step on a given batch.""" - x = batch.to(self.device) - y = x.clone() - - qres = self.model(x) - assert isinstance(qres, quantization.QuantizedResult) - y_pred = qres.x - # Log bandwidth in kb/s - metrics['bandwidth'] = qres.bandwidth.mean() - - if self.is_training: - d_losses: dict = {} - if len(self.adv_losses) > 0 and torch.rand(1, generator=self.rng).item() <= 1 / self.cfg.adversarial.every: - for adv_name, adversary in self.adv_losses.items(): - disc_loss = adversary.train_adv(y_pred, y) - d_losses[f'd_{adv_name}'] = disc_loss - metrics['d_loss'] = torch.sum(torch.stack(list(d_losses.values()))) - metrics.update(d_losses) - - balanced_losses: dict = {} - other_losses: dict = {} - - # penalty from quantization - if qres.penalty is not None and qres.penalty.requires_grad: - other_losses['penalty'] = qres.penalty # penalty term from the quantizer - - # adversarial losses - for adv_name, adversary in self.adv_losses.items(): - adv_loss, feat_loss = adversary(y_pred, y) - balanced_losses[f'adv_{adv_name}'] = adv_loss - balanced_losses[f'feat_{adv_name}'] = feat_loss - - # auxiliary losses - for loss_name, criterion in self.aux_losses.items(): - loss = criterion(y_pred, y) - balanced_losses[loss_name] = loss - - # weighted losses - metrics.update(balanced_losses) - metrics.update(other_losses) - metrics.update(qres.metrics) - - if self.is_training: - # backprop losses that are not handled by balancer - other_loss = torch.tensor(0., device=self.device) - if 'penalty' in other_losses: - other_loss += other_losses['penalty'] - if other_loss.requires_grad: - other_loss.backward(retain_graph=True) - ratio1 = sum(p.grad.data.norm(p=2).pow(2) - for p in self.model.parameters() if p.grad is not None) - assert isinstance(ratio1, torch.Tensor) - metrics['ratio1'] = ratio1.sqrt() - - # balancer losses backward, returns effective training loss - # with effective weights at the current batch. - metrics['g_loss'] = self.balancer.backward(balanced_losses, y_pred) - # add metrics corresponding to weight ratios - metrics.update(self.balancer.metrics) - ratio2 = sum(p.grad.data.norm(p=2).pow(2) - for p in self.model.parameters() if p.grad is not None) - assert isinstance(ratio2, torch.Tensor) - metrics['ratio2'] = ratio2.sqrt() - - # optim - flashy.distrib.sync_model(self.model) - if self.cfg.optim.max_norm: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.cfg.optim.max_norm - ) - self.optimizer.step() - self.optimizer.zero_grad() - - # informative losses only - info_losses: dict = {} - with torch.no_grad(): - for loss_name, criterion in self.info_losses.items(): - loss = criterion(y_pred, y) - info_losses[loss_name] = loss - - metrics.update(info_losses) - - # aggregated GAN losses: this is useful to report adv and feat across different adversarial loss setups - adv_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('adv')] - if len(adv_losses) > 0: - metrics['adv'] = torch.sum(torch.stack(adv_losses)) - feat_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('feat')] - if len(feat_losses) > 0: - metrics['feat'] = torch.sum(torch.stack(feat_losses)) - - return metrics - - def run_epoch(self): - # reset random seed at the beginning of the epoch - self.rng = torch.Generator() - self.rng.manual_seed(1234 + self.epoch) - # run epoch - super().run_epoch() - - def evaluate(self): - """Evaluate stage. Runs audio reconstruction evaluation.""" - self.model.eval() - evaluate_stage_name = str(self.current_stage) - - loader = self.dataloaders['evaluate'] - updates = len(loader) - lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates) - average = flashy.averager() - - pendings = [] - ctx = multiprocessing.get_context('spawn') - with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool: - for idx, batch in enumerate(lp): - x = batch.to(self.device) - with torch.no_grad(): - qres = self.model(x) - - y_pred = qres.x.cpu() - y = batch.cpu() # should already be on CPU but just in case - pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg)) - - metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates) - for pending in metrics_lp: - metrics = pending.result() - metrics = average(metrics) - - metrics = flashy.distrib.average_metrics(metrics, len(loader)) - return metrics - - def generate(self): - """Generate stage.""" - self.model.eval() - sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True) - generate_stage_name = str(self.current_stage) - - loader = self.dataloaders['generate'] - updates = len(loader) - lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) - - for batch in lp: - reference, _ = batch - reference = reference.to(self.device) - with torch.no_grad(): - qres = self.model(reference) - assert isinstance(qres, quantization.QuantizedResult) - - reference = reference.cpu() - estimate = qres.x.cpu() - sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference) - - flashy.distrib.barrier() - - def load_from_pretrained(self, name: str) -> dict: - model = models.CompressionModel.get_pretrained(name) - if isinstance(model, models.DAC): - raise RuntimeError("Cannot fine tune a DAC model.") - elif isinstance(model, models.HFEncodecCompressionModel): - self.logger.warning('Trying to automatically convert a HuggingFace model ' - 'to AudioCraft, this might fail!') - state = model.model.state_dict() - new_state = {} - for k, v in state.items(): - if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k: - # We need to determine if this a convtr or a regular conv. - layer = int(k.split('.')[2]) - if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d): - - k = k.replace('.conv.', '.convtr.') - k = k.replace('encoder.layers.', 'encoder.model.') - k = k.replace('decoder.layers.', 'decoder.model.') - k = k.replace('conv.', 'conv.conv.') - k = k.replace('convtr.', 'convtr.convtr.') - k = k.replace('quantizer.layers.', 'quantizer.vq.layers.') - k = k.replace('.codebook.', '._codebook.') - new_state[k] = v - state = new_state - elif isinstance(model, models.EncodecModel): - state = model.state_dict() - else: - raise RuntimeError(f"Cannot fine tune model type {type(model)}.") - return { - 'best_state': {'model': state} - } - - @staticmethod - def model_from_checkpoint(checkpoint_path: tp.Union[Path, str], - device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: - """Instantiate a CompressionModel from a given checkpoint path or dora sig. - This method is a convenient endpoint to load a CompressionModel to use in other solvers. - - Args: - checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. - This also supports pre-trained models by using a path of the form //pretrained/NAME. - See `model_from_pretrained` for a list of supported pretrained models. - use_ema (bool): Use EMA variant of the model instead of the actual model. - device (torch.device or str): Device on which the model is loaded. - """ - checkpoint_path = str(checkpoint_path) - if checkpoint_path.startswith('//pretrained/'): - name = checkpoint_path.split('/', 3)[-1] - return models.CompressionModel.get_pretrained(name, device) - logger = logging.getLogger(__name__) - logger.info(f"Loading compression model from checkpoint: {checkpoint_path}") - _checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False) - assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}" - state = checkpoint.load_checkpoint(_checkpoint_path) - assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}" - cfg = state['xp.cfg'] - cfg.device = device - compression_model = models.builders.get_compression_model(cfg).to(device) - assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match" - - assert 'best_state' in state and state['best_state'] != {} - assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix." - compression_model.load_state_dict(state['best_state']['model']) - compression_model.eval() - logger.info("Compression model loaded!") - return compression_model - - @staticmethod - def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig, - checkpoint_path: tp.Union[Path, str], - device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: - """Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig. - - Args: - cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode. - checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. - use_ema (bool): Use EMA variant of the model instead of the actual model. - device (torch.device or str): Device on which the model is loaded. - """ - compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device) - compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg) - return compression_model - - -def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor, cfg: omegaconf.DictConfig) -> dict: - """Audio reconstruction evaluation method that can be conveniently pickled.""" - metrics = {} - if cfg.evaluate.metrics.visqol: - visqol = builders.get_visqol(cfg.metrics.visqol) - metrics['visqol'] = visqol(y_pred, y, cfg.sample_rate) - sisnr = builders.get_loss('sisnr', cfg) - metrics['sisnr'] = sisnr(y_pred, y) - return metrics diff --git a/audiocraft/audiocraft/solvers/diffusion.py b/audiocraft/audiocraft/solvers/diffusion.py deleted file mode 100644 index 93dea2520836f458ab1b8514dca952b51d113ec2..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/solvers/diffusion.py +++ /dev/null @@ -1,279 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import typing as tp - -import flashy -import julius -import omegaconf -import torch -import torch.nn.functional as F - -from . import builders -from . import base -from .. import models -from ..modules.diffusion_schedule import NoiseSchedule -from ..metrics import RelativeVolumeMel -from ..models.builders import get_processor -from ..utils.samples.manager import SampleManager -from ..solvers.compression import CompressionSolver - - -class PerStageMetrics: - """Handle prompting the metrics per stage. - It outputs the metrics per range of diffusion states. - e.g. avg loss when t in [250, 500] - """ - def __init__(self, num_steps: int, num_stages: int = 4): - self.num_steps = num_steps - self.num_stages = num_stages - - def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]): - if type(step) is int: - stage = int((step / self.num_steps) * self.num_stages) - return {f"{name}_{stage}": loss for name, loss in losses.items()} - elif type(step) is torch.Tensor: - stage_tensor = ((step / self.num_steps) * self.num_stages).long() - out: tp.Dict[str, float] = {} - for stage_idx in range(self.num_stages): - mask = (stage_tensor == stage_idx) - N = mask.sum() - stage_out = {} - if N > 0: # pass if no elements in the stage - for name, loss in losses.items(): - stage_loss = (mask * loss).sum() / N - stage_out[f"{name}_{stage_idx}"] = stage_loss - out = {**out, **stage_out} - return out - - -class DataProcess: - """Apply filtering or resampling. - - Args: - initial_sr (int): Initial sample rate. - target_sr (int): Target sample rate. - use_resampling: Whether to use resampling or not. - use_filter (bool): - n_bands (int): Number of bands to consider. - idx_band (int): - device (torch.device or str): - cutoffs (): - boost (bool): - """ - def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False, - use_filter: bool = False, n_bands: int = 4, - idx_band: int = 0, device: torch.device = torch.device('cpu'), cutoffs=None, boost=False): - """Apply filtering or resampling - Args: - initial_sr (int): sample rate of the dataset - target_sr (int): sample rate after resampling - use_resampling (bool): whether or not performs resampling - use_filter (bool): when True filter the data to keep only one frequency band - n_bands (int): Number of bands used - cuts (none or list): The cutoff frequencies of the band filtering - if None then we use mel scale bands. - idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs - boost (bool): make the data scale match our music dataset. - """ - assert idx_band < n_bands - self.idx_band = idx_band - if use_filter: - if cutoffs is not None: - self.filter = julius.SplitBands(sample_rate=initial_sr, cutoffs=cutoffs).to(device) - else: - self.filter = julius.SplitBands(sample_rate=initial_sr, n_bands=n_bands).to(device) - self.use_filter = use_filter - self.use_resampling = use_resampling - self.target_sr = target_sr - self.initial_sr = initial_sr - self.boost = boost - - def process_data(self, x, metric=False): - if x is None: - return None - if self.boost: - x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4) - x * 0.22 - if self.use_filter and not metric: - x = self.filter(x)[self.idx_band] - if self.use_resampling: - x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr) - return x - - def inverse_process(self, x): - """Upsampling only.""" - if self.use_resampling: - x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr) - return x - - -class DiffusionSolver(base.StandardSolver): - """Solver for compression task. - - The diffusion task allows for MultiBand diffusion model training. - - Args: - cfg (DictConfig): Configuration. - """ - def __init__(self, cfg: omegaconf.DictConfig): - super().__init__(cfg) - self.cfg = cfg - self.device = cfg.device - self.sample_rate: int = self.cfg.sample_rate - self.codec_model = CompressionSolver.model_from_checkpoint( - cfg.compression_model_checkpoint, device=self.device) - - self.codec_model.set_num_codebooks(cfg.n_q) - assert self.codec_model.sample_rate == self.cfg.sample_rate, ( - f"Codec model sample rate is {self.codec_model.sample_rate} but " - f"Solver sample rate is {self.cfg.sample_rate}." - ) - assert self.codec_model.sample_rate == self.sample_rate, \ - f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} " \ - "don't match." - - self.sample_processor = get_processor(cfg.processor, sample_rate=self.sample_rate) - self.register_stateful('sample_processor') - self.sample_processor.to(self.device) - - self.schedule = NoiseSchedule( - **cfg.schedule, device=self.device, sample_processor=self.sample_processor) - - self.eval_metric: tp.Optional[torch.nn.Module] = None - - self.rvm = RelativeVolumeMel() - self.data_processor = DataProcess(initial_sr=self.sample_rate, target_sr=cfg.resampling.target_sr, - use_resampling=cfg.resampling.use, cutoffs=cfg.filter.cutoffs, - use_filter=cfg.filter.use, n_bands=cfg.filter.n_bands, - idx_band=cfg.filter.idx_band, device=self.device) - - @property - def best_metric_name(self) -> tp.Optional[str]: - if self._current_stage == "evaluate": - return 'rvm' - else: - return 'loss' - - @torch.no_grad() - def get_condition(self, wav: torch.Tensor) -> torch.Tensor: - codes, scale = self.codec_model.encode(wav) - assert scale is None, "Scaled compression models not supported." - emb = self.codec_model.decode_latent(codes) - return emb - - def build_model(self): - """Build model and optimizer as well as optional Exponential Moving Average of the model. - """ - # Model and optimizer - self.model = models.builders.get_diffusion_model(self.cfg).to(self.device) - self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) - self.register_stateful('model', 'optimizer') - self.register_best_state('model') - self.register_ema('model') - - def build_dataloaders(self): - """Build audio dataloaders for each stage.""" - self.dataloaders = builders.get_audio_datasets(self.cfg) - - def show(self): - # TODO - raise NotImplementedError() - - def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): - """Perform one training or valid step on a given batch.""" - x = batch.to(self.device) - loss_fun = F.mse_loss if self.cfg.loss.kind == 'mse' else F.l1_loss - - condition = self.get_condition(x) # [bs, 128, T/hop, n_emb] - sample = self.data_processor.process_data(x) - - input_, target, step = self.schedule.get_training_item(sample, - tensor_step=self.cfg.schedule.variable_step_batch) - out = self.model(input_, step, condition=condition).sample - - base_loss = loss_fun(out, target, reduction='none').mean(dim=(1, 2)) - reference_loss = loss_fun(input_, target, reduction='none').mean(dim=(1, 2)) - loss = base_loss / reference_loss ** self.cfg.loss.norm_power - - if self.is_training: - loss.mean().backward() - flashy.distrib.sync_model(self.model) - self.optimizer.step() - self.optimizer.zero_grad() - metrics = { - 'loss': loss.mean(), 'normed_loss': (base_loss / reference_loss).mean(), - } - metrics.update(self.per_stage({'loss': loss, 'normed_loss': base_loss / reference_loss}, step)) - metrics.update({ - 'std_in': input_.std(), 'std_out': out.std()}) - return metrics - - def run_epoch(self): - # reset random seed at the beginning of the epoch - self.rng = torch.Generator() - self.rng.manual_seed(1234 + self.epoch) - self.per_stage = PerStageMetrics(self.schedule.num_steps, self.cfg.metrics.num_stage) - # run epoch - super().run_epoch() - - def evaluate(self): - """Evaluate stage. - Runs audio reconstruction evaluation. - """ - self.model.eval() - evaluate_stage_name = f'{self.current_stage}' - loader = self.dataloaders['evaluate'] - updates = len(loader) - lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates) - - metrics = {} - n = 1 - for idx, batch in enumerate(lp): - x = batch.to(self.device) - with torch.no_grad(): - y_pred = self.regenerate(x) - - y_pred = y_pred.cpu() - y = batch.cpu() # should already be on CPU but just in case - rvm = self.rvm(y_pred, y) - lp.update(**rvm) - if len(metrics) == 0: - metrics = rvm - else: - for key in rvm.keys(): - metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1) - metrics = flashy.distrib.average_metrics(metrics) - return metrics - - @torch.no_grad() - def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None): - """Regenerate the given waveform.""" - condition = self.get_condition(wav) - initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav)) # sampling rate changes. - result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition, - step_list=step_list) - result = self.data_processor.inverse_process(result) - return result - - def generate(self): - """Generate stage.""" - sample_manager = SampleManager(self.xp) - self.model.eval() - generate_stage_name = f'{self.current_stage}' - - loader = self.dataloaders['generate'] - updates = len(loader) - lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) - - for batch in lp: - reference, _ = batch - reference = reference.to(self.device) - estimate = self.regenerate(reference) - reference = reference.cpu() - estimate = estimate.cpu() - sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference) - flashy.distrib.barrier() diff --git a/audiocraft/audiocraft/solvers/musicgen.py b/audiocraft/audiocraft/solvers/musicgen.py deleted file mode 100644 index ab2167b7958023274b04deedecc1b0b694dc83c7..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/solvers/musicgen.py +++ /dev/null @@ -1,721 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from pathlib import Path -import time -import typing as tp - -import flashy -import math -import omegaconf -import torch -from torch.nn import functional as F - -from . import base, builders -from .compression import CompressionSolver -from .. import metrics as eval_metrics -from .. import models -from ..data.audio_dataset import AudioDataset -from ..data.music_dataset import MusicDataset, MusicInfo, AudioInfo -from ..data.audio_utils import normalize_audio -from ..modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition -from ..utils.cache import CachedBatchWriter, CachedBatchLoader -from ..utils.samples.manager import SampleManager -from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once - - -class MusicGenSolver(base.StandardSolver): - """Solver for MusicGen training task. - - Used in: https://arxiv.org/abs/2306.05284 - """ - DATASET_TYPE: builders.DatasetType = builders.DatasetType.MUSIC - - def __init__(self, cfg: omegaconf.DictConfig): - super().__init__(cfg) - # easier access to sampling parameters - self.generation_params = { - 'use_sampling': self.cfg.generate.lm.use_sampling, - 'temp': self.cfg.generate.lm.temp, - 'top_k': self.cfg.generate.lm.top_k, - 'top_p': self.cfg.generate.lm.top_p, - } - self._best_metric_name: tp.Optional[str] = 'ce' - - self._cached_batch_writer = None - self._cached_batch_loader = None - if cfg.cache.path: - if cfg.cache.write: - self._cached_batch_writer = CachedBatchWriter(Path(cfg.cache.path)) - if self.cfg.cache.write_num_shards: - self.logger.warning("Multiple shard cache, best_metric_name will be set to None.") - self._best_metric_name = None - else: - self._cached_batch_loader = CachedBatchLoader( - Path(cfg.cache.path), cfg.dataset.batch_size, cfg.dataset.num_workers, - min_length=self.cfg.optim.updates_per_epoch or 1) - self.dataloaders['original_train'] = self.dataloaders['train'] - self.dataloaders['train'] = self._cached_batch_loader # type: ignore - - @staticmethod - def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None, - device: tp.Optional[str] = None, autocast: bool = True, - batch_size: tp.Optional[int] = None, - override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None, - **kwargs): - """Mostly a convenience function around magma.train.get_solver_from_sig, - populating all the proper param, deactivating EMA, FSDP, loading the best state, - basically all you need to get a solver ready to "play" with in single GPU mode - and with minimal memory overhead. - - Args: - sig (str): signature to load. - dtype (str or None): potential dtype, as a string, i.e. 'float16'. - device (str or None): potential device, as a string, i.e. 'cuda'. - override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'. - """ - from audiocraft import train - our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}} - our_override_cfg['autocast'] = autocast - if dtype is not None: - our_override_cfg['dtype'] = dtype - if device is not None: - our_override_cfg['device'] = device - if batch_size is not None: - our_override_cfg['dataset'] = {'batch_size': batch_size} - if override_cfg is None: - override_cfg = {} - override_cfg = omegaconf.OmegaConf.merge( - omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore - solver = train.get_solver_from_sig( - sig, override_cfg=override_cfg, - load_best=True, disable_fsdp=True, - ignore_state_keys=['optimizer', 'ema'], **kwargs) - solver.model.eval() - return solver - - def get_formatter(self, stage_name: str) -> flashy.Formatter: - return flashy.Formatter({ - 'lr': '.2E', - 'ce': '.3f', - 'ppl': '.3f', - 'grad_norm': '.3E', - }, exclude_keys=['ce_q*', 'ppl_q*']) - - @property - def best_metric_name(self) -> tp.Optional[str]: - return self._best_metric_name - - def build_model(self) -> None: - """Instantiate models and optimizer.""" - # we can potentially not use all quantizers with which the EnCodec model was trained - # (e.g. we trained the model with quantizers dropout) - self.compression_model = CompressionSolver.wrapped_model_from_checkpoint( - self.cfg, self.cfg.compression_model_checkpoint, device=self.device) - assert self.compression_model.sample_rate == self.cfg.sample_rate, ( - f"Compression model sample rate is {self.compression_model.sample_rate} but " - f"Solver sample rate is {self.cfg.sample_rate}." - ) - # ensure we have matching configuration between LM and compression model - assert self.cfg.transformer_lm.card == self.compression_model.cardinality, ( - "Cardinalities of the LM and compression model don't match: ", - f"LM cardinality is {self.cfg.transformer_lm.card} vs ", - f"compression model cardinality is {self.compression_model.cardinality}" - ) - #assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks, ( - # "Numbers of codebooks of the LM and compression models don't match: ", - # f"LM number of codebooks is {self.cfg.transformer_lm.n_q} vs ", - # f"compression model numer of codebooks is {self.compression_model.num_codebooks}" - #) - self.logger.info("Compression model has %d codebooks with %d cardinality, and a framerate of %d", - self.compression_model.num_codebooks, self.compression_model.cardinality, - self.compression_model.frame_rate) - # instantiate LM model - self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(self.device) - if self.cfg.fsdp.use: - assert not self.cfg.autocast, "Cannot use autocast with fsdp" - self.model = self.wrap_with_fsdp(self.model) - - # freeze some weight - for name, param in self.model.named_parameters(): - param.requires_grad = False - - layer_idxs = [idx for idx in range(0, 48, 4)] # jump freeze - for name, param in self.model.named_parameters(): - for idx in layer_idxs: - if name.startswith(f"transformer.layers.{idx}."): - param.requires_grad = True - if name.startswith("out_norm") or name.startswith("linears"): - param.requires_grad = True - if name.startswith("condition_provider.conditioners.chord") or name.startswith("condition_provider.conditioners.beat"): - param.requires_grad = True - # if name.startswith("condition_provider.conditioners.beat"): - # param.requires_grad = True - # if name.startswith("emb"): - # param.requires_grad = True - - self.register_ema('model') - # initialize optimization - self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim) - self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates) - self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler') - self.register_best_state('model') - self.autocast_dtype = { - 'float16': torch.float16, 'bfloat16': torch.bfloat16 - }[self.cfg.autocast_dtype] - self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None - if self.cfg.fsdp.use: - need_scaler = self.cfg.fsdp.param_dtype == 'float16' - else: - need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16 - if need_scaler: - if self.cfg.fsdp.use: - from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler - self.scaler = ShardedGradScaler() # type: ignore - else: - self.scaler = torch.cuda.amp.GradScaler() - self.register_stateful('scaler') - - def build_dataloaders(self) -> None: - """Instantiate audio dataloaders for each stage.""" - self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE) - - def show(self) -> None: - """Show the compression model and LM model.""" - self.logger.info("Compression model:") - self.log_model_summary(self.compression_model) - self.logger.info("LM model:") - self.log_model_summary(self.model) - - def load_state_dict(self, state: dict) -> None: - if 'condition_provider' in state: - model_state = state['model'] - condition_provider_state = state.pop('condition_provider') - prefix = 'condition_provider.' - for key, value in condition_provider_state.items(): - key = prefix + key - assert key not in model_state - model_state[key] = value - super().load_state_dict(state) - - def load_from_pretrained(self, name: str): - # TODO: support native HF versions of MusicGen. - lm_pkg = models.loaders.load_lm_model_ckpt(name) - state: dict = { - 'best_state': { - 'model': lm_pkg['best_state'], - }, - } - return state - - def _compute_cross_entropy( - self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor - ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: - """Compute cross entropy between multi-codebook targets and model's logits. - The cross entropy is computed per codebook to provide codebook-level cross entropy. - Valid timesteps for each of the codebook are pulled from the mask, where invalid - timesteps are set to 0. - - Args: - logits (torch.Tensor): Model's logits of shape [B, K, T, card]. - targets (torch.Tensor): Target codes, of shape [B, K, T]. - mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. - Returns: - ce (torch.Tensor): Cross entropy averaged over the codebooks - ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). - """ - B, K, T = targets.shape - assert logits.shape[:-1] == targets.shape - assert mask.shape == targets.shape - ce = torch.zeros([], device=targets.device) - ce_per_codebook: tp.List[torch.Tensor] = [] - for k in range(K): - logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] - targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] - mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] - ce_targets = targets_k[mask_k] - ce_logits = logits_k[mask_k] - q_ce = F.cross_entropy(ce_logits, ce_targets) - ce += q_ce - ce_per_codebook.append(q_ce.detach()) - # average cross entropy across codebooks - ce = ce / K - return ce, ce_per_codebook - - @torch.no_grad() - def _prepare_tokens_and_attributes( - self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], - check_synchronization_points: bool = False - ) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]: - """Prepare input batchs for language model training. - - Args: - batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T] - and corresponding metadata as SegmentWithAttributes (with B items). - check_synchronization_points (bool): Whether to check for synchronization points slowing down training. - Returns: - Condition tensors (dict[str, any]): Preprocessed condition attributes. - Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s], - with B the batch size, K the number of codebooks, T_s the token timesteps. - Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s]. - """ - if self._cached_batch_loader is None or self.current_stage != "train": - audio, infos = batch - audio = audio.to(self.device) - audio_tokens = None - assert audio.size(0) == len(infos), ( - f"Mismatch between number of items in audio batch ({audio.size(0)})", - f" and in metadata ({len(infos)})" - ) - else: - audio = None - # In that case the batch will be a tuple coming from the _cached_batch_writer bit below. - infos, = batch # type: ignore - assert all([isinstance(info, AudioInfo) for info in infos]) - assert all([info.audio_tokens is not None for info in infos]) # type: ignore - audio_tokens = torch.stack([info.audio_tokens for info in infos]).to(self.device) # type: ignore - audio_tokens = audio_tokens.long() - for info in infos: - if isinstance(info, MusicInfo): - # Careful here, if you want to use this condition_wav (e.b. chroma conditioning), - # then you must be using the chroma cache! otherwise the code will try - # to use this segment and fail (by that I mean you will see NaN everywhere). - info.self_wav = WavCondition( - torch.full([1, info.channels, info.total_frames], float('NaN')), - length=torch.tensor([info.n_frames]), - sample_rate=[info.sample_rate], - path=[info.meta.path], - seek_time=[info.seek_time]) - dataset = get_dataset_from_loader(self.dataloaders['original_train']) - assert isinstance(dataset, MusicDataset), type(dataset) - if dataset.paraphraser is not None and info.description is not None: - # Hackingly reapplying paraphraser when using cache. - info.description = dataset.paraphraser.sample_paraphrase( - info.meta.path, info.description) - # prepare attributes - attributes = [info.to_condition_attributes() for info in infos] - attributes = self.model.cfg_dropout(attributes) - attributes = self.model.att_dropout(attributes) - tokenized = self.model.condition_provider.tokenize(attributes) - - # Now we should be synchronization free. - if self.device == "cuda" and check_synchronization_points: - torch.cuda.set_sync_debug_mode("warn") - - if audio_tokens is None: - with torch.no_grad(): - audio_tokens, scale = self.compression_model.encode(audio) - assert scale is None, "Scaled compression model not supported with LM." - - with self.autocast: - condition_tensors = self.model.condition_provider(tokenized) - - # create a padding mask to hold valid vs invalid positions - padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device) - # replace encodec tokens from padded audio with special_token_id - if self.cfg.tokens.padding_with_special_token: - audio_tokens = audio_tokens.clone() - padding_mask = padding_mask.clone() - token_sample_rate = self.compression_model.frame_rate - B, K, T_s = audio_tokens.shape - for i in range(B): - n_samples = infos[i].n_frames - audio_sample_rate = infos[i].sample_rate - # take the last token generated from actual audio frames (non-padded audio) - valid_tokens = math.floor(float(n_samples) / audio_sample_rate * token_sample_rate) - audio_tokens[i, :, valid_tokens:] = self.model.special_token_id - padding_mask[i, :, valid_tokens:] = 0 - - if self.device == "cuda" and check_synchronization_points: - torch.cuda.set_sync_debug_mode("default") - - if self._cached_batch_writer is not None and self.current_stage == 'train': - assert self._cached_batch_loader is None - assert audio_tokens is not None - for info, one_audio_tokens in zip(infos, audio_tokens): - assert isinstance(info, AudioInfo) - if isinstance(info, MusicInfo): - assert not info.joint_embed, "joint_embed and cache not supported yet." - info.self_wav = None - assert one_audio_tokens.max() < 2**15, one_audio_tokens.max().item() - info.audio_tokens = one_audio_tokens.short().cpu() - self._cached_batch_writer.save(infos) - - return condition_tensors, audio_tokens, padding_mask - - def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict: - """Perform one training or valid step on a given batch.""" - check_synchronization_points = idx == 1 and self.device == 'cuda' - - condition_tensors, audio_tokens, padding_mask = self._prepare_tokens_and_attributes( - batch, check_synchronization_points) - - self.deadlock_detect.update('tokens_and_conditions') - - if check_synchronization_points: - torch.cuda.set_sync_debug_mode('warn') - - with self.autocast: - model_output = self.model.compute_predictions(audio_tokens, [], condition_tensors) # type: ignore - logits = model_output.logits - mask = padding_mask & model_output.mask - ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) - loss = ce - self.deadlock_detect.update('loss') - - if check_synchronization_points: - torch.cuda.set_sync_debug_mode('default') - - if self.is_training: - metrics['lr'] = self.optimizer.param_groups[0]['lr'] - if self.scaler is not None: - loss = self.scaler.scale(loss) - self.deadlock_detect.update('scale') - # apply grad accum - loss = loss / self.cfg.optim.grad_accum_steps - if self.cfg.fsdp.use: - loss.backward() - flashy.distrib.average_tensors(self.model.buffers()) - elif self.cfg.optim.eager_sync: - with flashy.distrib.eager_sync_model(self.model): - loss.backward() - else: - # this should always be slower but can be useful - # for weird use cases like multiple backwards. - loss.backward() - flashy.distrib.sync_model(self.model) - self.deadlock_detect.update('backward') - - if idx % self.cfg.optim.grad_accum_steps == 0: - if self.scaler is not None: - self.scaler.unscale_(self.optimizer) - if self.cfg.optim.max_norm: - if self.cfg.fsdp.use: - metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm) # type: ignore - else: - metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.cfg.optim.max_norm - ) - if self.scaler is None: - self.optimizer.step() - else: - self.scaler.step(self.optimizer) - self.scaler.update() - if self.lr_scheduler: - self.lr_scheduler.step() - self.optimizer.zero_grad() - self.deadlock_detect.update('optim') - if self.scaler is not None: - scale = self.scaler.get_scale() - metrics['grad_scale'] = scale - if not loss.isfinite().all(): - raise RuntimeError("Model probably diverged.") - - metrics['ce'] = ce - metrics['ppl'] = torch.exp(ce) - for k, ce_q in enumerate(ce_per_codebook): - metrics[f'ce_q{k + 1}'] = ce_q - metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q) - - return metrics - - @torch.no_grad() - def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], - gen_duration: float, prompt_duration: tp.Optional[float] = None, - remove_prompt: bool = False, - **generation_params) -> dict: - """Run generate step on a batch of optional audio tensor and corresponding attributes. - - Args: - batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): - use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch. - gen_duration (float): Target audio duration for the generation. - prompt_duration (float, optional): Duration for the audio prompt to use for continuation. - remove_prompt (bool, optional): Whether to remove the prompt from the generated audio. - generation_params: Additional generation parameters. - Returns: - gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation - and the prompt along with additional information. - """ - bench_start = time.time() - audio, meta = batch - assert audio.size(0) == len(meta), ( - f"Mismatch between number of items in audio batch ({audio.size(0)})", - f" and in metadata ({len(meta)})" - ) - # prepare attributes - attributes = [x.to_condition_attributes() for x in meta] - # TODO: Add dropout for chroma? - - # prepare audio prompt - if prompt_duration is None: - prompt_audio = None - else: - assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration" - prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate) - prompt_audio = audio[..., :prompt_audio_frames] - - # get audio tokens from compression model - if prompt_audio is None or prompt_audio.nelement() == 0: - num_samples = len(attributes) - prompt_tokens = None - else: - num_samples = None - prompt_audio = prompt_audio.to(self.device) - prompt_tokens, scale = self.compression_model.encode(prompt_audio) - assert scale is None, "Compression model in MusicGen should not require rescaling." - - # generate by sampling from the LM - with self.autocast: - total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate) - gen_tokens = self.model.generate( - prompt_tokens, attributes, max_gen_len=total_gen_len, - num_samples=num_samples, **self.generation_params) - - # generate audio from tokens - assert gen_tokens.dim() == 3 - gen_audio = self.compression_model.decode(gen_tokens, None) - - bench_end = time.time() - gen_outputs = { - 'rtf': (bench_end - bench_start) / gen_duration, - 'ref_audio': audio, - 'gen_audio': gen_audio, - 'gen_tokens': gen_tokens, - 'prompt_audio': prompt_audio, - 'prompt_tokens': prompt_tokens, - } - return gen_outputs - - def generate_audio(self) -> dict: - """Audio generation stage.""" - generate_stage_name = f'{self.current_stage}' - sample_manager = SampleManager(self.xp) - self.logger.info(f"Generating samples in {sample_manager.base_folder}") - loader = self.dataloaders['generate'] - updates = len(loader) - lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) - - dataset = get_dataset_from_loader(loader) - dataset_duration = dataset.segment_duration - assert dataset_duration is not None - assert isinstance(dataset, AudioDataset) - target_duration = self.cfg.generate.lm.gen_duration - prompt_duration = self.cfg.generate.lm.prompt_duration - if target_duration is None: - target_duration = dataset_duration - if prompt_duration is None: - prompt_duration = dataset_duration / 4 - assert prompt_duration < dataset_duration, ( - f"Specified prompt duration ({prompt_duration}s) is longer", - f" than reference audio duration ({dataset_duration}s)" - ) - - def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]): - hydrated_conditions = [] - for sample in [x.to_condition_attributes() for x in meta]: - cond_dict = {} - for cond_type in sample.__annotations__.keys(): - for cond_key, cond_val in getattr(sample, cond_type).items(): - if cond_key not in self.model.condition_provider.conditioners.keys(): - continue - if is_jsonable(cond_val): - cond_dict[cond_key] = cond_val - elif isinstance(cond_val, WavCondition): - cond_dict[cond_key] = cond_val.path - elif isinstance(cond_val, JointEmbedCondition): - cond_dict[cond_key] = cond_val.text # only support text at inference for now - else: - # if we reached this point, it is not clear how to log the condition - # so we just log the type. - cond_dict[cond_key] = str(type(cond_val)) - continue - hydrated_conditions.append(cond_dict) - return hydrated_conditions - - metrics: dict = {} - average = flashy.averager() - for batch in lp: - audio, meta = batch - # metadata for sample manager - hydrated_conditions = get_hydrated_conditions(meta) - sample_generation_params = { - **{f'classifier_free_guidance_{k}': v for k, v in self.cfg.classifier_free_guidance.items()}, - **self.generation_params - } - if self.cfg.generate.lm.unprompted_samples: - if self.cfg.generate.lm.gen_gt_samples: - # get the ground truth instead of generation - self.logger.warn( - "Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true") - gen_unprompted_audio = audio - rtf = 1. - else: - gen_unprompted_outputs = self.run_generate_step( - batch, gen_duration=target_duration, prompt_duration=prompt_duration, - **self.generation_params) - gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu() - rtf = gen_unprompted_outputs['rtf'] - sample_manager.add_samples( - gen_unprompted_audio, self.epoch, hydrated_conditions, - ground_truth_wavs=audio, generation_args=sample_generation_params) - - if self.cfg.generate.lm.prompted_samples: - gen_outputs = self.run_generate_step( - batch, gen_duration=target_duration, prompt_duration=prompt_duration, - **self.generation_params) - gen_audio = gen_outputs['gen_audio'].cpu() - prompt_audio = gen_outputs['prompt_audio'].cpu() - sample_manager.add_samples( - gen_audio, self.epoch, hydrated_conditions, - prompt_wavs=prompt_audio, ground_truth_wavs=audio, - generation_args=sample_generation_params) - - metrics['rtf'] = rtf - metrics = average(metrics) - - flashy.distrib.barrier() - return metrics - - def generate(self) -> dict: - """Generate stage.""" - self.model.eval() - with torch.no_grad(): - return self.generate_audio() - - def run_epoch(self): - if self.cfg.cache.write: - if ((self.epoch - 1) % self.cfg.cache.write_num_shards) != self.cfg.cache.write_shard: - return - super().run_epoch() - - def train(self): - """Train stage. - """ - if self._cached_batch_writer is not None: - self._cached_batch_writer.start_epoch(self.epoch) - if self._cached_batch_loader is None: - dataset = get_dataset_from_loader(self.dataloaders['train']) - assert isinstance(dataset, AudioDataset) - dataset.current_epoch = self.epoch - else: - self._cached_batch_loader.start_epoch(self.epoch) - return super().train() - - def evaluate_audio_generation(self) -> dict: - """Evaluate audio generation with off-the-shelf metrics.""" - evaluate_stage_name = f'{self.current_stage}_generation' - # instantiate evaluation metrics, if at least one metric is defined, run audio generation evaluation - fad: tp.Optional[eval_metrics.FrechetAudioDistanceMetric] = None - kldiv: tp.Optional[eval_metrics.KLDivergenceMetric] = None - text_consistency: tp.Optional[eval_metrics.TextConsistencyMetric] = None - chroma_cosine: tp.Optional[eval_metrics.ChromaCosineSimilarityMetric] = None - should_run_eval = False - eval_chroma_wavs: tp.Optional[torch.Tensor] = None - if self.cfg.evaluate.metrics.fad: - fad = builders.get_fad(self.cfg.metrics.fad).to(self.device) - should_run_eval = True - if self.cfg.evaluate.metrics.kld: - kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device) - should_run_eval = True - if self.cfg.evaluate.metrics.text_consistency: - text_consistency = builders.get_text_consistency(self.cfg.metrics.text_consistency).to(self.device) - should_run_eval = True - if self.cfg.evaluate.metrics.chroma_cosine: - chroma_cosine = builders.get_chroma_cosine_similarity(self.cfg.metrics.chroma_cosine).to(self.device) - # if we have predefind wavs for chroma we should purge them for computing the cosine metric - has_predefined_eval_chromas = 'self_wav' in self.model.condition_provider.conditioners and \ - self.model.condition_provider.conditioners['self_wav'].has_eval_wavs() - if has_predefined_eval_chromas: - warn_once(self.logger, "Attempting to run cosine eval for config with pre-defined eval chromas! " - 'Resetting eval chromas to None for evaluation.') - eval_chroma_wavs = self.model.condition_provider.conditioners.self_wav.eval_wavs # type: ignore - self.model.condition_provider.conditioners.self_wav.reset_eval_wavs(None) # type: ignore - should_run_eval = True - - def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor: - audio_tokens, scale = self.compression_model.encode(audio.to(self.device)) - compressed_audio = self.compression_model.decode(audio_tokens, scale) - return compressed_audio[..., :audio.shape[-1]] - - metrics: dict = {} - if should_run_eval: - loader = self.dataloaders['evaluate'] - updates = len(loader) - lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates) - average = flashy.averager() - dataset = get_dataset_from_loader(loader) - assert isinstance(dataset, AudioDataset) - self.logger.info(f"Computing evaluation metrics on {len(dataset)} samples") - - for idx, batch in enumerate(lp): - audio, meta = batch - assert all([self.cfg.sample_rate == m.sample_rate for m in meta]) - - target_duration = audio.shape[-1] / self.cfg.sample_rate - if self.cfg.evaluate.fixed_generation_duration: - target_duration = self.cfg.evaluate.fixed_generation_duration - - gen_outputs = self.run_generate_step( - batch, gen_duration=target_duration, - **self.generation_params - ) - y_pred = gen_outputs['gen_audio'].detach() - y_pred = y_pred[..., :audio.shape[-1]] - - normalize_kwargs = dict(self.cfg.generate.audio) - normalize_kwargs.pop('format', None) - y_pred = torch.stack([normalize_audio(w, **normalize_kwargs) for w in y_pred], dim=0).cpu() - y = audio.cpu() # should already be on CPU but just in case - sizes = torch.tensor([m.n_frames for m in meta]) # actual sizes without padding - sample_rates = torch.tensor([m.sample_rate for m in meta]) # sample rates for audio samples - audio_stems = [Path(m.meta.path).stem + f"_{m.seek_time}" for m in meta] - - if fad is not None: - if self.cfg.metrics.fad.use_gt: - y_pred = get_compressed_audio(y).cpu() - fad.update(y_pred, y, sizes, sample_rates, audio_stems) - if kldiv is not None: - if self.cfg.metrics.kld.use_gt: - y_pred = get_compressed_audio(y).cpu() - kldiv.update(y_pred, y, sizes, sample_rates) - if text_consistency is not None: - texts = [m.description for m in meta] - if self.cfg.metrics.text_consistency.use_gt: - y_pred = y - text_consistency.update(y_pred, texts, sizes, sample_rates) - if chroma_cosine is not None: - if self.cfg.metrics.chroma_cosine.use_gt: - y_pred = get_compressed_audio(y).cpu() - chroma_cosine.update(y_pred, y, sizes, sample_rates) - # restore chroma conditioner's eval chroma wavs - if eval_chroma_wavs is not None: - self.model.condition_provider.conditioners['self_wav'].reset_eval_wavs(eval_chroma_wavs) - - flashy.distrib.barrier() - if fad is not None: - metrics['fad'] = fad.compute() - if kldiv is not None: - kld_metrics = kldiv.compute() - metrics.update(kld_metrics) - if text_consistency is not None: - metrics['text_consistency'] = text_consistency.compute() - if chroma_cosine is not None: - metrics['chroma_cosine'] = chroma_cosine.compute() - metrics = average(metrics) - metrics = flashy.distrib.average_metrics(metrics, len(loader)) - - return metrics - - def evaluate(self) -> dict: - """Evaluate stage.""" - self.model.eval() - with torch.no_grad(): - metrics: dict = {} - if self.cfg.evaluate.metrics.base: - metrics.update(self.common_train_valid('evaluate')) - gen_metrics = self.evaluate_audio_generation() - return {**metrics, **gen_metrics} diff --git a/audiocraft/audiocraft/train.py b/audiocraft/audiocraft/train.py deleted file mode 100644 index 22dd117830bb403829d0a60b1b95e120d1e6978b..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/train.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Entry point for dora to launch solvers for running training loops. -See more info on how to use dora: https://github.com/facebookresearch/dora -""" - -import logging -import multiprocessing -import os -import sys -import typing as tp - -from dora import git_save, hydra_main, XP -import flashy -import hydra -import omegaconf - -from .environment import AudioCraftEnvironment -from .utils.cluster import get_slurm_parameters - -logger = logging.getLogger(__name__) - - -def resolve_config_dset_paths(cfg): - """Enable Dora to load manifest from git clone repository.""" - # manifest files for the different splits - for key, value in cfg.datasource.items(): - if isinstance(value, str): - cfg.datasource[key] = git_save.to_absolute_path(value) - - -def get_solver(cfg): - from . import solvers - # Convert batch size to batch size for each GPU - assert cfg.dataset.batch_size % flashy.distrib.world_size() == 0 - cfg.dataset.batch_size //= flashy.distrib.world_size() - for split in ['train', 'valid', 'evaluate', 'generate']: - if hasattr(cfg.dataset, split) and hasattr(cfg.dataset[split], 'batch_size'): - assert cfg.dataset[split].batch_size % flashy.distrib.world_size() == 0 - cfg.dataset[split].batch_size //= flashy.distrib.world_size() - resolve_config_dset_paths(cfg) - solver = solvers.get_solver(cfg) - return solver - - -def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None, - restore: bool = True, load_best: bool = True, - ignore_state_keys: tp.List[str] = [], disable_fsdp: bool = True): - """Given a XP, return the Solver object. - - Args: - xp (XP): Dora experiment for which to retrieve the solver. - override_cfg (dict or None): If not None, should be a dict used to - override some values in the config of `xp`. This will not impact - the XP signature or folder. The format is different - than the one used in Dora grids, nested keys should actually be nested dicts, - not flattened, e.g. `{'optim': {'batch_size': 32}}`. - restore (bool): If `True` (the default), restore state from the last checkpoint. - load_best (bool): If `True` (the default), load the best state from the checkpoint. - ignore_state_keys (list[str]): List of sources to ignore when loading the state, e.g. `optimizer`. - disable_fsdp (bool): if True, disables FSDP entirely. This will - also automatically skip loading the EMA. For solver specific - state sources, like the optimizer, you might want to - use along `ignore_state_keys=['optimizer']`. Must be used with `load_best=True`. - """ - logger.info(f"Loading solver from XP {xp.sig}. " - f"Overrides used: {xp.argv}") - cfg = xp.cfg - if override_cfg is not None: - cfg = omegaconf.OmegaConf.merge(cfg, omegaconf.DictConfig(override_cfg)) - if disable_fsdp and cfg.fsdp.use: - cfg.fsdp.use = False - assert load_best is True - # ignoring some keys that were FSDP sharded like model, ema, and best_state. - # fsdp_best_state will be used in that case. When using a specific solver, - # one is responsible for adding the relevant keys, e.g. 'optimizer'. - # We could make something to automatically register those inside the solver, but that - # seem overkill at this point. - ignore_state_keys = ignore_state_keys + ['model', 'ema', 'best_state'] - - try: - with xp.enter(): - solver = get_solver(cfg) - if restore: - solver.restore(load_best=load_best, ignore_state_keys=ignore_state_keys) - return solver - finally: - hydra.core.global_hydra.GlobalHydra.instance().clear() - - -def get_solver_from_sig(sig: str, *args, **kwargs): - """Return Solver object from Dora signature, i.e. to play with it from a notebook. - See `get_solver_from_xp` for more information. - """ - xp = main.get_xp_from_sig(sig) - return get_solver_from_xp(xp, *args, **kwargs) - - -def init_seed_and_system(cfg): - import numpy as np - import torch - import random - from audiocraft.modules.transformer import set_efficient_attention_backend - - multiprocessing.set_start_method(cfg.mp_start_method) - logger.debug('Setting mp start method to %s', cfg.mp_start_method) - random.seed(cfg.seed) - np.random.seed(cfg.seed) - # torch also initialize cuda seed if available - torch.manual_seed(cfg.seed) - torch.set_num_threads(cfg.num_threads) - os.environ['MKL_NUM_THREADS'] = str(cfg.num_threads) - os.environ['OMP_NUM_THREADS'] = str(cfg.num_threads) - logger.debug('Setting num threads to %d', cfg.num_threads) - set_efficient_attention_backend(cfg.efficient_attention_backend) - logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend) - - -@hydra_main(config_path='../config', config_name='config', version_base='1.1') -def main(cfg): - init_seed_and_system(cfg) - - # Setup logging both to XP specific folder, and to stderr. - log_name = '%s.log.{rank}' % cfg.execute_only if cfg.execute_only else 'solver.log.{rank}' - flashy.setup_logging(level=str(cfg.logging.level).upper(), log_name=log_name) - # Initialize distributed training, no need to specify anything when using Dora. - flashy.distrib.init() - solver = get_solver(cfg) - if cfg.show: - solver.show() - return - - if cfg.execute_only: - assert cfg.execute_inplace or cfg.continue_from is not None, \ - "Please explicitly specify the checkpoint to continue from with continue_from= " + \ - "when running with execute_only or set execute_inplace to True." - solver.restore(replay_metrics=False) # load checkpoint - solver.run_one_stage(cfg.execute_only) - return - - return solver.run() - - -main.dora.dir = AudioCraftEnvironment.get_dora_dir() -main._base_cfg.slurm = get_slurm_parameters(main._base_cfg.slurm) - -if main.dora.shared is not None and not os.access(main.dora.shared, os.R_OK): - print("No read permission on dora.shared folder, ignoring it.", file=sys.stderr) - main.dora.shared = None - -if __name__ == '__main__': - main() diff --git a/audiocraft/audiocraft/utils/__init__.py b/audiocraft/audiocraft/utils/__init__.py deleted file mode 100644 index 75e25a0212f98e4a18d97c86c6cda225636a3215..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/utils/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -"""Utilities.""" diff --git a/audiocraft/audiocraft/utils/__pycache__/__init__.cpython-311.pyc b/audiocraft/audiocraft/utils/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 093b5eb8070af48aea57c7c726a9f5d5f8262e50..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/utils/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/utils/__pycache__/autocast.cpython-311.pyc b/audiocraft/audiocraft/utils/__pycache__/autocast.cpython-311.pyc deleted file mode 100644 index 605b9b3ab2738226b1464ce76bb3bd80a9abc568..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/utils/__pycache__/autocast.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/utils/__pycache__/best_state.cpython-311.pyc b/audiocraft/audiocraft/utils/__pycache__/best_state.cpython-311.pyc deleted file mode 100644 index 1c4626d6833196e4447cc68dea2a8b1c0b2efe20..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/utils/__pycache__/best_state.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/utils/__pycache__/cache.cpython-311.pyc b/audiocraft/audiocraft/utils/__pycache__/cache.cpython-311.pyc deleted file mode 100644 index fc2a3836af340d27cc29eae79176b9283a78b5d4..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/utils/__pycache__/cache.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/utils/__pycache__/checkpoint.cpython-311.pyc b/audiocraft/audiocraft/utils/__pycache__/checkpoint.cpython-311.pyc deleted file mode 100644 index 10d9f0186035f7ecc47864cc57b527c821c50be9..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/utils/__pycache__/checkpoint.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/utils/__pycache__/cluster.cpython-311.pyc b/audiocraft/audiocraft/utils/__pycache__/cluster.cpython-311.pyc deleted file mode 100644 index 3b5fe8e616f43de5539822172f5478582eb8c5e3..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/utils/__pycache__/cluster.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/utils/__pycache__/deadlock.cpython-311.pyc b/audiocraft/audiocraft/utils/__pycache__/deadlock.cpython-311.pyc deleted file mode 100644 index dbc86c87ba5d39c11c0d081c5e65d11df09dde78..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/utils/__pycache__/deadlock.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/utils/__pycache__/export.cpython-311.pyc b/audiocraft/audiocraft/utils/__pycache__/export.cpython-311.pyc deleted file mode 100644 index e5d39b027b8c230a6280515eab89f67cd63b1b85..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/utils/__pycache__/export.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/utils/__pycache__/profiler.cpython-311.pyc b/audiocraft/audiocraft/utils/__pycache__/profiler.cpython-311.pyc deleted file mode 100644 index 0829d206ed84c57aef536602d3e8550b56125c0d..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/utils/__pycache__/profiler.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/utils/__pycache__/utils.cpython-311.pyc b/audiocraft/audiocraft/utils/__pycache__/utils.cpython-311.pyc deleted file mode 100644 index 8f782fb1fd7b3fa4801e2bceb48ea863c160eb85..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/utils/__pycache__/utils.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/utils/autocast.py b/audiocraft/audiocraft/utils/autocast.py deleted file mode 100644 index ed644843bb37cf8a92a20fbd51d6cebaa43b9a08..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/utils/autocast.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch - - -class TorchAutocast: - """TorchAutocast utility class. - Allows you to enable and disable autocast. This is specially useful - when dealing with different architectures and clusters with different - levels of support. - - Args: - enabled (bool): Whether to enable torch.autocast or not. - args: Additional args for torch.autocast. - kwargs: Additional kwargs for torch.autocast - """ - def __init__(self, enabled: bool, *args, **kwargs): - self.autocast = torch.autocast(*args, **kwargs) if enabled else None - - def __enter__(self): - if self.autocast is None: - return - try: - self.autocast.__enter__() - except RuntimeError: - device = self.autocast.device - dtype = self.autocast.fast_dtype - raise RuntimeError( - f"There was an error autocasting with dtype={dtype} device={device}\n" - "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16" - ) - - def __exit__(self, *args, **kwargs): - if self.autocast is None: - return - self.autocast.__exit__(*args, **kwargs) diff --git a/audiocraft/audiocraft/utils/best_state.py b/audiocraft/audiocraft/utils/best_state.py deleted file mode 100644 index f5ad551432ad5cb0f83278b5d2100f9aa287958b..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/utils/best_state.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from collections import defaultdict -import logging -import typing as tp - -import flashy -import torch - -from ..optim import ModuleDictEMA -from .utils import copy_state - - -logger = logging.getLogger(__name__) - - -class BestStateDictManager(flashy.state.StateDictSource): - """BestStateDictManager maintains a copy of best state_dict() for registered sources. - - BestStateDictManager has two main attributes: - states (dict): State dict of the registered StateDictSource. - param_ids (dict): Dict of parameter ids for registered states from ModuleDictEMA and other sources. - - When registering new sources, the BestStateDictManager will ensure two conflicting sources between - ModuleDictEMA and original modules are not both registered as it would otherwise create ambiguity about - what to consider for best state. - - Args: - device (torch.device or str): Device on which we keep the copy. - dtype (torch.dtype): Data type for the state parameters. - """ - def __init__(self, device: tp.Union[torch.device, str] = 'cpu', - dtype: tp.Optional[torch.dtype] = None): - self.device = device - self.states: dict = {} - self.param_ids: dict = defaultdict(dict) - self.dtype = dtype - - def _get_parameter_ids(self, state_dict): - return {id(p): name for name, p in state_dict.items() if isinstance(p, torch.Tensor)} - - def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict): - for registered_name, registered_param_ids in self.param_ids.items(): - if registered_name != name: - overlap = set.intersection(registered_param_ids.keys(), param_ids.keys()) - assert len(overlap) == 0, f"Found {len(overlap)} / {len(param_ids.keys())} overlapping parameters" - f" in {name} and already registered {registered_name}: {' '.join(overlap)}" - - def update(self, name: str, source: flashy.state.StateDictSource): - if name not in self.states: - raise ValueError(f"{name} missing from registered states.") - self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype) - - def register(self, name: str, source: flashy.state.StateDictSource): - if name in self.states: - raise ValueError(f"{name} already present in states.") - # Registering parameter ids for EMA and non-EMA states allows us to check that - # there is no overlap that would create ambiguity about how to handle the best state - param_ids = self._get_parameter_ids(source.state_dict()) - if isinstance(source, ModuleDictEMA): - logger.debug(f"Registering to best state: ModuleDictEMA '{name}' with {len(param_ids)} params") - self._validate_no_parameter_ids_overlap(name, param_ids) - self.param_ids[name] = param_ids - else: - logger.debug(f"Registering to best state: StateDictSource '{name}' with {len(param_ids)} params") - self._validate_no_parameter_ids_overlap('base', param_ids) - self.param_ids['base'].update(param_ids) - # Register state - self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype) - - def state_dict(self) -> flashy.state.StateDict: - return self.states - - def load_state_dict(self, state: flashy.state.StateDict): - for name, sub_state in state.items(): - for k, v in sub_state.items(): - self.states[name][k].copy_(v) diff --git a/audiocraft/audiocraft/utils/cache.py b/audiocraft/audiocraft/utils/cache.py deleted file mode 100644 index f7f82064e8f43b86af1071cab4d967cca9b5bd86..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/utils/cache.py +++ /dev/null @@ -1,323 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from concurrent.futures import ThreadPoolExecutor -from collections import deque -from functools import partial -from hashlib import sha1 -import logging -from pathlib import Path -import sys -import typing as tp -import zipfile - -import flashy -import torch - - -logger = logging.getLogger(__name__) - - -def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device: tp.Union[str, torch.device]) -> torch.Tensor: - """Utility function for the EmbeddingCache, returning the full embedding without any chunking. - This method can be used in case there is no need in extracting a chunk of the full embedding - read from the cache. - - Args: - full_embed (torch.Tensor): The full embedding. - x (any): Batch object from which the full embedding is derived. - idx (torch.Tensor): Index of object to consider in the batch object. - Returns: - full_embed (torch.Tensor): The full embedding - """ - return full_embed.to(device) - - -class EmbeddingCache: - """Cache around embeddings computation for faster execution. - The EmbeddingCache is storing pre-computed embeddings on disk and provides a simple API - to retrieve the pre-computed embeddings on full inputs and extract only a given chunk - using a user-provided function. When the cache is warm (all embeddings are pre-computed), - the EmbeddingCache allows for faster training as it removes the need of computing the embeddings. - Additionally, it provides in-memory cache around the loaded embeddings to limit IO footprint - and synchronization points in the forward calls. - - Args: - cache_path (Path): Path to folder where all pre-computed embeddings are saved on disk. - device (str or torch.device): Device on which the embedding is returned. - compute_embed_fn (callable[[Path, any, int], torch.Tensor], optional): Function to compute - the embedding from a given object and path. This user provided function can compute the - embedding from the provided object or using the provided path as entry point. The last parameter - specify the index corresponding to the current embedding in the object that can represent batch metadata. - extract_embed_fn (callable[[torch.Tensor, any, int], torch.Tensor], optional): Function to extract - the desired embedding chunk from the full embedding loaded from the cache. The last parameter - specify the index corresponding to the current embedding in the object that can represent batch metadata. - If not specified, will return the full embedding unmodified. - """ - def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[str, torch.device], - compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor], - extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None): - self.cache_path = Path(cache_path) - self.device = device - self._compute_embed_fn = compute_embed_fn - self._extract_embed_fn: tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor] - if extract_embed_fn is not None: - self._extract_embed_fn = extract_embed_fn - else: - self._extract_embed_fn = partial(get_full_embed, device=device) - if self.cache_path is not None: - self.cache_path.mkdir(exist_ok=True, parents=True) - logger.info(f"Cache instantiated at: {self.cache_path}") - self.pool = ThreadPoolExecutor(8) - self.pool.__enter__() - self._current_batch_cache: dict = {} - self._memory_cache: dict = {} - - def _get_cache_path(self, path: tp.Union[Path, str]): - """Get cache path for the given file path.""" - sig = sha1(str(path).encode()).hexdigest() - return self.cache_path / sig - - @staticmethod - def _get_full_embed_from_cache(cache: Path): - """Loads full pre-computed embedding from the cache.""" - try: - embed = torch.load(cache, 'cpu') - except Exception as exc: - logger.error("Error loading %s: %r", cache, exc) - embed = None - return embed - - def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> torch.Tensor: - """Get embedding from cache, computing and storing it to cache if not already cached. - The EmbeddingCache first tries to load the embedding from the in-memory cache - containing the pre-computed chunks populated through `populate_embed_cache`. - If not found, the full embedding is computed and stored on disk to be later accessed - to populate the in-memory cache, and the desired embedding chunk is extracted and returned. - - Args: - paths (list[Path or str]): List of paths from where the embeddings can be loaded. - x (any): Object from which the embedding is extracted. - """ - embeds = [] - for idx, path in enumerate(paths): - cache = self._get_cache_path(path) - if cache in self._current_batch_cache: - embed = self._current_batch_cache[cache] - else: - full_embed = self._compute_embed_fn(path, x, idx) - try: - with flashy.utils.write_and_rename(cache, pid=True) as f: - torch.save(full_embed.cpu(), f) - except Exception as exc: - logger.error('Error saving embed %s (%s): %r', cache, full_embed.shape, exc) - else: - logger.info('New embed cache saved: %s (%s)', cache, full_embed.shape) - embed = self._extract_embed_fn(full_embed, x, idx) - embeds.append(embed) - embed = torch.stack(embeds, dim=0) - return embed - - def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None: - """Populate in-memory caches for embeddings reading from the embeddings stored on disk. - The in-memory caches consist in a cache for the full embedding and another cache for the - final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings - and reduce the IO footprint and synchronization points during forward passes. - - Args: - paths (list[Path]): List of paths from where the embeddings can be loaded. - x (any): Object from which the embedding is extracted. - """ - self._current_batch_cache.clear() - if self.cache_path is not None: - futures: list = [] - for path in paths: - assert path is not None, "Path is required for computation from cache" - cache = self._get_cache_path(path) - if cache in self._memory_cache or not cache.exists(): - futures.append(None) - else: - futures.append(self.pool.submit(EmbeddingCache._get_full_embed_from_cache, cache)) - for idx, (path, future) in enumerate(zip(paths, futures)): - assert path is not None - cache = self._get_cache_path(path) - full_embed = None - if future is None: - if cache in self._memory_cache: - full_embed = self._memory_cache[cache] - else: - full_embed = future.result() - if full_embed is not None: - self._memory_cache[cache] = full_embed - full_embed = full_embed.to(self.device) - if full_embed is not None: - embed = self._extract_embed_fn(full_embed, x, idx) - self._current_batch_cache[cache] = embed - - -class CachedBatchWriter: - """Write pre computed caches for mini batches. This can - make loading a lot more efficient depending on your filesystem. - - Args: - cache_folder (Path): folder in which the cached minibatches - will be stored. - - Inside cache folder, the structure is the following: - `epoch_number / update_number.zip` - And the zip file contains one entry per batch item. - - It is possible to use the cache with a batch size smaller than - created with but obviously not larger. Make sure to call the - `start_epoch(epoch)` method for indicating changes of epochs. - - See the grid `audiocraft/grids/musicgen/musicgen_warmup_cache.py` - for an example of how to warmup the cache. - """ - def __init__(self, cache_folder: Path): - self.cache_folder = cache_folder - self._current_epoch: tp.Optional[int] = None - self._current_index = 0 - - def start_epoch(self, epoch: int): - """Call at the beginning of each epoch. - """ - self._current_epoch = epoch - self._current_index = 0 - self._zip_path.parent.mkdir(exist_ok=True, parents=True) - - @staticmethod - def _get_zip_path(cache_folder: Path, epoch: int, index: int): - return cache_folder / f"{epoch:05d}" / f"{index:06d}.zip" - - @property - def _zip_path(self): - assert self._current_epoch is not None - return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, self._current_index) - - def save(self, *content): - """Save one mini batch. This function is distributed-aware - and will automatically merge all the items from the different - workers. - """ - all_contents = [] - for rank in range(flashy.distrib.world_size()): - their_content = flashy.distrib.broadcast_object(content, src=rank) - all_contents.append(their_content) - - if flashy.distrib.is_rank_zero(): - idx = 0 - with flashy.utils.write_and_rename(self._zip_path) as tmp: - with zipfile.ZipFile(tmp, 'w') as zf: - for content in all_contents: - for vals in zip(*content): - with zf.open(f'{idx}', 'w') as f: # type: ignore - torch.save(vals, f) - idx += 1 - flashy.distrib.barrier() - self._current_index += 1 - - -class CachedBatchLoader: - """Loader for cached mini-batches dumped with `CachedBatchWriter`. - - Args: - cache_folder (Path): folder in which the cached minibatches are stored. - batch_size (int): batch size (per GPU) expected. - num_workers (int): number of workers to use for loading. - min_length (int): minimum expected length for each epoch. If some - mini-batches are missing, and error is raised. - - This is iterable just like a regular DataLoader. - """ - - def __init__(self, cache_folder: Path, batch_size: int, - num_workers: int = 10, min_length: int = 1): - self.cache_folder = cache_folder - self.batch_size = batch_size - self.num_workers = num_workers - self.min_length = min_length - self._current_epoch: tp.Optional[int] = None - self.sampler = None # for compatibility with the regular DataLoader - - def __len__(self): - path = CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch or 0, 0).parent - return len([p for p in path.iterdir() if p.suffix == ".zip"]) - - def start_epoch(self, epoch: int): - """Call at the beginning of each epoch. - """ - self._current_epoch = epoch - - def _zip_path(self, index: int): - assert self._current_epoch is not None - return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, index) - - def _load_one(self, index: int): - zip_path = self._zip_path(index) - if not zip_path.exists(): - if index < self.min_length: - raise RuntimeError(f"Cache should have at least {self.min_length} batches, but {index} doesn't exist") - - return None - mode = "rb" if sys.version_info >= (3, 9) else "r" - try: - with zipfile.ZipFile(zip_path, 'r') as zf: - rank = flashy.distrib.rank() - world_size = flashy.distrib.world_size() - root = zipfile.Path(zf) - items = list(root.iterdir()) - total_batch_size = self.batch_size * world_size - if len(items) < total_batch_size: - raise RuntimeError( - f"The cache can handle a max batch size of {len(items)}, " - f"but {total_batch_size} is needed.") - start = rank * self.batch_size - items = items[start: start + self.batch_size] - assert len(items) == self.batch_size - entries = [] - entries = [torch.load(item.open(mode), 'cpu') for item in items] # type: ignore - transposed = zip(*entries) - out = [] - for part in transposed: - assert len(part) > 0 - if isinstance(part[0], torch.Tensor): - out.append(torch.stack(part)) - else: - out.append(part) - return out - except Exception: - logger.error("Error when reading zip path %s", zip_path) - raise - - def __iter__(self): - """This will yields tuples, exactly as provided to the - `CachedBatchWriter.save` method. - """ - pool = ThreadPoolExecutor(self.num_workers) - next_index = 0 - queue = deque() - - def _get_next(): - nonlocal next_index - r = queue.popleft().result() - if r is None: - return None - else: - queue.append(pool.submit(self._load_one, next_index)) - next_index += 1 - return r - - with pool: - # fill the buffer of fetching jobs. - for _ in range(2 * self.num_workers): - queue.append(pool.submit(self._load_one, next_index)) - next_index += 1 - while True: - batch = _get_next() - if batch is None: - return - yield batch diff --git a/audiocraft/audiocraft/utils/checkpoint.py b/audiocraft/audiocraft/utils/checkpoint.py deleted file mode 100644 index f6f871837e09c5cc7832b85b0d80b84f59e87ca0..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/utils/checkpoint.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from enum import Enum -import logging -from pathlib import Path -import re -import typing as tp - -import flashy -import torch - -from ..environment import AudioCraftEnvironment - - -logger = logging.getLogger(__name__) - - -class CheckpointSource(Enum): - CURRENT_XP = "current_xp" - PRETRAINED = "pretrained" - OTHER = "other" - - -def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str: - """Checkpoint name formatted for all use in AudioCraft codebase and has the following format: - `checkpoint_.th(.)`. By convention, name is expected to be empty for last checkpoint, - 'best' for the best checkpoint or the epoch number. - - Args: - name (str, optional): Name suffix for the checkpoint file stem. - rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided. - use_fsdp (bool): Whether the calling solver relies on FSDP. - Returns: - str: The checkpoint name. - """ - suffix = '' - if rank is None: - rank = flashy.distrib.rank() - if rank > 0 and use_fsdp: - suffix = '.' + str(rank) - name_part = '' - if name is not None: - name_part = f'_{name}' - return f'checkpoint{name_part}.th{suffix}' - - -def is_sharded_checkpoint(path: Path) -> bool: - """Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank.""" - return re.search(r'\.th\.\d+$', path.name) is not None - - -def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None, - use_fsdp: bool = False) -> tp.Optional[Path]: - """Resolve a given checkpoint path for a provided dora sig or path. - - Args: - sig_or_path (Path or str): Checkpoint path or dora signature. - name (str, optional): Name suffix for the checkpoint file stem. - rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided. - use_fsdp (bool): Whether the calling solver relies on FSDP. - Returns: - Path, optional: Resolved checkpoint path, if it exists. - """ - from audiocraft import train - xps_root = train.main.dora.dir / 'xps' - sig_or_path = str(sig_or_path) - if sig_or_path.startswith('//sig/'): - sig = sig_or_path[len('//sig/'):] - path = xps_root / sig - else: - path = Path(sig_or_path) - path = AudioCraftEnvironment.resolve_reference_path(path) - - if path.is_dir(): - path = path / checkpoint_name(name, use_fsdp=use_fsdp) - - if path.exists(): - return path - else: - return None - - -def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any: - """Load state from checkpoints at the specified checkpoint path.""" - if is_sharded: - rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False) - if rank0_checkpoint_path.exists(): - check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path) - state = torch.load(checkpoint_path, 'cpu') - logger.info("Checkpoint loaded from %s", checkpoint_path) - return state - - -def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None: - """Save state to disk to the specified checkpoint_path.""" - _safe_save_checkpoint(state, checkpoint_path, is_sharded) - logger.info("Checkpoint saved to %s", checkpoint_path) - - -def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None: - """Flush checkpoints to only keep last N checkpoints.""" - if keep_last is None or keep_last <= 0: - return - checkpoint_dir = checkpoint_path.parent - suffix = '' - if flashy.distrib.rank() > 0: - suffix = f'.{flashy.distrib.rank()}' - checkpoint_files_with_epoch = [] - for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'): - epoch_part = path.name.split('.', 1)[0].split('_', 1)[1] - if epoch_part.isdigit(): - checkpoint_files_with_epoch.append((path, int(epoch_part))) - checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))] - total_to_flush = max(0, len(checkpoint_files) - keep_last) - files_to_flush = checkpoint_files[:total_to_flush] - for path in files_to_flush: - logger.debug("Removing checkpoint: %s", str(path)) - path.unlink(missing_ok=True) - - -def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None: - """Check sharded checkpoint state, ensuring the checkpoints are not corrupted.""" - # Finish the work of a previous run that got interrupted while dumping. - old_path = Path(str(checkpoint_path) + '.old') - if old_path.exists(): - raise RuntimeError( - f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.") - token = Path(str(rank0_checkpoint_path) + '.tmp.done') - tmp_path = Path(str(checkpoint_path) + '.tmp') - if token.exists(): - if tmp_path.exists(): - tmp_path.rename(checkpoint_path) - flashy.distrib.barrier() - if flashy.distrib.is_rank_zero() and token.exists(): - token.unlink() - - -def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None: - """Save checkpoints in a safe manner even with when sharded checkpoints across nodes.""" - def _barrier_if_sharded(): - if is_sharded: - flashy.distrib.barrier() - - if flashy.distrib.is_rank_zero(): - token = Path(str(checkpoint_path) + '.tmp.done') - if token.exists(): - token.unlink() - _barrier_if_sharded() - with flashy.utils.write_and_rename(checkpoint_path) as f: - torch.save(state, f) - _barrier_if_sharded() - if flashy.distrib.is_rank_zero(): - token.touch() - _barrier_if_sharded() - _barrier_if_sharded() - if flashy.distrib.rank() == 0: - token.unlink() diff --git a/audiocraft/audiocraft/utils/cluster.py b/audiocraft/audiocraft/utils/cluster.py deleted file mode 100644 index 3380d031739d473fb859c76b9c25350f47fa77e8..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/utils/cluster.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Utility functions for SLURM configuration and cluster settings. -""" - -from enum import Enum -import os -import socket -import typing as tp - -import omegaconf - - -class ClusterType(Enum): - AWS = "aws" - FAIR = "fair" - RSC = "rsc" - LOCAL_DARWIN = "darwin" - DEFAULT = "default" # used for any other cluster. - - -def _guess_cluster_type() -> ClusterType: - uname = os.uname() - fqdn = socket.getfqdn() - if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn): - return ClusterType.AWS - - if fqdn.endswith(".fair"): - return ClusterType.FAIR - - if fqdn.endswith(".facebook.com"): - return ClusterType.RSC - - if uname.sysname == "Darwin": - return ClusterType.LOCAL_DARWIN - - return ClusterType.DEFAULT - - -def get_cluster_type( - cluster_type: tp.Optional[ClusterType] = None, -) -> tp.Optional[ClusterType]: - if cluster_type is None: - return _guess_cluster_type() - - return cluster_type - - -def get_slurm_parameters( - cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None -) -> omegaconf.DictConfig: - """Update SLURM parameters in configuration based on cluster type. - If the cluster type is not specify, it infers it automatically. - """ - from ..environment import AudioCraftEnvironment - cluster_type = get_cluster_type(cluster_type) - # apply cluster-specific adjustments - if cluster_type == ClusterType.AWS: - cfg["mem_per_gpu"] = None - cfg["constraint"] = None - cfg["setup"] = [] - elif cluster_type == ClusterType.RSC: - cfg["mem_per_gpu"] = None - cfg["setup"] = [] - cfg["constraint"] = None - cfg["partition"] = "learn" - slurm_exclude = AudioCraftEnvironment.get_slurm_exclude() - if slurm_exclude is not None: - cfg["exclude"] = slurm_exclude - return cfg diff --git a/audiocraft/audiocraft/utils/deadlock.py b/audiocraft/audiocraft/utils/deadlock.py deleted file mode 100644 index 8abd1bbeea5909e664cf816c020bd7c37effdb66..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/utils/deadlock.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import logging -import os -from queue import Queue, Empty -import signal -import sys -import threading -import traceback - -logger = logging.getLogger(__name__) - - -class DeadlockDetect: - def __init__(self, use: bool = False, timeout: float = 120.): - self.use = use - self.timeout = timeout - self._queue: Queue = Queue() - - def update(self, stage: str): - if self.use: - self._queue.put(stage) - - def __enter__(self): - if self.use: - self._thread = threading.Thread(target=self._detector_thread) - self._thread.start() - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.use: - self._queue.put(None) - self._thread.join() - - def _detector_thread(self): - logger.debug("Deadlock detector started") - last_stage = "init" - while True: - try: - stage = self._queue.get(timeout=self.timeout) - except Empty: - break - if stage is None: - logger.debug("Exiting deadlock detector thread") - return - else: - last_stage = stage - logger.error("Deadlock detector timed out, last stage was %s", last_stage) - for th in threading.enumerate(): - print(th, file=sys.stderr) - traceback.print_stack(sys._current_frames()[th.ident]) - print(file=sys.stderr) - sys.stdout.flush() - sys.stderr.flush() - os.kill(os.getpid(), signal.SIGKILL) diff --git a/audiocraft/audiocraft/utils/export.py b/audiocraft/audiocraft/utils/export.py deleted file mode 100644 index 28b214017d9ac23934b67e8254a96131cefa6501..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/utils/export.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Utility to export a training checkpoint to a lightweight release checkpoint. -""" - -from pathlib import Path -import typing as tp - -from omegaconf import OmegaConf -import torch - -from audiocraft import __version__ - - -def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): - """Export only the best state from the given EnCodec checkpoint. This - should be used if you trained your own EnCodec model. - """ - pkg = torch.load(checkpoint_path, 'cpu') - new_pkg = { - 'best_state': pkg['best_state']['model'], - 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), - 'version': __version__, - 'exported': True, - } - Path(out_file).parent.mkdir(exist_ok=True, parents=True) - torch.save(new_pkg, out_file) - return out_file - - -def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]): - """Export a compression model (potentially EnCodec) from a pretrained model. - This is required for packaging the audio tokenizer along a MusicGen or AudioGen model. - Do not include the //pretrained/ prefix. For instance if you trained a model - with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`. - - In that case, this will not actually include a copy of the model, simply the reference - to the model used. - """ - if Path(pretrained_encodec).exists(): - pkg = torch.load(pretrained_encodec) - assert 'best_state' in pkg - assert 'xp.cfg' in pkg - assert 'version' in pkg - assert 'exported' in pkg - else: - pkg = { - 'pretrained': pretrained_encodec, - 'exported': True, - 'version': __version__, - } - Path(out_file).parent.mkdir(exist_ok=True, parents=True) - torch.save(pkg, out_file) - - -def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): - """Export only the best state from the given MusicGen or AudioGen checkpoint. - """ - pkg = torch.load(checkpoint_path, 'cpu') - if pkg['fsdp_best_state']: - best_state = pkg['fsdp_best_state']['model'] - else: - assert pkg['best_state'] - best_state = pkg['best_state']['model'] - new_pkg = { - 'best_state': best_state, - 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), - 'version': __version__, - 'exported': True, - } - - Path(out_file).parent.mkdir(exist_ok=True, parents=True) - torch.save(new_pkg, out_file) - return out_file diff --git a/audiocraft/audiocraft/utils/export_legacy.py b/audiocraft/audiocraft/utils/export_legacy.py deleted file mode 100644 index 52f145f3148c3e9fdba436273bc45480fbae6481..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/utils/export_legacy.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -Legacy functions used at the time of the first release, kept for referencd. -""" - -from pathlib import Path -import typing as tp - -from omegaconf import OmegaConf, DictConfig -import torch - - -def _clean_lm_cfg(cfg: DictConfig): - OmegaConf.set_struct(cfg, False) - # This used to be set automatically in the LM solver, need a more robust solution - # for the future. - cfg['transformer_lm']['card'] = 2048 - cfg['transformer_lm']['n_q'] = 4 - # Experimental params no longer supported. - bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters', - 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop'] - for name in bad_params: - del cfg['transformer_lm'][name] - OmegaConf.set_struct(cfg, True) - return cfg - - -def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): - sig = Path(checkpoint_path).parent.name - assert len(sig) == 8, "Not a valid Dora signature" - pkg = torch.load(checkpoint_path, 'cpu') - new_pkg = { - 'best_state': pkg['ema']['state']['model'], - 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), - } - out_file = Path(out_folder) / f'{sig}.th' - torch.save(new_pkg, out_file) - return out_file - - -def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): - sig = Path(checkpoint_path).parent.name - assert len(sig) == 8, "Not a valid Dora signature" - pkg = torch.load(checkpoint_path, 'cpu') - new_pkg = { - 'best_state': pkg['fsdp_best_state']['model'], - 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])) - } - out_file = Path(out_folder) / f'{sig}.th' - torch.save(new_pkg, out_file) - return out_file diff --git a/audiocraft/audiocraft/utils/notebook.py b/audiocraft/audiocraft/utils/notebook.py deleted file mode 100644 index 019b9d19e5bef976bedddf428fd25da42a8a9726..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/utils/notebook.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -try: - import IPython.display as ipd # type: ignore -except ImportError: - # Note in a notebook... - pass - - -import torch - - -def display_audio(samples: torch.Tensor, sample_rate: int): - """Renders an audio player for the given audio samples. - - Args: - samples (torch.Tensor): a Tensor of decoded audio samples - with shapes [B, C, T] or [C, T] - sample_rate (int): sample rate audio should be displayed with. - """ - assert samples.dim() == 2 or samples.dim() == 3 - - samples = samples.detach().cpu() - if samples.dim() == 2: - samples = samples[None, ...] - - for audio in samples: - ipd.display(ipd.Audio(audio, rate=sample_rate)) diff --git a/audiocraft/audiocraft/utils/profiler.py b/audiocraft/audiocraft/utils/profiler.py deleted file mode 100644 index b45b6d15910b50305c7b212c089ffad3c25b324d..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/utils/profiler.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import logging -import typing as tp - -import dora -import torch - - -logger = logging.getLogger(__name__) - - -class Profiler: - """Context manager wrapper for xformers profiler. - """ - def __init__(self, module: torch.nn.Module, enabled: bool = False): - self.profiler: tp.Optional[tp.Any] = None - if enabled: - from xformers.profiler import profile - output_dir = dora.get_xp().folder / 'profiler_data' - logger.info("Profiling activated, results with be saved to %s", output_dir) - self.profiler = profile(output_dir=output_dir, module=module) - - def step(self): - if self.profiler is not None: - self.profiler.step() # type: ignore - - def __enter__(self): - if self.profiler is not None: - return self.profiler.__enter__() # type: ignore - - def __exit__(self, exc_type, exc_value, exc_tb): - if self.profiler is not None: - return self.profiler.__exit__(exc_type, exc_value, exc_tb) # type: ignore diff --git a/audiocraft/audiocraft/utils/samples/__init__.py b/audiocraft/audiocraft/utils/samples/__init__.py deleted file mode 100644 index 0952fcc3f57e34b3747962e9ebd6fc57aeea63fa..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/utils/samples/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. diff --git a/audiocraft/audiocraft/utils/samples/__pycache__/__init__.cpython-311.pyc b/audiocraft/audiocraft/utils/samples/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 19d854c4561dbe27b2b24cdd4768e5bcb3f59685..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/utils/samples/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/utils/samples/__pycache__/manager.cpython-311.pyc b/audiocraft/audiocraft/utils/samples/__pycache__/manager.cpython-311.pyc deleted file mode 100644 index ddc98a809dfb0203679bb4440fdec7e8b12541af..0000000000000000000000000000000000000000 Binary files a/audiocraft/audiocraft/utils/samples/__pycache__/manager.cpython-311.pyc and /dev/null differ diff --git a/audiocraft/audiocraft/utils/samples/manager.py b/audiocraft/audiocraft/utils/samples/manager.py deleted file mode 100644 index bf0fb21b2d2867c03f7cce6f27d9524fdb89b51d..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/utils/samples/manager.py +++ /dev/null @@ -1,386 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -""" -API that can manage the storage and retrieval of generated samples produced by experiments. - -It offers the following benefits: -* Samples are stored in a consistent way across epoch -* Metadata about the samples can be stored and retrieved -* Can retrieve audio -* Identifiers are reliable and deterministic for prompted and conditioned samples -* Can request the samples for multiple XPs, grouped by sample identifier -* For no-input samples (not prompt and no conditions), samples across XPs are matched - by sorting their identifiers -""" - -from concurrent.futures import ThreadPoolExecutor -from dataclasses import asdict, dataclass -from functools import lru_cache -import hashlib -import json -import logging -from pathlib import Path -import re -import typing as tp -import unicodedata -import uuid - -import dora -import torch - -from ...data.audio import audio_read, audio_write - - -logger = logging.getLogger(__name__) - - -@dataclass -class ReferenceSample: - id: str - path: str - duration: float - - -@dataclass -class Sample: - id: str - path: str - epoch: int - duration: float - conditioning: tp.Optional[tp.Dict[str, tp.Any]] - prompt: tp.Optional[ReferenceSample] - reference: tp.Optional[ReferenceSample] - generation_args: tp.Optional[tp.Dict[str, tp.Any]] - - def __hash__(self): - return hash(self.id) - - def audio(self) -> tp.Tuple[torch.Tensor, int]: - return audio_read(self.path) - - def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]: - return audio_read(self.prompt.path) if self.prompt is not None else None - - def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]: - return audio_read(self.reference.path) if self.reference is not None else None - - -class SampleManager: - """Audio samples IO handling within a given dora xp. - - The sample manager handles the dumping and loading logic for generated and - references samples across epochs for a given xp, providing a simple API to - store, retrieve and compare audio samples. - - Args: - xp (dora.XP): Dora experiment object. The XP contains information on the XP folder - where all outputs are stored and the configuration of the experiment, - which is useful to retrieve audio-related parameters. - map_reference_to_sample_id (bool): Whether to use the sample_id for all reference samples - instead of generating a dedicated hash id. This is useful to allow easier comparison - with ground truth sample from the files directly without having to read the JSON metadata - to do the mapping (at the cost of potentially dumping duplicate prompts/references - depending on the task). - """ - def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = False): - self.xp = xp - self.base_folder: Path = xp.folder / xp.cfg.generate.path - self.reference_folder = self.base_folder / 'reference' - self.map_reference_to_sample_id = map_reference_to_sample_id - self.samples: tp.List[Sample] = [] - self._load_samples() - - @property - def latest_epoch(self): - """Latest epoch across all samples.""" - return max(self.samples, key=lambda x: x.epoch).epoch if self.samples else 0 - - def _load_samples(self): - """Scan the sample folder and load existing samples.""" - jsons = self.base_folder.glob('**/*.json') - with ThreadPoolExecutor(6) as pool: - self.samples = list(pool.map(self._load_sample, jsons)) - - @staticmethod - @lru_cache(2**26) - def _load_sample(json_file: Path) -> Sample: - with open(json_file, 'r') as f: - data: tp.Dict[str, tp.Any] = json.load(f) - # fetch prompt data - prompt_data = data.get('prompt') - prompt = ReferenceSample(id=prompt_data['id'], path=prompt_data['path'], - duration=prompt_data['duration']) if prompt_data else None - # fetch reference data - reference_data = data.get('reference') - reference = ReferenceSample(id=reference_data['id'], path=reference_data['path'], - duration=reference_data['duration']) if reference_data else None - # build sample object - return Sample(id=data['id'], path=data['path'], epoch=data['epoch'], duration=data['duration'], - prompt=prompt, conditioning=data.get('conditioning'), reference=reference, - generation_args=data.get('generation_args')) - - def _init_hash(self): - return hashlib.sha1() - - def _get_tensor_id(self, tensor: torch.Tensor) -> str: - hash_id = self._init_hash() - hash_id.update(tensor.numpy().data) - return hash_id.hexdigest() - - def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Tensor], - conditions: tp.Optional[tp.Dict[str, str]]) -> str: - """Computes an id for a sample given its input data. - This id is deterministic if prompt and/or conditions are provided by using a sha1 hash on the input. - Otherwise, a random id of the form "noinput_{uuid4().hex}" is returned. - - Args: - index (int): Batch index, Helpful to differentiate samples from the same batch. - prompt_wav (torch.Tensor): Prompt used during generation. - conditions (dict[str, str]): Conditioning used during generation. - """ - # For totally unconditioned generations we will just use a random UUID. - # The function get_samples_for_xps will do a simple ordered match with a custom key. - if prompt_wav is None and not conditions: - return f"noinput_{uuid.uuid4().hex}" - - # Human readable portion - hr_label = "" - # Create a deterministic id using hashing - hash_id = self._init_hash() - hash_id.update(f"{index}".encode()) - if prompt_wav is not None: - hash_id.update(prompt_wav.numpy().data) - hr_label += "_prompted" - else: - hr_label += "_unprompted" - if conditions: - encoded_json = json.dumps(conditions, sort_keys=True).encode() - hash_id.update(encoded_json) - cond_str = "-".join([f"{key}={slugify(value)}" - for key, value in sorted(conditions.items())]) - cond_str = cond_str[:100] # some raw text might be too long to be a valid filename - cond_str = cond_str if len(cond_str) > 0 else "unconditioned" - hr_label += f"_{cond_str}" - else: - hr_label += "_unconditioned" - - return hash_id.hexdigest() + hr_label - - def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: bool = False) -> Path: - """Stores the audio with the given stem path using the XP's configuration. - - Args: - wav (torch.Tensor): Audio to store. - stem_path (Path): Path in sample output directory with file stem to use. - overwrite (bool): When False (default), skips storing an existing audio file. - Returns: - Path: The path at which the audio is stored. - """ - existing_paths = [ - path for path in stem_path.parent.glob(stem_path.stem + '.*') - if path.suffix != '.json' - ] - exists = len(existing_paths) > 0 - if exists and overwrite: - logger.warning(f"Overwriting existing audio file with stem path {stem_path}") - elif exists: - return existing_paths[0] - - audio_path = audio_write(stem_path, wav, **self.xp.cfg.generate.audio) - return audio_path - - def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int = 0, - conditions: tp.Optional[tp.Dict[str, str]] = None, prompt_wav: tp.Optional[torch.Tensor] = None, - ground_truth_wav: tp.Optional[torch.Tensor] = None, - generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> Sample: - """Adds a single sample. - The sample is stored in the XP's sample output directory, under a corresponding epoch folder. - Each sample is assigned an id which is computed using the input data. In addition to the - sample itself, a json file containing associated metadata is stored next to it. - - Args: - sample_wav (torch.Tensor): sample audio to store. Tensor of shape [channels, shape]. - epoch (int): current training epoch. - index (int): helpful to differentiate samples from the same batch. - conditions (dict[str, str], optional): conditioning used during generation. - prompt_wav (torch.Tensor, optional): prompt used during generation. Tensor of shape [channels, shape]. - ground_truth_wav (torch.Tensor, optional): reference audio where prompt was extracted from. - Tensor of shape [channels, shape]. - generation_args (dict[str, any], optional): dictionary of other arguments used during generation. - Returns: - Sample: The saved sample. - """ - sample_id = self._get_sample_id(index, prompt_wav, conditions) - reuse_id = self.map_reference_to_sample_id - prompt, ground_truth = None, None - if prompt_wav is not None: - prompt_id = sample_id if reuse_id else self._get_tensor_id(prompt_wav.sum(0, keepdim=True)) - prompt_duration = prompt_wav.shape[-1] / self.xp.cfg.sample_rate - prompt_path = self._store_audio(prompt_wav, self.base_folder / str(epoch) / 'prompt' / prompt_id) - prompt = ReferenceSample(prompt_id, str(prompt_path), prompt_duration) - if ground_truth_wav is not None: - ground_truth_id = sample_id if reuse_id else self._get_tensor_id(ground_truth_wav.sum(0, keepdim=True)) - ground_truth_duration = ground_truth_wav.shape[-1] / self.xp.cfg.sample_rate - ground_truth_path = self._store_audio(ground_truth_wav, self.base_folder / 'reference' / ground_truth_id) - ground_truth = ReferenceSample(ground_truth_id, str(ground_truth_path), ground_truth_duration) - sample_path = self._store_audio(sample_wav, self.base_folder / str(epoch) / sample_id, overwrite=True) - duration = sample_wav.shape[-1] / self.xp.cfg.sample_rate - sample = Sample(sample_id, str(sample_path), epoch, duration, conditions, prompt, ground_truth, generation_args) - self.samples.append(sample) - with open(sample_path.with_suffix('.json'), 'w') as f: - json.dump(asdict(sample), f, indent=2) - return sample - - def add_samples(self, samples_wavs: torch.Tensor, epoch: int, - conditioning: tp.Optional[tp.List[tp.Dict[str, tp.Any]]] = None, - prompt_wavs: tp.Optional[torch.Tensor] = None, - ground_truth_wavs: tp.Optional[torch.Tensor] = None, - generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> tp.List[Sample]: - """Adds a batch of samples. - The samples are stored in the XP's sample output directory, under a corresponding - epoch folder. Each sample is assigned an id which is computed using the input data and their batch index. - In addition to the sample itself, a json file containing associated metadata is stored next to it. - - Args: - sample_wavs (torch.Tensor): Batch of audio wavs to store. Tensor of shape [batch_size, channels, shape]. - epoch (int): Current training epoch. - conditioning (list of dict[str, str], optional): List of conditions used during generation, - one per sample in the batch. - prompt_wavs (torch.Tensor, optional): Prompts used during generation. Tensor of shape - [batch_size, channels, shape]. - ground_truth_wav (torch.Tensor, optional): Reference audio where prompts were extracted from. - Tensor of shape [batch_size, channels, shape]. - generation_args (dict[str, Any], optional): Dictionary of other arguments used during generation. - Returns: - samples (list of Sample): The saved audio samples with prompts, ground truth and metadata. - """ - samples = [] - for idx, wav in enumerate(samples_wavs): - prompt_wav = prompt_wavs[idx] if prompt_wavs is not None else None - gt_wav = ground_truth_wavs[idx] if ground_truth_wavs is not None else None - conditions = conditioning[idx] if conditioning is not None else None - samples.append(self.add_sample(wav, epoch, idx, conditions, prompt_wav, gt_wav, generation_args)) - return samples - - def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_prompted: bool = False, - exclude_unprompted: bool = False, exclude_conditioned: bool = False, - exclude_unconditioned: bool = False) -> tp.Set[Sample]: - """Returns a set of samples for this XP. Optionally, you can filter which samples to obtain. - Please note that existing samples are loaded during the manager's initialization, and added samples through this - manager are also tracked. Any other external changes are not tracked automatically, so creating a new manager - is the only way detect them. - - Args: - epoch (int): If provided, only return samples corresponding to this epoch. - max_epoch (int): If provided, only return samples corresponding to the latest epoch that is <= max_epoch. - exclude_prompted (bool): If True, does not include samples that used a prompt. - exclude_unprompted (bool): If True, does not include samples that did not use a prompt. - exclude_conditioned (bool): If True, excludes samples that used conditioning. - exclude_unconditioned (bool): If True, excludes samples that did not use conditioning. - Returns: - Samples (set of Sample): The retrieved samples matching the provided filters. - """ - if max_epoch >= 0: - samples_epoch = max(sample.epoch for sample in self.samples if sample.epoch <= max_epoch) - else: - samples_epoch = self.latest_epoch if epoch < 0 else epoch - samples = { - sample - for sample in self.samples - if ( - (sample.epoch == samples_epoch) and - (not exclude_prompted or sample.prompt is None) and - (not exclude_unprompted or sample.prompt is not None) and - (not exclude_conditioned or not sample.conditioning) and - (not exclude_unconditioned or sample.conditioning) - ) - } - return samples - - -def slugify(value: tp.Any, allow_unicode: bool = False): - """Process string for safer file naming. - - Taken from https://github.com/django/django/blob/master/django/utils/text.py - - Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated - dashes to single dashes. Remove characters that aren't alphanumerics, - underscores, or hyphens. Convert to lowercase. Also strip leading and - trailing whitespace, dashes, and underscores. - """ - value = str(value) - if allow_unicode: - value = unicodedata.normalize("NFKC", value) - else: - value = ( - unicodedata.normalize("NFKD", value) - .encode("ascii", "ignore") - .decode("ascii") - ) - value = re.sub(r"[^\w\s-]", "", value.lower()) - return re.sub(r"[-\s]+", "-", value).strip("-_") - - -def _match_stable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]: - # Create a dictionary of stable id -> sample per XP - stable_samples_per_xp = [{ - sample.id: sample for sample in samples - if sample.prompt is not None or sample.conditioning - } for samples in samples_per_xp] - # Set of all stable ids - stable_ids = {id for samples in stable_samples_per_xp for id in samples.keys()} - # Dictionary of stable id -> list of samples. If an XP does not have it, assign None - stable_samples = {id: [xp.get(id) for xp in stable_samples_per_xp] for id in stable_ids} - # Filter out ids that contain None values (we only want matched samples after all) - # cast is necessary to avoid mypy linter errors. - return {id: tp.cast(tp.List[Sample], samples) for id, samples in stable_samples.items() if None not in samples} - - -def _match_unstable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]: - # For unstable ids, we use a sorted list since we'll match them in order - unstable_samples_per_xp = [[ - sample for sample in sorted(samples, key=lambda x: x.id) - if sample.prompt is None and not sample.conditioning - ] for samples in samples_per_xp] - # Trim samples per xp so all samples can have a match - min_len = min([len(samples) for samples in unstable_samples_per_xp]) - unstable_samples_per_xp = [samples[:min_len] for samples in unstable_samples_per_xp] - # Dictionary of index -> list of matched samples - return { - f'noinput_{i}': [samples[i] for samples in unstable_samples_per_xp] for i in range(min_len) - } - - -def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str, tp.List[Sample]]: - """Gets a dictionary of matched samples across the given XPs. - Each dictionary entry maps a sample id to a list of samples for that id. The number of samples per id - will always match the number of XPs provided and will correspond to each XP in the same order given. - In other words, only samples that can be match across all provided XPs will be returned - in order to satisfy this rule. - - There are two types of ids that can be returned: stable and unstable. - * Stable IDs are deterministic ids that were computed by the SampleManager given a sample's inputs - (prompts/conditioning). This is why we can match them across XPs. - * Unstable IDs are of the form "noinput_{idx}" and are generated on-the-fly, in order to map samples - that used non-deterministic, random ids. This is the case for samples that did not use prompts or - conditioning for their generation. This function will sort these samples by their id and match them - by their index. - - Args: - xps: a list of XPs to match samples from. - start_epoch (int): If provided, only return samples corresponding to this epoch or newer. - end_epoch (int): If provided, only return samples corresponding to this epoch or older. - exclude_prompted (bool): If True, does not include samples that used a prompt. - exclude_unprompted (bool): If True, does not include samples that did not use a prompt. - exclude_conditioned (bool): If True, excludes samples that used conditioning. - exclude_unconditioned (bool): If True, excludes samples that did not use conditioning. - """ - managers = [SampleManager(xp) for xp in xps] - samples_per_xp = [manager.get_samples(**kwargs) for manager in managers] - stable_samples = _match_stable_samples(samples_per_xp) - unstable_samples = _match_unstable_samples(samples_per_xp) - return dict(stable_samples, **unstable_samples) diff --git a/audiocraft/audiocraft/utils/utils.py b/audiocraft/audiocraft/utils/utils.py deleted file mode 100644 index 3135d70e949a058095ef84dd87b49384546c465c..0000000000000000000000000000000000000000 --- a/audiocraft/audiocraft/utils/utils.py +++ /dev/null @@ -1,298 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from concurrent.futures import ProcessPoolExecutor -from contextlib import contextmanager -from functools import wraps, lru_cache -import hashlib -import json -import logging -from pathlib import Path -import typing as tp - -import flashy -import flashy.distrib -import omegaconf -import torch -from torch.nn.utils.rnn import pad_sequence - - -logger = logging.getLogger(__name__) - - -def model_hash(model: torch.nn.Module) -> str: - """Return a model hash. This should allow us to track regressions in model init - from the logs of past experiments. - """ - hasher = hashlib.sha1() - for p in model.parameters(): - hasher.update(p.data.cpu().numpy().tobytes()) - return hasher.hexdigest() - - -def dict_from_config(cfg: omegaconf.DictConfig) -> dict: - """Convenience function to map an omegaconf configuration to a dictionary. - - Args: - cfg (omegaconf.DictConfig): Original configuration to map to dict. - Returns: - dict: Config as dictionary object. - """ - dct = omegaconf.OmegaConf.to_container(cfg, resolve=True) - assert isinstance(dct, dict) - return dct - - -def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset: - if max_samples >= len(dataset): - return dataset - - generator = torch.Generator().manual_seed(seed) - perm = torch.randperm(len(dataset), generator=generator) - return torch.utils.data.Subset(dataset, perm[:max_samples].tolist()) - - -def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int, - num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader: - """Convenience function to load dataset into a dataloader with optional subset sampling. - - Args: - dataset: Dataset to load. - num_samples (Optional[int]): Number of samples to limit subset size. - batch_size (int): Batch size. - num_workers (int): Number of workers for data loading. - seed (int): Random seed. - """ - if num_samples is not None: - dataset = random_subset(dataset, num_samples, seed) - - dataloader = flashy.distrib.loader( - dataset, - batch_size=batch_size, - num_workers=num_workers, - **kwargs - ) - return dataloader - - -def get_dataset_from_loader(dataloader): - dataset = dataloader.dataset - if isinstance(dataset, torch.utils.data.Subset): - return dataset.dataset - else: - return dataset - - -def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): - """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. - - Args: - input (torch.Tensor): The input tensor containing probabilities. - num_samples (int): Number of samples to draw. - replacement (bool): Whether to draw with replacement or not. - Keywords args: - generator (torch.Generator): A pseudorandom number generator for sampling. - Returns: - torch.Tensor: Last dimension contains num_samples indices - sampled from the multinomial probability distribution - located in the last dimension of tensor input. - """ - input_ = input.reshape(-1, input.shape[-1]) - output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) - output = output_.reshape(*list(input.shape[:-1]), -1) - return output - - -def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: - """Sample next token from top K values along the last dimension of the input probs tensor. - - Args: - probs (torch.Tensor): Input probabilities with token candidates on the last dimension. - k (int): The k in “top-k”. - Returns: - torch.Tensor: Sampled tokens. - """ - top_k_value, _ = torch.topk(probs, k, dim=-1) - min_value_top_k = top_k_value[..., [-1]] - probs *= (probs >= min_value_top_k).float() - probs.div_(probs.sum(dim=-1, keepdim=True)) - next_token = multinomial(probs, num_samples=1) - return next_token - - -def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: - """Sample next token from top P probabilities along the last dimension of the input probs tensor. - - Args: - probs (torch.Tensor): Input probabilities with token candidates on the last dimension. - p (int): The p in “top-p”. - Returns: - torch.Tensor: Sampled tokens. - """ - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort *= (~mask).float() - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - return next_token - - -class DummyPoolExecutor: - """Dummy pool executor to use when we actually have only 1 worker. - (e.g. instead of ProcessPoolExecutor). - """ - class DummyResult: - def __init__(self, func, *args, **kwargs): - self.func = func - self.args = args - self.kwargs = kwargs - - def result(self): - return self.func(*self.args, **self.kwargs) - - def __init__(self, workers, mp_context=None): - pass - - def submit(self, func, *args, **kwargs): - return DummyPoolExecutor.DummyResult(func, *args, **kwargs) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, exc_tb): - return - - -def get_pool_executor(num_workers: int, mp_context=None): - return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1) - - -def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor: - """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences). - For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]] - - Args: - lengths (torch.Tensor): tensor with lengths - max_len (int): can set the max length manually. Defaults to None. - Returns: - torch.Tensor: mask with 0s where there is pad tokens else 1s - """ - assert len(lengths.shape) == 1, "Length shape should be 1 dimensional." - final_length = lengths.max().item() if not max_len else max_len - final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor - return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None] - - -def hash_trick(word: str, vocab_size: int) -> int: - """Hash trick to pair each word with an index - - Args: - word (str): word we wish to convert to an index - vocab_size (int): size of the vocabulary - Returns: - int: index of the word in the embedding LUT - """ - hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16) - return hash % vocab_size - - -def with_rank_rng(base_seed: int = 1234): - """Decorator for a function so that the function will use a Random Number Generator - whose state depend on the GPU rank. The original RNG state is restored upon returning. - - Args: - base_seed (int): Random seed. - """ - def _decorator(fun: tp.Callable): - @wraps(fun) - def _decorated(*args, **kwargs): - state = torch.get_rng_state() - seed = base_seed ^ flashy.distrib.rank() - torch.manual_seed(seed) - logger.debug('Rank dependent seed set to %d', seed) - try: - return fun(*args, **kwargs) - finally: - torch.set_rng_state(state) - logger.debug('RNG state restored.') - return _decorated - return _decorator - - -def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]: - """Get a list of tensors and collate them to a single tensor. according to the following logic: - - `dim` specifies the time dimension which will be stacked and padded. - - The output will contain 1 new dimension (dimension index 0) which will be the size of - of the original list. - - Args: - tensors (tp.List[torch.Tensor]): List of tensors to collate. - dim (int): Dimension which will be stacked and padded. - Returns: - tp.Tuple[torch.Tensor, torch.Tensor]: - torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension - (dimension index 0) which will be the size of the original list. - torch.Tensor: Tensor containing length of original tensor sizes (without padding). - """ - tensors = [x.transpose(0, dim) for x in tensors] - lens = torch.LongTensor([len(x) for x in tensors]) - padded_tensors = pad_sequence(tensors) - padded_tensors = padded_tensors.transpose(0, 1) - padded_tensors = padded_tensors.transpose(1, dim + 1) - return padded_tensors, lens - - -# TODO: Move to flashy? -def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu', - dtype: tp.Optional[torch.dtype] = None) -> tp.Any: - if isinstance(state, torch.Tensor): - if dtype is None or not state.is_floating_point(): - dtype = state.dtype - return state.detach().to(device=device, dtype=dtype, copy=True) - elif isinstance(state, dict): - return {k: copy_state(v, device, dtype) for k, v in state.items()} - elif isinstance(state, list): - return [copy_state(v, device, dtype) for v in state] - - -# TODO: Move to flashy? -@contextmanager -def swap_state(model, state, **kwargs): - old_state = copy_state(model.state_dict()) - model.load_state_dict(state, **kwargs) - try: - yield - finally: - model.load_state_dict(old_state) - - -@lru_cache(None) -def warn_once(logger, msg): - """Warn about a given message only once.""" - logger.warning(msg) - - -def is_jsonable(x: tp.Any): - """Check if an object can be serialized into a json:""" - try: - json.dumps(x) - return True - except (TypeError, OverflowError): - return False - - -def load_clap_state_dict(clap_model, path: tp.Union[str, Path]): - """Wrapper around state dict loading of CLAP model - addressing compatibility issues between CLAP and AudioCraft - HuggingFace transformer version. - See: https://github.com/LAION-AI/CLAP/issues/118 - """ - from clap_module.factory import load_state_dict # type: ignore - pkg = load_state_dict(path) - pkg.pop('text_branch.embeddings.position_ids', None) - clap_model.model.load_state_dict(pkg) diff --git a/audiocraft/config/conditioner/chord2music_inattn.yaml b/audiocraft/config/conditioner/chord2music_inattn.yaml deleted file mode 100644 index 2569e12f5843d78269f610a70a712f7ce636dd8e..0000000000000000000000000000000000000000 --- a/audiocraft/config/conditioner/chord2music_inattn.yaml +++ /dev/null @@ -1,45 +0,0 @@ -# @package __global__ - -classifier_free_guidance: - training_dropout: 0.2 - inference_coef: 3.0 - -attribute_dropout: - args: - active_on_eval: false - text: {} - chord: - chord: 0.5 - beat: - beat: 0.5 - -fuser: - cross_attention_pos_emb: false - cross_attention_pos_emb_scale: 1 - in_attn : true - sum: [chord, beat] - prepend: [chord, description] - cross: [] - input_interpolate: [] - -conditioners: - description: - model: t5 - t5: - name: t5-base - finetune: false - word_dropout: 0.2 - normalize_text: false - chord: - model: chord - chord: - name: chord - beat: - model: beat - beat: - name: beat -dataset: - train: - merge_text_p: 0.25 - drop_desc_p: 0.5 - drop_other_p: 0.5 diff --git a/audiocraft/config/conditioner/chroma2music.yaml b/audiocraft/config/conditioner/chroma2music.yaml deleted file mode 100644 index 91d37e758ef183678cff3f7a880b6bab2e36b03c..0000000000000000000000000000000000000000 --- a/audiocraft/config/conditioner/chroma2music.yaml +++ /dev/null @@ -1,46 +0,0 @@ -# @package __global__ - -classifier_free_guidance: - training_dropout: 0.2 - inference_coef: 3.0 - -attribute_dropout: - args: - active_on_eval: false - text: {} - wav: - self_wav: 0.5 - -fuser: - cross_attention_pos_emb: false - cross_attention_pos_emb_scale: 1 - sum: [] - prepend: [self_wav, description] - cross: [] - input_interpolate: [] - -conditioners: - self_wav: - model: chroma_stem - chroma_stem: - sample_rate: ${sample_rate} - n_chroma: 12 - radix2_exp: 14 - argmax: true - match_len_on_eval: false - eval_wavs: null - n_eval_wavs: 100 - cache_path: null - description: - model: t5 - t5: - name: t5-base - finetune: false - word_dropout: 0.2 - normalize_text: false - -dataset: - train: - merge_text_p: 0.25 - drop_desc_p: 0.5 - drop_other_p: 0.5 diff --git a/audiocraft/config/conditioner/chroma_text2music.yaml b/audiocraft/config/conditioner/chroma_text2music.yaml deleted file mode 100644 index 3a2b685ab82c14a8bfa1e603b9d1f69af29fbd0b..0000000000000000000000000000000000000000 --- a/audiocraft/config/conditioner/chroma_text2music.yaml +++ /dev/null @@ -1,46 +0,0 @@ -# @package __global__ - -classifier_free_guidance: - training_dropout: 0.2 - inference_coef: 3.0 - -attribute_dropout: - args: - active_on_eval: false - text: {} - wav: - self_wav: 0.5 - -fuser: - cross_attention_pos_emb: false - cross_attention_pos_emb_scale: 1 - sum: [] - prepend: [self_wav] - cross: [description] - input_interpolate: [] - -conditioners: - self_wav: - model: chroma_stem - chroma_stem: - sample_rate: ${sample_rate} - n_chroma: 12 - radix2_exp: 14 - argmax: true - match_len_on_eval: false - eval_wavs: null - n_eval_wavs: 100 - cache_path: null - description: - model: t5 - t5: - name: t5-base - finetune: false - word_dropout: 0.2 - normalize_text: false - -dataset: - train: - merge_text_p: 0.25 - drop_desc_p: 0.5 - drop_other_p: 0.5 diff --git a/audiocraft/config/conditioner/clapemb2music.yaml b/audiocraft/config/conditioner/clapemb2music.yaml deleted file mode 100644 index 8500a826e7379b4a8baaf67570e233f7bac7e5da..0000000000000000000000000000000000000000 --- a/audiocraft/config/conditioner/clapemb2music.yaml +++ /dev/null @@ -1,44 +0,0 @@ -# @package __global__ - -classifier_free_guidance: - training_dropout: 0.3 - inference_coef: 3.0 - -attribute_dropout: - text: {} - wav: {} - -fuser: - cross_attention_pos_emb: false - cross_attention_pos_emb_scale: 1 - sum: [] - prepend: [] - cross: [description] - input_interpolate: [] - -conditioners: - description: - model: clap - clap: - checkpoint: //reference/clap/music_audioset_epoch_15_esc_90.14.pt - model_arch: 'HTSAT-base' - enable_fusion: false - sample_rate: 44100 - max_audio_length: 10 - audio_stride: 1 - dim: 512 - attribute: description - normalize: true - quantize: true # use RVQ quantization - n_q: 12 - bins: 1024 - kmeans_iters: 50 - text_p: 0. # probability of using text embed at train time - cache_path: null - -dataset: - joint_embed_attributes: [description] - train: - merge_text_p: 0.25 - drop_desc_p: 0.5 - drop_other_p: 0.5 diff --git a/audiocraft/config/conditioner/none.yaml b/audiocraft/config/conditioner/none.yaml deleted file mode 100644 index c8e33156281e2af7616307da5c05b8094ee012e0..0000000000000000000000000000000000000000 --- a/audiocraft/config/conditioner/none.yaml +++ /dev/null @@ -1,20 +0,0 @@ -# @package __global__ - -# No conditioning - -classifier_free_guidance: - training_dropout: 0 - inference_coef: 1 - -attribute_dropout: - text: {} - wav: {} - -fuser: - sum: [] - concat: [] - prepend: [] - cross: [] - input_interpolate: [] - -conditioners: null diff --git a/audiocraft/config/conditioner/text2music.yaml b/audiocraft/config/conditioner/text2music.yaml deleted file mode 100644 index 2d0fe6cfa3fb33bcdb4f9fd16bd5ab4034c68b7b..0000000000000000000000000000000000000000 --- a/audiocraft/config/conditioner/text2music.yaml +++ /dev/null @@ -1,30 +0,0 @@ -# @package __global__ - -classifier_free_guidance: - training_dropout: 0.3 - inference_coef: 3.0 - -attribute_dropout: {} - -fuser: - cross_attention_pos_emb: false - cross_attention_pos_emb_scale: 1 - sum: [] - prepend: [] - cross: [description] - input_interpolate: [] - -conditioners: - description: - model: t5 - t5: - name: t5-base - finetune: false - word_dropout: 0.3 - normalize_text: false - -dataset: - train: - merge_text_p: 0.25 - drop_desc_p: 0.5 - drop_other_p: 0.5 diff --git a/audiocraft/config/conditioner/text2sound.yaml b/audiocraft/config/conditioner/text2sound.yaml deleted file mode 100644 index 555d4b7c3cecf0ec06c8cb25440b2f426c098ad2..0000000000000000000000000000000000000000 --- a/audiocraft/config/conditioner/text2sound.yaml +++ /dev/null @@ -1,24 +0,0 @@ -# @package __global__ - -classifier_free_guidance: - training_dropout: 0.1 - inference_coef: 3.0 - -attribute_dropout: {} - -fuser: - cross_attention_pos_emb: false - cross_attention_pos_emb_scale: 1 - sum: [] - prepend: [] - cross: [description] - input_interpolate: [] - -conditioners: - description: - model: t5 - t5: - name: t5-large - finetune: false - word_dropout: 0. - normalize_text: false diff --git a/audiocraft/config/config.yaml b/audiocraft/config/config.yaml deleted file mode 100644 index 6b0b7866eafac173fe7b056ad5920be1df57a947..0000000000000000000000000000000000000000 --- a/audiocraft/config/config.yaml +++ /dev/null @@ -1,75 +0,0 @@ -# WARNING: This is the base configuration file shared across ALL solvers in AudioCraft -# Please don't update this file directly. Instead use distinct configuration files -# to override the below configuration. -defaults: - - _self_ - - dset: default - - solver: default - -device: cuda -dtype: float32 -autocast: false -autocast_dtype: bfloat16 -seed: 2036 -show: false # just show the model and its size and exit -continue_from: # continue from a given sig or path -execute_only: # can be set to generate/evaluate/valid to run that stage -execute_inplace: false # don't enforce continue_from to be set - # to enable inplace execution of the stage. This assume - # that you know what you are doing and execute stage - # preserving the original xp sig. -benchmark_no_load: false # if set to true, will repeat the same batch instead of loading them - -efficient_attention_backend: torch # can be torch or xformers. -num_threads: 1 # called with torch.set_num_thread. -mp_start_method: forkserver # multiprocessing method (spawn, fork or fork_server). - - -label: # use this if you want twice the same exp, with a name. - -# logging parameters -logging: - level: INFO - log_updates: 10 - log_tensorboard: false - log_wandb: false -tensorboard: - with_media_logging: false - name: # optional name for the experiment - sub_dir: # optional sub directory to store tensorboard data -wandb: - with_media_logging: true - project: # project name - name: # optional name for the experiment - group: # optional group - -# SLURM launcher configuration. -slurm: - gpus: 4 # convenience parameter, number of GPUs to use. - mem_per_gpu: 40 # in GB, total mem is automatically scaled with `gpus`. - time: 3600 - constraint: - partition: - comment: - setup: [] - exclude: '' - -# dora parameters -dora: - # Output folder for all artifacts of an experiment. - dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs - # The following entries will be ignored by dora when computing the unique XP signature. - # Note that slurm.* and dora.* are automatically ignored. - exclude: [ - 'device', 'wandb.*', 'tensorboard.*', 'logging.*', - 'dataset.num_workers', 'eval.num_workers', 'special.*', - 'metrics.visqol.bin', 'metrics.fad.bin', - 'execute_only', 'execute_best', 'generate.every', - 'optim.eager_sync', 'profiler.*', 'deadlock.*', - 'efficient_attention_backend', 'num_threads', 'mp_start_method', - ] - use_rendezvous: false - # for grids, always run from a clean repo, allowing reliable runs and storing - # the exact commit. Your repo must be absolutely pristine clean. - # Local `dora run` are not impacted for easier debugging. - git_save: true diff --git a/audiocraft/config/dset/audio/audiocaps_16khz.yaml b/audiocraft/config/dset/audio/audiocaps_16khz.yaml deleted file mode 100644 index 14f5d6a4fcbf4426b7987d4427ca2d98d17d6c5b..0000000000000000000000000000000000000000 --- a/audiocraft/config/dset/audio/audiocaps_16khz.yaml +++ /dev/null @@ -1,11 +0,0 @@ -# @package __global__ - -# AudioCaps dataset -datasource: - max_sample_rate: 16000 - max_channels: 1 - - train: null # only evaluation set - valid: null # only evaluation set - evaluate: egs/audiocaps/audiocaps_16khz - generate: egs/audiocaps/audiocaps_16khz # identical to evaluate diff --git a/audiocraft/config/dset/audio/default.yaml b/audiocraft/config/dset/audio/default.yaml deleted file mode 100644 index 80be23e999c6366cc89ebcf55af6b958c0e45158..0000000000000000000000000000000000000000 --- a/audiocraft/config/dset/audio/default.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# @package __global__ - -datasource: - max_sample_rate: ??? - max_channels: ??? - - train: ??? - valid: ??? - evaluate: ??? - generate: null diff --git a/audiocraft/config/dset/audio/example.yaml b/audiocraft/config/dset/audio/example.yaml deleted file mode 100644 index d559d6d79a1cc05a82bb09f267c446258ef9ca55..0000000000000000000000000000000000000000 --- a/audiocraft/config/dset/audio/example.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# @package __global__ - -datasource: - max_sample_rate: 44100 - max_channels: 2 - - train: egs/example - valid: egs/example - evaluate: egs/example - generate: egs/example diff --git a/audiocraft/config/dset/audio/train.yaml b/audiocraft/config/dset/audio/train.yaml deleted file mode 100644 index df915cd6ee51ae2af4f413e68e6570a7a73ef770..0000000000000000000000000000000000000000 --- a/audiocraft/config/dset/audio/train.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# @package __global__ - -datasource: - max_sample_rate: 44100 - max_channels: 2 - - train: egs/YT_backing_tracks_0615 - valid: egs/YT_backing_tracks_0615 - evaluate: egs/YT_backing_tracks_0615 - generate: egs/YT_backing_tracks_0615 \ No newline at end of file diff --git a/audiocraft/config/dset/audio/train_backing.yaml b/audiocraft/config/dset/audio/train_backing.yaml deleted file mode 100644 index 8da9fa930eba5a27c9955a33dd27e88a4a8f76e6..0000000000000000000000000000000000000000 --- a/audiocraft/config/dset/audio/train_backing.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# @package __global__ - -datasource: - max_sample_rate: 48000 - max_channels: 2 - - train: egs/5_genre_backing - valid: egs/musdb_valid - evaluate: egs/musdb_valid - generate: egs/musdb_valid \ No newline at end of file diff --git a/audiocraft/config/dset/default.yaml b/audiocraft/config/dset/default.yaml deleted file mode 100644 index b5d730130e090b38a42984a8a87e1eea01cbf031..0000000000000000000000000000000000000000 --- a/audiocraft/config/dset/default.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# @package __global__ - -# WARNING: This is a base configuration file shared across ALL solvers in AudioCraft -# Please don't update this file directly. Instead use distinct configuration files -# to override the below configuration. -datasource: - train: ??? - valid: ??? - evaluate: ??? - generate: ??? diff --git a/audiocraft/config/dset/internal/music_10k_32khz.yaml b/audiocraft/config/dset/internal/music_10k_32khz.yaml deleted file mode 100644 index 036628abfeaa89279790547bbb5b3ee9dd69cea3..0000000000000000000000000000000000000000 --- a/audiocraft/config/dset/internal/music_10k_32khz.yaml +++ /dev/null @@ -1,11 +0,0 @@ -# @package __global__ - -# high quality music dataset with no artist overlap between splits -datasource: - max_sample_rate: 32000 - max_channels: 1 - - train: egs/music/music_10k_32khz/train - valid: egs/music/music_10k_32khz/valid - evaluate: egs/music/music_10k_32khz/test - generate: egs/music/music_10k_32khz/test # identical to evaluate diff --git a/audiocraft/config/dset/internal/music_400k_32khz.yaml b/audiocraft/config/dset/internal/music_400k_32khz.yaml deleted file mode 100644 index 7786880ab9c0464a0423d906c18d62bdf7194463..0000000000000000000000000000000000000000 --- a/audiocraft/config/dset/internal/music_400k_32khz.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# @package __global__ - -datasource: - max_sample_rate: 32000 - max_channels: 1 - - train: egs/music/music_400k_32khz/train - valid: egs/music/music_400k_32khz/valid - evaluate: egs/music/music_400k_32khz/test - generate: egs/music/music_400k_32khz/test # identical to evaluate diff --git a/audiocraft/config/dset/internal/sounds_16khz.yaml b/audiocraft/config/dset/internal/sounds_16khz.yaml deleted file mode 100644 index 4f3401a1b44ce300e22f3f64ef9c54d5c013c153..0000000000000000000000000000000000000000 --- a/audiocraft/config/dset/internal/sounds_16khz.yaml +++ /dev/null @@ -1,12 +0,0 @@ -# @package __global__ - -# environmental sounds dataset compiling all datasets -# with applied filters on tags -datasource: - max_sample_rate: 16000 - max_channels: 1 - - train: egs/sound/sounds_16khz/train - valid: egs/sound/sounds_16khz/valid - evaluate: egs/sound/sounds_16khz/test - generate: egs/sound/sounds_16khz/test # identical to evaluate diff --git a/audiocraft/config/model/encodec/default.yaml b/audiocraft/config/model/encodec/default.yaml deleted file mode 100644 index ec62c6c8ef9a686890bdca8b8f27a2f1c232205d..0000000000000000000000000000000000000000 --- a/audiocraft/config/model/encodec/default.yaml +++ /dev/null @@ -1,54 +0,0 @@ -# @package __global__ - -compression_model: encodec - -encodec: - autoencoder: seanet - quantizer: rvq - sample_rate: ${sample_rate} - channels: ${channels} - causal: false - renormalize: false - -seanet: - dimension: 128 - channels: ${channels} - causal: ${encodec.causal} - n_filters: 32 - n_residual_layers: 1 - ratios: [8, 5, 4, 2] - activation: ELU - activation_params: {"alpha": 1.} - norm: weight_norm - norm_params: {} - kernel_size: 7 - residual_kernel_size: 3 - last_kernel_size: 7 - dilation_base: 2 - pad_mode: constant - true_skip: true - compress: 2 - lstm: 2 - disable_norm_outer_blocks: 0 - # Specific encoder or decoder params. - # You can also override any param for the encoder or decoder only - # by using Hydra `+param=` syntax, i.e.` - # `+seanet.decoder.n_filters=64`. - decoder: - trim_right_ratio: 1.0 - final_activation: null - final_activation_params: null - encoder: {} - -rvq: - n_q: 8 - q_dropout: false - bins: 1024 - decay: 0.99 - kmeans_init: true - kmeans_iters: 50 - threshold_ema_dead_code: 2 - orthogonal_reg_weight: 0.0 - orthogonal_reg_active_codes_only: false - -no_quant: {} diff --git a/audiocraft/config/model/encodec/encodec_base_causal.yaml b/audiocraft/config/model/encodec/encodec_base_causal.yaml deleted file mode 100644 index 3ca555bcdc69433f172915400bb71c3b63e68681..0000000000000000000000000000000000000000 --- a/audiocraft/config/model/encodec/encodec_base_causal.yaml +++ /dev/null @@ -1,11 +0,0 @@ -# @package __global__ - -defaults: - - encodec/default - -encodec: - causal: true - -rvq: - n_q: 32 - q_dropout: true diff --git a/audiocraft/config/model/encodec/encodec_large_nq4_s320.yaml b/audiocraft/config/model/encodec/encodec_large_nq4_s320.yaml deleted file mode 100644 index 5f2d77590afd8a81185358c705a6e42853e257c3..0000000000000000000000000000000000000000 --- a/audiocraft/config/model/encodec/encodec_large_nq4_s320.yaml +++ /dev/null @@ -1,13 +0,0 @@ -# @package __global__ - -defaults: - - encodec/default - -seanet: - # default ratios are [8, 5, 4, 2] - n_filters: 64 - -rvq: - bins: 2048 - n_q: 4 - q_dropout: false diff --git a/audiocraft/config/model/encodec/encodec_large_nq4_s640.yaml b/audiocraft/config/model/encodec/encodec_large_nq4_s640.yaml deleted file mode 100644 index 3fcb7e87f4f700554164b0a58e9927b2f96a2c5a..0000000000000000000000000000000000000000 --- a/audiocraft/config/model/encodec/encodec_large_nq4_s640.yaml +++ /dev/null @@ -1,13 +0,0 @@ -# @package __global__ - -defaults: - - encodec/default - -seanet: - ratios: [8, 5, 4, 4] - n_filters: 64 - -rvq: - bins: 2048 - n_q: 4 - q_dropout: false diff --git a/audiocraft/config/model/lm/audiogen_lm.yaml b/audiocraft/config/model/lm/audiogen_lm.yaml deleted file mode 100644 index 696f74620af193c12208ce66fdb93a37f8ea9d80..0000000000000000000000000000000000000000 --- a/audiocraft/config/model/lm/audiogen_lm.yaml +++ /dev/null @@ -1,36 +0,0 @@ -# @package __global__ - -defaults: - - lm/default - - override /conditioner: text2sound - - override /model/lm/model_scale: small # prefer this group to set model scale instead of transformer_lm keys directly - -lm_model: transformer_lm - -codebooks_pattern: - modeling: delay - delay: - delays: [0, 1, 2, 3] - flatten_first: 0 - empty_initial: 0 - unroll: - flattening: [0, 1, 2, 3] - delays: [0, 0, 0, 0] - music_lm: - group_by: 2 - valle: - delays: [0, 0, 0] - -transformer_lm: - n_q: 4 - card: 2048 - memory_efficient: true - bias_proj: false - bias_ff: false - bias_attn: false - norm_first: true - layer_scale: null - weight_init: gaussian - depthwise_init: current - zero_bias_init: true - attention_as_float32: false diff --git a/audiocraft/config/model/lm/default.yaml b/audiocraft/config/model/lm/default.yaml deleted file mode 100644 index 2d256ad14ef69d25d62c19b73599937c8546e79b..0000000000000000000000000000000000000000 --- a/audiocraft/config/model/lm/default.yaml +++ /dev/null @@ -1,47 +0,0 @@ -# @package __global__ -defaults: - - _self_ - - /model/lm/model_scale: base # prefer this group to set model scale instead of transformer_lm keys directly - -lm_model: transformer_lm - -codebooks_pattern: - modeling: parallel - -transformer_lm: - dim: 512 - num_heads: 8 - num_layers: 8 - hidden_scale: 4 - n_q: 8 # number of streams to model - card: 1024 - dropout: 0. - emb_lr: null - activation: gelu - norm_first: false # use pre-norm instead of post-norm - bias_ff: true # use bias for the feedforward - bias_attn: true # use bias for the attention - bias_proj: true # use bias for the output projections - past_context: null - causal: true - custom: false # use custom MHA implementation - memory_efficient: false # use flash attention - attention_as_float32: false # use float32 for the attention part, - # recommended at the moment when memory_efficient is True. - layer_scale: null - positional_embedding: sin # positional embedding strategy (sin, rope, or sin_rope). - xpos: false # apply xpos decay (rope only). - checkpointing: none # layer checkpointing method, can be none, torch, xformers_default. - # torch is the slowest but uses the least memory, - # xformers_default is somewhere in between. - weight_init: null # weight initialization (null, gaussian or uniform) - depthwise_init: null # perform depthwise initialization (null, current, global) - zero_bias_init: false # initialize bias to zero if bias in linears and - # if a weight_init method is used. - norm: layer_norm # normalization method to use in transformer. - cross_attention: false - qk_layer_norm: false - qk_layer_norm_cross: false - attention_dropout: null - kv_repeat: 1 - two_step_cfg: false # whether to do true 2 steps CFG, potentially resolving some padding issues or not... diff --git a/audiocraft/config/model/lm/model_scale/base.yaml b/audiocraft/config/model/lm/model_scale/base.yaml deleted file mode 100644 index 3da88d2305e4c380435de1a3eecfe311ecfc82f9..0000000000000000000000000000000000000000 --- a/audiocraft/config/model/lm/model_scale/base.yaml +++ /dev/null @@ -1,3 +0,0 @@ -# @package __global__ - -# overrides nothing because default is already transformer base (~ 60M params) diff --git a/audiocraft/config/model/lm/model_scale/large.yaml b/audiocraft/config/model/lm/model_scale/large.yaml deleted file mode 100644 index d355bfb93618003ac8994bc093eb7bc96ac60114..0000000000000000000000000000000000000000 --- a/audiocraft/config/model/lm/model_scale/large.yaml +++ /dev/null @@ -1,7 +0,0 @@ -# @package _global_ - -# gpt2 inspired, even bigger (~3.3B params) -transformer_lm: - dim: 2048 - num_heads: 32 - num_layers: 48 diff --git a/audiocraft/config/model/lm/model_scale/medium.yaml b/audiocraft/config/model/lm/model_scale/medium.yaml deleted file mode 100644 index c825d1ff6c3b8cc9ae4959a898e14b40409d95e8..0000000000000000000000000000000000000000 --- a/audiocraft/config/model/lm/model_scale/medium.yaml +++ /dev/null @@ -1,7 +0,0 @@ -# @package _global_ - -# gpt2 like (~1.5B params) -transformer_lm: - dim: 1536 - num_heads: 24 - num_layers: 48 diff --git a/audiocraft/config/model/lm/model_scale/medium_small.yaml b/audiocraft/config/model/lm/model_scale/medium_small.yaml deleted file mode 100644 index 8debdc58182e340dc19ec6fc1c345d15de9d0e46..0000000000000000000000000000000000000000 --- a/audiocraft/config/model/lm/model_scale/medium_small.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _global_ - -# ???M - -transformer_lm: - dim: 1280 - num_heads: 20 - num_layers: 36 diff --git a/audiocraft/config/model/lm/model_scale/small.yaml b/audiocraft/config/model/lm/model_scale/small.yaml deleted file mode 100644 index 88d89cb5ac1b183fb3a9092834cea83aa16c70a8..0000000000000000000000000000000000000000 --- a/audiocraft/config/model/lm/model_scale/small.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _global_ - -# 300M Param. - -transformer_lm: - dim: 1024 - num_heads: 16 - num_layers: 24 diff --git a/audiocraft/config/model/lm/model_scale/xsmall.yaml b/audiocraft/config/model/lm/model_scale/xsmall.yaml deleted file mode 100644 index e98d4370d4fe7497f12aeb58f092a88797d1afa1..0000000000000000000000000000000000000000 --- a/audiocraft/config/model/lm/model_scale/xsmall.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _global_ -# just used for debugging or when we just want to populate the cache -# and do not care about training. - -transformer_lm: - dim: 64 - num_heads: 2 - num_layers: 2 diff --git a/audiocraft/config/model/lm/musicgen_lm.yaml b/audiocraft/config/model/lm/musicgen_lm.yaml deleted file mode 100644 index 5bc87a628789a34e381e2aa8ba5ef6ed780669d7..0000000000000000000000000000000000000000 --- a/audiocraft/config/model/lm/musicgen_lm.yaml +++ /dev/null @@ -1,36 +0,0 @@ -# @package __global__ - -defaults: - - lm/default - - override /conditioner: text2music - - override /model/lm/model_scale: small # prefer this group to set model scale instead of transformer_lm keys directly - -lm_model: transformer_lm - -codebooks_pattern: - modeling: delay - delay: - delays: [0, 1, 2, 3] - flatten_first: 0 - empty_initial: 0 - unroll: - flattening: [0, 1, 2, 3] - delays: [0, 0, 0, 0] - music_lm: - group_by: 2 - valle: - delays: [0, 0, 0] - -transformer_lm: - n_q: 4 - card: 2048 - memory_efficient: true - bias_proj: false - bias_ff: false - bias_attn: false - norm_first: true - layer_scale: null - weight_init: gaussian - depthwise_init: current - zero_bias_init: true - attention_as_float32: false diff --git a/audiocraft/config/model/none.yaml b/audiocraft/config/model/none.yaml deleted file mode 100644 index 1d4169f468d462c794ee6ed25017c3d78ae45d06..0000000000000000000000000000000000000000 --- a/audiocraft/config/model/none.yaml +++ /dev/null @@ -1,4 +0,0 @@ -# @package __global__ - -# This file exist so that model is recognized as a config group -# by Hydra, and Dora. A bit weird we might need a better fix someday. diff --git a/audiocraft/config/model/score/basic.yaml b/audiocraft/config/model/score/basic.yaml deleted file mode 100644 index 75fbc3783942602beaddaa38d0aca977aeee2dda..0000000000000000000000000000000000000000 --- a/audiocraft/config/model/score/basic.yaml +++ /dev/null @@ -1,17 +0,0 @@ -# @package _global_ - -diffusion_unet: - hidden: 48 - depth: 4 - res_blocks: 1 - norm_groups: 4 - kernel: 8 - stride: 4 - growth: 4 - max_channels: 10_000 - dropout: 0. - emb_all_layers: true - bilstm: false - codec_dim: null - transformer: false - cross_attention: false \ No newline at end of file diff --git a/audiocraft/config/solver/audiogen/audiogen_base_16khz.yaml b/audiocraft/config/solver/audiogen/audiogen_base_16khz.yaml deleted file mode 100644 index dd6aee785c74db19ce9d6f488e68e6eeb471c026..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/audiogen/audiogen_base_16khz.yaml +++ /dev/null @@ -1,70 +0,0 @@ -# @package __global__ - -# This is the training loop solver -# for the base AudioGen model (text-to-sound) -# on monophonic audio sampled at 16 kHz -# using a similar EnCodec+LM setup to MusicGen -defaults: - - audiogen/default - - /model: lm/audiogen_lm - - override /dset: audio/default - - _self_ - -autocast: true -autocast_dtype: float16 - -# EnCodec large trained on mono-channel music audio sampled at 16khz -# with a total stride of 320 leading to 50 frames/s. -# rvq.n_q=4, rvq.bins=2048, no quantization dropout -# (transformer_lm card and n_q must be compatible) -compression_model_checkpoint: //reference/bd44a852/checkpoint.th - -channels: 1 -sample_rate: 16000 - -deadlock: - use: true # deadlock detection - -dataset: - batch_size: 128 # matching AudioGen paper setup (256 * mix_p=0.5 = 128) - num_workers: 10 - segment_duration: 10 - min_segment_ratio: 1.0 - sample_on_weight: false # Uniform sampling all the way - sample_on_duration: false # Uniform sampling all the way - external_metadata_source: null - # sample mixing augmentation at train time - train: - batch_size: 256 # matching AudioGen paper setup - aug_p: 0.5 # perform audio mixing 50% of the time - mix_p: 0.5 # proportion of batch items mixed together - # important: note that this will reduce the - # actual batch size used at train time - # which will be equal to mix_p * batch_size - mix_snr_low: -5 - mix_snr_high: 5 - mix_min_overlap: 0.5 - -generate: - lm: - use_sampling: true - top_k: 250 - top_p: 0.0 - -optim: - epochs: 100 - optimizer: adamw - lr: 5e-4 - ema: - use: true - updates: 10 - device: cuda - -logging: - log_tensorboard: true - -schedule: - lr_scheduler: inverse_sqrt - inverse_sqrt: - warmup: 3000 - warmup_init_lr: 0.0 diff --git a/audiocraft/config/solver/audiogen/debug.yaml b/audiocraft/config/solver/audiogen/debug.yaml deleted file mode 100644 index fbda8281c6d552d9445e04fee498641a26549aa5..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/audiogen/debug.yaml +++ /dev/null @@ -1,52 +0,0 @@ -# @package __global__ - -# This is a minimal debugging configuration -# for MusicGen training solver -defaults: - - audiogen/default - - /model: lm/audiogen_lm - - override /model/lm/model_scale: xsmall - - override /dset: audio/example - - _self_ - -autocast: false -compression_model_checkpoint: null - -codebooks_pattern: - modeling: parallel - -channels: 1 -sample_rate: 16000 - -deadlock: - use: false # deadlock detection - -dataset: - batch_size: 4 - segment_duration: 5 - sample_on_weight: false # Uniform sampling all the way - sample_on_duration: false # Uniform sampling all the way - -generate: - audio: - strategy: peak - lm: - use_sampling: false - top_k: 0 - top_p: 0.0 - -checkpoint: - save_every: 0 - keep_last: 0 - -optim: - epochs: 2 - updates_per_epoch: 10 - optimizer: adamw - lr: 1e-4 - -logging: - log_tensorboard: true - -schedule: - lr_scheduler: null diff --git a/audiocraft/config/solver/audiogen/default.yaml b/audiocraft/config/solver/audiogen/default.yaml deleted file mode 100644 index afee63c65e0dd7350e3e89d2133bbca221d17631..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/audiogen/default.yaml +++ /dev/null @@ -1,40 +0,0 @@ -# @package __global__ - -defaults: - - /solver/musicgen/default - - _self_ - - /solver/audiogen/evaluation: none - - override /dset: audio/default - -# See config/solver/musicgen/default.yaml for a list of possible values. -# We only keep the most important here. - -autocast: true -autocast_dtype: float16 - -solver: audiogen -sample_rate: ??? -channels: ??? -compression_model_checkpoint: ??? - -tokens: - padding_with_special_token: false - -dataset: - batch_size: 128 - segment_duration: 10 - min_segment_ratio: 1.0 # lower values such as 0.5 result in generations with a lot of silence. - -optim: - epochs: 100 - updates_per_epoch: 2000 - lr: 1e-4 - optimizer: adamw - max_norm: 1.0 - adam: - betas: [0.9, 0.95] - weight_decay: 0.1 - eps: 1e-8 - -schedule: - lr_scheduler: null diff --git a/audiocraft/config/solver/audiogen/evaluation/none.yaml b/audiocraft/config/solver/audiogen/evaluation/none.yaml deleted file mode 100644 index 1e739995ed6488700527529862a7a24f1afdcc7a..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/audiogen/evaluation/none.yaml +++ /dev/null @@ -1,5 +0,0 @@ -# @package __global__ - -dataset: - evaluate: - num_samples: 10000 diff --git a/audiocraft/config/solver/audiogen/evaluation/objective_eval.yaml b/audiocraft/config/solver/audiogen/evaluation/objective_eval.yaml deleted file mode 100644 index 32fcc10033f3c3ff317216fe2876c65c6834e59b..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/audiogen/evaluation/objective_eval.yaml +++ /dev/null @@ -1,29 +0,0 @@ -# @package __global__ - -# Setup for execute only on audiocaps for audio generation -# evaluation with objective metrics -# execute_only=evaluate - -dataset: - max_audio_duration: null - # ensure the proper values are broadcasted here for evaluate - evaluate: - min_audio_duration: 1. # some metrics requires a minimum audio length - max_audio_duration: null # all samples from audiocaps should be ~10s - num_samples: null - segment_duration: null - generate: - min_audio_duration: 1. - max_audio_duration: null - num_samples: 500 - -evaluate: - metrics: - fad: true - kld: true - text_consistency: true - -metrics: - kld: - passt: - pretrained_length: 10 # similarly to reported results in AudioGen paper diff --git a/audiocraft/config/solver/compression/debug.yaml b/audiocraft/config/solver/compression/debug.yaml deleted file mode 100644 index 54dac175278d4ff509b0e44905d6b6195441f2c6..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/compression/debug.yaml +++ /dev/null @@ -1,55 +0,0 @@ -# @package __global__ - -defaults: - - compression/default - - /model: encodec/encodec_base_causal - - override /dset: audio/example - - _self_ - -channels: 1 -sample_rate: 16000 - -# debug config uses just L1 -losses: - adv: 0. - feat: 0. - l1: 1. - mel: 0. - msspec: 0. -# no balancer -balancer: - balance_grads: false - ema_decay: 1. - total_norm: 1. - per_batch_item: false -# no adversaries -adversarial: - adversaries: [] - adv_loss: hinge - feat_loss: l1 - -# faster model for local dev -seanet: - dimension: 16 - n_filters: 4 - -# very small dataset -dataset: - batch_size: 8 - num_workers: 10 - num_samples: 100 - segment_duration: 1 - evaluate: - batch_size: 32 - generate: - batch_size: 1 - num_samples: 5 - segment_duration: 10 - -# limited training -evaluate: - every: 5 -generate: - every: 5 -optim: - epochs: 50 diff --git a/audiocraft/config/solver/compression/default.yaml b/audiocraft/config/solver/compression/default.yaml deleted file mode 100644 index 41c812ba9ff8afe7ee10302ad5b9f05b745877d9..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/compression/default.yaml +++ /dev/null @@ -1,160 +0,0 @@ -# @package __global__ - -defaults: - - ../default - - override /dset: audio/default - - _self_ - -solver: compression -sample_rate: ??? -channels: ??? - -# loss balancing -losses: - adv: 4. - feat: 4. - l1: 0.1 - mel: 0. - msspec: 2. - sisnr: 0. -balancer: - balance_grads: true - ema_decay: 0.999 - per_batch_item: true - total_norm: 1. - -adversarial: - every: 1 - adversaries: [msstftd] - adv_loss: hinge - feat_loss: l1 - -# losses hyperparameters -l1: {} -l2: {} -mrstft: - factor_sc: .5 - factor_mag: .5 - normalized: false -mel: - sample_rate: ${sample_rate} - n_fft: 1024 - hop_length: 256 - win_length: 1024 - n_mels: 64 - f_min: 64 - f_max: null - normalized: false - floor_level: 1e-5 -sisnr: - sample_rate: ${sample_rate} - segment: 5. -msspec: - sample_rate: ${sample_rate} - range_start: 6 - range_end: 11 - n_mels: 64 - f_min: 64 - f_max: null - normalized: true - alphas: false - floor_level: 1e-5 - -# metrics -metrics: - visqol: - mode: audio - bin: null # path to visqol install - model: tcdaudio14_aacvopus_coresv_svrnsim_n.68_g.01_c1.model # visqol v3 - -# adversaries hyperparameters -msstftd: - in_channels: 1 - out_channels: 1 - filters: 32 - norm: weight_norm - n_ffts: [1024, 2048, 512, 256, 128] - hop_lengths: [256, 512, 128, 64, 32] - win_lengths: [1024, 2048, 512, 256, 128] - activation: LeakyReLU - activation_params: {negative_slope: 0.3} -msd: - in_channels: 1 - out_channels: 1 - scale_norms: [spectral_norm, weight_norm, weight_norm] - kernel_sizes: [5, 3] - filters: 16 - max_filters: 1024 - downsample_scales: [4, 4, 4, 4] - inner_kernel_sizes: null - groups: [4, 4, 4, 4] - strides: null - paddings: null - activation: LeakyReLU - activation_params: {negative_slope: 0.3} -mpd: - in_channels: 1 - out_channels: 1 - periods: [2, 3, 5, 7, 11] - n_layers: 5 - kernel_size: 5 - stride: 3 - filters: 8 - filter_scales: 4 - max_filters: 1024 - activation: LeakyReLU - activation_params: {negative_slope: 0.3} - norm: weight_norm - -# data hyperparameters -dataset: - batch_size: 64 - num_workers: 10 - segment_duration: 1 - train: - num_samples: 500000 - valid: - num_samples: 10000 - evaluate: - batch_size: 32 - num_samples: 10000 - generate: - batch_size: 32 - num_samples: 50 - segment_duration: 10 - -# solver hyperparameters -evaluate: - every: 25 - num_workers: 5 - metrics: - visqol: false - sisnr: true -generate: - every: 25 - num_workers: 5 - audio: - sample_rate: ${sample_rate} - -# checkpointing schedule -checkpoint: - save_last: true - save_every: 25 - keep_last: 10 - keep_every_states: null - -# optimization hyperparameters -optim: - epochs: 200 - updates_per_epoch: 2000 - lr: 3e-4 - max_norm: 0. - optimizer: adam - adam: - betas: [0.5, 0.9] - weight_decay: 0. - ema: - use: true # whether to use EMA or not - updates: 1 # update at every step - device: ${device} # device for EMA, can be put on GPU if more frequent updates - decay: 0.99 # EMA decay value, if null, no EMA is used diff --git a/audiocraft/config/solver/compression/encodec_audiogen_16khz.yaml b/audiocraft/config/solver/compression/encodec_audiogen_16khz.yaml deleted file mode 100644 index 654deaa01ba9cace3f7144cc91921791c081b32a..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/compression/encodec_audiogen_16khz.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# @package __global__ - -defaults: - - compression/default - - /model: encodec/encodec_large_nq4_s320 - - override /dset: audio/default - - _self_ - -channels: 1 -sample_rate: 16000 diff --git a/audiocraft/config/solver/compression/encodec_base_24khz.yaml b/audiocraft/config/solver/compression/encodec_base_24khz.yaml deleted file mode 100644 index 018ad1cd61af84b616ad3088f055e8eaa36729eb..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/compression/encodec_base_24khz.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# @package __global__ - -defaults: - - compression/default - - /model: encodec/encodec_base_causal - - override /dset: audio/default - - _self_ - -channels: 1 -sample_rate: 24000 diff --git a/audiocraft/config/solver/compression/encodec_musicgen_32khz.yaml b/audiocraft/config/solver/compression/encodec_musicgen_32khz.yaml deleted file mode 100644 index eca4b90fb221372dace164fe59bb15822207a980..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/compression/encodec_musicgen_32khz.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# @package __global__ - -defaults: - - compression/default - - /model: encodec/encodec_large_nq4_s640 - - override /dset: audio/default - - _self_ - -channels: 1 -sample_rate: 32000 diff --git a/audiocraft/config/solver/default.yaml b/audiocraft/config/solver/default.yaml deleted file mode 100644 index 2981c54c7c56e234c27f1bbeeb6ebdf23c64e0ff..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/default.yaml +++ /dev/null @@ -1,109 +0,0 @@ -# @package __global__ - -# WARNING: This is a base configuration file shared across ALL solvers in AudioCraft -# Please don't update this file directly. Instead use distinct configuration files -# to override the below configuration. -solver: ??? - -fsdp: - use: false # should we use FSDP. - param_dtype: float16 # equivalent to autocast_dtype for FSDP. - reduce_dtype: float32 # gradient averaging dtype, float32 will give max stability. - buffer_dtype: float32 # dtype used for buffers, we don't have much buffers, so let's leave it. - sharding_strategy: shard_grad_op # can be shard_grad_op or full_shard. - # full_shard will use less memory but slower ?? - per_block: true # If True, uses nested FSDP. - -profiler: - enabled: false - -deadlock: - use: false - timeout: 600 - -dataset: - batch_size: ??? - num_workers: 10 - segment_duration: null - num_samples: null - return_info: false - shuffle: false - sample_on_duration: true - sample_on_weight: true - min_segment_ratio: 0.5 - train: - num_samples: null - shuffle: true - shuffle_seed: 0 # if you want to sample the data differently. - permutation_on_files: false - valid: - num_samples: null - evaluate: - num_samples: null - generate: - num_samples: null - return_info: true - -checkpoint: - save_last: true - save_every: null - keep_last: null - keep_every_states: null - -generate: - every: null - path: 'samples' - audio: - format: 'mp3' - strategy: 'clip' - sample_rate: null - lm: - use_sampling: false - temp: 1.0 - top_k: 0 - top_p: 0.0 -evaluate: - every: null - num_workers: 5 - truncate_audio: null - fixed_generation_duration: null # in secs - metrics: - base: true # run default evaluation (e.g. like train/valid stage) - -optim: - epochs: ??? - updates_per_epoch: null - lr: ??? - optimizer: ??? - adam: - betas: [0.9, 0.999] - weight_decay: 0. - ema: - use: false # whether to use EMA or not - updates: ${optim.updates_per_epoch} # frequency of updates of the EMA - device: cpu # device for EMA, can be put on GPU if more frequent updates - decay: 0.99 # EMA decay value, if null, no EMA is used - grad_accum_steps: 1 - -schedule: - lr_scheduler: null - step: - step_size: null - gamma: null - exponential: - lr_decay: null - cosine: - warmup: null - lr_min_ratio: 0.0 - cycle_length: 1.0 - polynomial_decay: - warmup: null - zero_lr_warmup_steps: 0 - end_lr: 0.0 - power: 1 - inverse_sqrt: - warmup: null - warmup_init_lr: 0.0 - linear_warmup: - warmup: null - warmup_init_lr: 0.0 diff --git a/audiocraft/config/solver/musicgen/debug.yaml b/audiocraft/config/solver/musicgen/debug.yaml deleted file mode 100644 index ec658f9d2fb0262cc8eab19d0cf333963c646a98..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/musicgen/debug.yaml +++ /dev/null @@ -1,55 +0,0 @@ -# @package __global__ - -# This is a minimal debugging configuration -# for MusicGen training solver -defaults: - - musicgen/default - - /model: lm/musicgen_lm - - override /model/lm/model_scale: xsmall - - override /dset: audio/example - - _self_ - -autocast: false -compression_model_checkpoint: //pretrained/debug_compression_model -transformer_lm: - n_q: 4 - card: 400 - -codebooks_pattern: - modeling: parallel - -channels: 1 -sample_rate: 32000 - -deadlock: - use: false # deadlock detection - -dataset: - batch_size: 4 - segment_duration: 5 - sample_on_weight: false # Uniform sampling all the way - sample_on_duration: false # Uniform sampling all the way - -generate: - audio: - strategy: peak - lm: - use_sampling: false - top_k: 0 - top_p: 0.0 - -checkpoint: - save_every: 0 - keep_last: 0 - -optim: - epochs: 2 - updates_per_epoch: 10 - optimizer: adamw - lr: 1e-4 - -logging: - log_tensorboard: true - -schedule: - lr_scheduler: null diff --git a/audiocraft/config/solver/musicgen/default.yaml b/audiocraft/config/solver/musicgen/default.yaml deleted file mode 100644 index 16dc85d1a8b64b03eb4d4dcad1ae71e39f23455f..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/musicgen/default.yaml +++ /dev/null @@ -1,120 +0,0 @@ -# @package __global__ - -defaults: - - /solver/default - - /conditioner: none - - _self_ - - /solver/musicgen/evaluation: none - - override /dset: audio/default - -autocast: true -autocast_dtype: float16 - -solver: musicgen -sample_rate: ??? -channels: ??? -compression_model_checkpoint: ??? - -tokens: - padding_with_special_token: false - -cache: - path: - write: false - write_shard: 0 - write_num_shards: 1 - - -dataset: - batch_size: 128 - num_workers: 10 - segment_duration: 30 - min_segment_ratio: 0.8 # lower values such as 0.5 result in generations with a lot of silence. - return_info: true - train: - num_samples: 1000000 # need a randomly large number here for AudioDataset - valid: - num_samples: 10000 - generate: - num_samples: 5 - -metrics: - fad: - use_gt: false - model: tf - tf: - bin: null # path to local frechet_audio_distance code - model_path: //reference/fad/vggish_model.ckpt - kld: - use_gt: false - model: passt - passt: - pretrained_length: 20 - text_consistency: - use_gt: false - model: clap - clap: - model_path: //reference/clap/music_audioset_epoch_15_esc_90.14.pt - model_arch: 'HTSAT-base' - enable_fusion: false - chroma_cosine: - use_gt: false - model: chroma_base - chroma_base: - sample_rate: ${sample_rate} - n_chroma: 12 - radix2_exp: 14 - argmax: true - -generate: - every: 25 - num_workers: 4 - path: samples - audio: - format: wav - strategy: loudness - sample_rate: ${sample_rate} - loudness_headroom_db: 14 - lm: - prompted_samples: true - unprompted_samples: true - gen_gt_samples: false - prompt_duration: null # if not set, will use dataset.generate.segment_duration / 4 - gen_duration: null # if not set, will use dataset.generate.segment_duration - remove_prompts: false - # generation params - use_sampling: false - temp: 1.0 - top_k: 0 - top_p: 0.0 -evaluate: - every: 25 - num_workers: 4 - metrics: - base: false - fad: false - kld: false - text_consistency: false - chroma_cosine: false - -checkpoint: - save_last: true - save_every: 25 - keep_last: 10 - keep_every_states: null - -optim: - epochs: 200 - updates_per_epoch: 2000 - lr: 1e-4 - optimizer: adamw - max_norm: 1.0 - eager_sync: true - adam: - betas: [0.9, 0.95] - weight_decay: 0.1 - eps: 1e-8 - grad_accum_steps: 1 - -schedule: - lr_scheduler: null diff --git a/audiocraft/config/solver/musicgen/dummy_train.yaml b/audiocraft/config/solver/musicgen/dummy_train.yaml deleted file mode 100644 index 40aa99997fb49ca606e88049ddc93882bd599ea0..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/musicgen/dummy_train.yaml +++ /dev/null @@ -1,65 +0,0 @@ -# @package __global__ - -# This is the training loop solver -# for the base MusicGen model (text-to-music) -defaults: - - musicgen/default - - /model: lm/musicgen_lm - - override /dset: audio/train_backing - - _self_ - -autocast: true -autocast_dtype: float16 - -# EnCodec large trained on mono-channel music audio sampled at 32khz -# with a total stride of 640 leading to 50 frames/s. -# rvq.n_q=4, rvq.bins=2048, no quantization dropout -# (transformer_lm card and n_q must be compatible) -compression_model_checkpoint: //pretrained/facebook/encodec_32khz - -channels: 1 -sample_rate: 32000 - -deadlock: - use: true # deadlock detection - -dataset: - batch_size: 8 # 1 GPU(A100) - num_workers: 8 - segment_duration: 30 - sample_on_weight: false # Uniform sampling all the way - sample_on_duration: false # Uniform sampling all the way - valid: - num_samples: 4 - -generate: - lm: - use_sampling: true - top_k: 250 - top_p: 0.0 - -checkpoint: - save_last: true - save_every: 25 - keep_every_states: null - -optim: - epochs: 1 - updates_per_epoch: 1 - optimizer: dadam - lr: 1e-32 - max_norm: 1.0 - ema: - use: false - updates: 10 - device: cuda - -logging: - log_tensorboard: false - -schedule: - lr_scheduler: cosine - cosine: - warmup: 0 - lr_min_ratio: 0.0 - cycle_length: 1.0 \ No newline at end of file diff --git a/audiocraft/config/solver/musicgen/evaluation/none.yaml b/audiocraft/config/solver/musicgen/evaluation/none.yaml deleted file mode 100644 index 1e739995ed6488700527529862a7a24f1afdcc7a..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/musicgen/evaluation/none.yaml +++ /dev/null @@ -1,5 +0,0 @@ -# @package __global__ - -dataset: - evaluate: - num_samples: 10000 diff --git a/audiocraft/config/solver/musicgen/evaluation/objective_eval.yaml b/audiocraft/config/solver/musicgen/evaluation/objective_eval.yaml deleted file mode 100644 index 4881e9d86cddf36b306a75fb498253e1e12ec5be..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/musicgen/evaluation/objective_eval.yaml +++ /dev/null @@ -1,24 +0,0 @@ -# @package __global__ - -# Setup for execute only on musiccaps for audio generation -# evaluation with objective metrics -# execute_only=evaluate - -dataset: - max_audio_duration: null - # ensure the proper values are broadcasted here for evaluate - evaluate: - min_audio_duration: 1. # some metrics requires a minimum audio length - max_audio_duration: null # all samples from musiccaps should be < 20s - num_samples: null - segment_duration: null - generate: - min_audio_duration: 1. - max_audio_duration: null - num_samples: 500 - -evaluate: - metrics: - fad: true - kld: true - text_consistency: true diff --git a/audiocraft/config/solver/musicgen/multigpu_finetune.yaml b/audiocraft/config/solver/musicgen/multigpu_finetune.yaml deleted file mode 100644 index fa1ee8a373cffb9290879275d9a7d29beb6a7cd1..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/musicgen/multigpu_finetune.yaml +++ /dev/null @@ -1,63 +0,0 @@ -# @package __global__ - -# This is the training loop solver -# for the base MusicGen model (text-to-music) -defaults: - - musicgen/default - - /model: lm/musicgen_lm - - _self_ - -autocast: true -autocast_dtype: float16 - -# EnCodec large trained on mono-channel music audio sampled at 32khz -# with a total stride of 640 leading to 50 frames/s. -# rvq.n_q=4, rvq.bins=2048, no quantization dropout -# (transformer_lm card and n_q must be compatible) -compression_model_checkpoint: //pretrained/facebook/encodec_32khz - -channels: 1 -sample_rate: 32000 - -deadlock: - use: true # deadlock detection - -dataset: - batch_size: 8 # 4 GPUs(3090) - num_workers: 8 - segment_duration: 30 - sample_on_weight: false # Uniform sampling all the way - sample_on_duration: false # Uniform sampling all the way - valid: - num_samples: 4 - -generate: - lm: - use_sampling: true - top_k: 250 - top_p: 0.0 - -checkpoint: - save_last: true - save_every: 25 - keep_every_states: null - -optim: - epochs: 100 - optimizer: dadam - lr: 1.0 - max_norm: 1.0 - ema: - use: false - updates: 10 - device: cuda - -logging: - log_tensorboard: true - -schedule: - lr_scheduler: cosine - cosine: - warmup: 5 - lr_min_ratio: 0.0 - cycle_length: 1.0 \ No newline at end of file diff --git a/audiocraft/config/solver/musicgen/musicgen_base_32khz.yaml b/audiocraft/config/solver/musicgen/musicgen_base_32khz.yaml deleted file mode 100644 index b32c9c898a70718f91af862caa79f5553a5107e1..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/musicgen/musicgen_base_32khz.yaml +++ /dev/null @@ -1,55 +0,0 @@ -# @package __global__ - -# This is the training loop solver -# for the base MusicGen model (text-to-music) -# on monophonic audio sampled at 32 kHz -defaults: - - musicgen/default - - /model: lm/musicgen_lm - - override /dset: audio/default - - _self_ - -autocast: true -autocast_dtype: float16 - -# EnCodec large trained on mono-channel music audio sampled at 32khz -# with a total stride of 640 leading to 50 frames/s. -# rvq.n_q=4, rvq.bins=2048, no quantization dropout -# (transformer_lm card and n_q must be compatible) -compression_model_checkpoint: //pretrained/facebook/encodec_32khz - -channels: 1 -sample_rate: 32000 - -deadlock: - use: true # deadlock detection - -dataset: - batch_size: 192 # 32 GPUs - sample_on_weight: false # Uniform sampling all the way - sample_on_duration: false # Uniform sampling all the way - -generate: - lm: - use_sampling: true - top_k: 250 - top_p: 0.0 - -optim: - epochs: 500 - optimizer: dadam - lr: 1 - ema: - use: true - updates: 10 - device: cuda - -logging: - log_tensorboard: true - -schedule: - lr_scheduler: cosine - cosine: - warmup: 4000 - lr_min_ratio: 0.0 - cycle_length: 1.0 diff --git a/audiocraft/config/solver/musicgen/single_finetune.yaml b/audiocraft/config/solver/musicgen/single_finetune.yaml deleted file mode 100644 index 902dee3ebddb34d7ee5b6cc9b60caff4b3b9b0c6..0000000000000000000000000000000000000000 --- a/audiocraft/config/solver/musicgen/single_finetune.yaml +++ /dev/null @@ -1,63 +0,0 @@ -# @package __global__ - -# This is the training loop solver -# for the base MusicGen model (text-to-music) -defaults: - - musicgen/default - - /model: lm/musicgen_lm - - _self_ - -autocast: true -autocast_dtype: float16 - -# EnCodec large trained on mono-channel music audio sampled at 32khz -# with a total stride of 640 leading to 50 frames/s. -# rvq.n_q=4, rvq.bins=2048, no quantization dropout -# (transformer_lm card and n_q must be compatible) -compression_model_checkpoint: //pretrained/facebook/encodec_32khz - -channels: 1 -sample_rate: 32000 - -deadlock: - use: true # deadlock detection - -dataset: - batch_size: 2 # 1 GPU(3090) - num_workers: 2 - segment_duration: 30 - sample_on_weight: false # Uniform sampling all the way - sample_on_duration: false # Uniform sampling all the way - valid: - num_samples: 4 - -generate: - lm: - use_sampling: true - top_k: 250 - top_p: 0.0 - -checkpoint: - save_last: true - save_every: 25 - keep_every_states: null - -optim: - epochs: 100 - optimizer: dadam - lr: 1.0 - max_norm: 1.0 - ema: - use: false - updates: 10 - device: cuda - -logging: - log_tensorboard: true - -schedule: - lr_scheduler: cosine - cosine: - warmup: 5 - lr_min_ratio: 0.0 - cycle_length: 1.0 \ No newline at end of file diff --git a/audiocraft/config/teams/default.yaml b/audiocraft/config/teams/default.yaml deleted file mode 100644 index 3e684c27a0bf23876323e64d766eb74913f685b8..0000000000000000000000000000000000000000 --- a/audiocraft/config/teams/default.yaml +++ /dev/null @@ -1,12 +0,0 @@ -default: - dora_dir: ./training_weights - partitions: - global: debug - team: debug - reference_dir: ./ -darwin: # if we detect we are on a Mac, then most likely we are doing unit testing etc. - dora_dir: ./training_weights - partitions: - global: debug - team: debug - reference_dir: ./ diff --git a/audiocraft/config/teams/labs.yaml b/audiocraft/config/teams/labs.yaml deleted file mode 100644 index da350a94bc5758531ced5d9e4332624fe86f3d57..0000000000000000000000000000000000000000 --- a/audiocraft/config/teams/labs.yaml +++ /dev/null @@ -1,28 +0,0 @@ -aws: - dora_dir: /fsx-audio-craft-llm/${oc.env:USER}/experiments/audiocraft/outputs - partitions: - global: learnlab - team: learnlab - reference_dir: /fsx-audio-craft-llm/shared/audiocraft/reference - dataset_mappers: - "^/checkpoint/[a-z]+": "/fsx-audio-craft-llm" -fair: - dora_dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs - partitions: - global: learnlab - team: learnlab - reference_dir: /large_experiments/audiocraft/reference - dataset_mappers: - "^/datasets01/datasets01": "/datasets01" -darwin: - dora_dir: /tmp/audiocraft_${oc.env:USER} - partitions: - global: debug - team: debug - reference_dir: /tmp -rsc: - dora_dir: /checkpoint/audiocraft/${oc.env:USER}/experiments/audiocraft/outputs - partitions: - global: learn - team: learn - reference_dir: /checkpoint/audiocraft/shared/reference diff --git a/audiocraft/dataset/example/clip/sample_1/beats.npy b/audiocraft/dataset/example/clip/sample_1/beats.npy deleted file mode 100644 index 0194428ecdf0fed5be17e112e6e4c4f9ac7a7cd7..0000000000000000000000000000000000000000 --- a/audiocraft/dataset/example/clip/sample_1/beats.npy +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:774a64a3e1bc8f704bebb961ab9ef43cdf20e07a6470149230a80691f0d6b1eb -size 784 diff --git a/audiocraft/dataset/example/clip/sample_1/chord.lab b/audiocraft/dataset/example/clip/sample_1/chord.lab deleted file mode 100644 index 390e4f55a1a9ff3b582901b0c9fefed27155d06f..0000000000000000000000000000000000000000 --- a/audiocraft/dataset/example/clip/sample_1/chord.lab +++ /dev/null @@ -1,22 +0,0 @@ -0.000 1.389 G -1.389 2.963 E:min7 -2.963 4.352 C -4.352 5.833 D -5.833 7.315 G -7.315 8.796 E:min7 -8.796 10.185 C -10.185 11.574 D -11.574 13.056 G -13.056 14.630 E:min7 -14.630 16.111 C -16.111 17.315 D -17.315 18.981 G -18.981 20.463 E:min7 -20.463 21.852 C -21.852 22.870 D -22.870 24.815 G -24.815 26.204 E:min7 -26.204 26.296 E:min -26.296 27.778 C -27.778 29.167 D -29.167 30.000 G diff --git a/audiocraft/dataset/example/clip/sample_1/no_vocal.wav b/audiocraft/dataset/example/clip/sample_1/no_vocal.wav deleted file mode 100644 index 9e738015b2202fbf01283b003509a4fcf51c30d5..0000000000000000000000000000000000000000 --- a/audiocraft/dataset/example/clip/sample_1/no_vocal.wav +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:64d92035567b0a88cedcdfaf828b7c4f38b11d9900e8acbe4f8f236f2edfc27f -size 5292044 diff --git a/audiocraft/dataset/example/clip/sample_1/tags.json b/audiocraft/dataset/example/clip/sample_1/tags.json deleted file mode 100644 index c55ad53a4afa32fbccb396f762c782b195dd2252..0000000000000000000000000000000000000000 --- a/audiocraft/dataset/example/clip/sample_1/tags.json +++ /dev/null @@ -1 +0,0 @@ -{"key": "", "artist": "", "sample_rate": 44100, "file_extension": "wav", "description": "chill song with guitar and drum", "keywords": "", "duration": 30.0, "bpm": "", "genre": "", "title": "", "name": "", "instrument": "Mix", "moods": [], "path": "dataset/example/sample_1/no_vocal.wav"} \ No newline at end of file diff --git a/audiocraft/dataset/example/clip/sample_2/beats.npy b/audiocraft/dataset/example/clip/sample_2/beats.npy deleted file mode 100644 index 8d21b2c8af07deb00ffe4c282a3ffb96fd38b10f..0000000000000000000000000000000000000000 --- a/audiocraft/dataset/example/clip/sample_2/beats.npy +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:26673d903483bfef082b9b84035b5db0b96b5b60887cae04841bec177ece54f5 -size 1136 diff --git a/audiocraft/dataset/example/clip/sample_2/chord.lab b/audiocraft/dataset/example/clip/sample_2/chord.lab deleted file mode 100644 index ab82b72148475c8cd51e459126d9706626598ffb..0000000000000000000000000000000000000000 --- a/audiocraft/dataset/example/clip/sample_2/chord.lab +++ /dev/null @@ -1,49 +0,0 @@ -0.000 0.648 E:min -0.648 0.741 E -0.741 1.204 F#:min -1.204 1.296 D -1.296 1.389 E:min -1.389 1.759 G -1.759 2.685 E:min -2.685 3.611 D -3.611 4.722 E:min -4.722 4.907 B:min -4.907 5.185 E:min -5.185 5.556 G -5.556 7.130 E:min -7.130 7.407 G -7.407 8.426 E:min -8.426 8.796 F#:min7 -8.796 8.981 E:min -8.981 9.352 G -9.352 10.185 E:min -10.185 10.833 F#:min7 -10.833 11.111 E:min -11.111 11.296 G -11.296 12.130 E:min -12.130 12.778 F#:min7 -12.778 13.056 E:min -13.056 13.148 G -13.148 14.167 E:min -14.167 14.537 F#:min7 -14.537 16.204 E:min -16.204 16.389 F#:min7 -16.389 19.074 E:min -19.074 19.259 A -19.259 20.000 A:min -20.000 20.370 N -20.370 21.111 G -21.111 21.852 E:min -21.852 22.315 F#:min7 -22.315 22.407 D -22.407 22.963 G -22.963 24.907 D -24.907 25.741 E:min -25.741 26.204 F#:min7 -26.204 26.296 E:min -26.296 26.759 G -26.759 27.593 E:min -27.593 28.148 F#:min7 -28.148 28.611 G -28.611 29.537 E:min -29.537 30.000 F#:min7 diff --git a/audiocraft/dataset/example/clip/sample_2/no_vocal.wav b/audiocraft/dataset/example/clip/sample_2/no_vocal.wav deleted file mode 100644 index 1352673b88c7544ecf413edc3d9bc659747e821c..0000000000000000000000000000000000000000 --- a/audiocraft/dataset/example/clip/sample_2/no_vocal.wav +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:180a41036fe7245cb34eb6a8de5cf630b93367d4c18d55d1b98b8e76fd2d81a9 -size 5292044 diff --git a/audiocraft/dataset/example/clip/sample_2/tags.json b/audiocraft/dataset/example/clip/sample_2/tags.json deleted file mode 100644 index ca1d22127ae18971bc15c4a555b4e5ed7fa204aa..0000000000000000000000000000000000000000 --- a/audiocraft/dataset/example/clip/sample_2/tags.json +++ /dev/null @@ -1 +0,0 @@ -{"key": "", "artist": "", "sample_rate": 44100, "file_extension": "wav", "description": "cool song from BKS", "keywords": "", "duration": 30.0, "bpm": "", "genre": "", "title": "", "name": "", "instrument": "Mix", "moods": [], "path": "dataset/example/sample_2/no_vocal.wav"} \ No newline at end of file diff --git a/audiocraft/egs/.DS_Store b/audiocraft/egs/.DS_Store deleted file mode 100644 index 57a55533b24c0913b16270b1e0331e8066b90fde..0000000000000000000000000000000000000000 Binary files a/audiocraft/egs/.DS_Store and /dev/null differ diff --git a/audiocraft/egs/example/data.jsonl b/audiocraft/egs/example/data.jsonl deleted file mode 100644 index b00f36b76ff0e9d8281513d85a278489b14cb08e..0000000000000000000000000000000000000000 --- a/audiocraft/egs/example/data.jsonl +++ /dev/null @@ -1,2 +0,0 @@ -{"path": "dataset/example/clip/sample_1/no_vocal.wav", "duration": 30.0, "sample_rate": 44100, "bpm": "", "amplitude": null, "weight": null, "info_path": null} -{"path": "dataset/example/clip/sample_2/no_vocal.wav", "duration": 30.0, "sample_rate": 44100, "bpm": "", "amplitude": null, "weight": null, "info_path": null} diff --git a/audiocraft/export_weight.py b/audiocraft/export_weight.py deleted file mode 100644 index 7f89e113e90946758e8c4f5975e64c6ad400e5a9..0000000000000000000000000000000000000000 --- a/audiocraft/export_weight.py +++ /dev/null @@ -1,12 +0,0 @@ -from audiocraft.utils import export -from audiocraft import train -import os -from pathlib import Path - -sig = "your_training_signature" -output_dir = "./ckpt/output_weight_dir" - - -folder = f"./audiocraft_default/xps/{sig}" -export.export_lm(Path(folder) / 'checkpoint.th', os.path.join(output_dir, 'state_dict.bin')) -export.export_pretrained_compression_model('facebook/encodec_32khz', os.path.join(output_dir, 'compression_state_dict.bin')) \ No newline at end of file diff --git a/audiocraft/generate_chord_beat.py b/audiocraft/generate_chord_beat.py deleted file mode 100644 index e34c879a1589ff394196e96f8e96bb049979add3..0000000000000000000000000000000000000000 --- a/audiocraft/generate_chord_beat.py +++ /dev/null @@ -1,49 +0,0 @@ -from audiocraft.data.audio import audio_write -import audiocraft.models -import numpy as np -import pandas as pd -import os -import torch - -# set hparams -output_dir = 'example_1' ### change this output directory - - -duration = 30 -num_samples = 5 -bs = 1 - - -# load your model -musicgen = audiocraft.models.MusicGen.get_pretrained('./ckpt/musicongen') ### change this path -musicgen.set_generation_params(duration=duration, extend_stride=duration//2, top_k = 250) - - -chords = ['C G A:min F', - 'A:min F C G', - 'C F G F', - 'C A:min F G', - 'D:min G C A:min', - ] - -descriptions = ["A laid-back blues shuffle with a relaxed tempo, warm guitar tones, and a comfortable groove, perfect for a slow dance or a night in. Instruments: electric guitar, bass, drums."] * num_samples - -bpms = [120] * num_samples - -meters = [4] * num_samples - -wav = [] -for i in range(num_samples//bs): - print(f"starting {i} batch...") - temp = musicgen.generate_with_chords_and_beats(descriptions[i*bs:(i+1)*bs], - chords[i*bs:(i+1)*bs], - bpms[i*bs:(i+1)*bs], - meters[i*bs:(i+1)*bs] - ) - wav.extend(temp.cpu()) - -# save and display generated audio -for idx, one_wav in enumerate(wav): - - sav_path = os.path.join('./output_samples', output_dir, chords[idx] + "|" + descriptions[idx]).replace(" ", "_") - audio_write(sav_path, one_wav.cpu(), musicgen.sample_rate, strategy='loudness', loudness_compressor=True) \ No newline at end of file