Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import inspect | |
import random | |
import typing as tp | |
from functools import partial | |
import julius | |
import omegaconf | |
import torch | |
from julius import fft_conv1d, resample_frac | |
from ..data.audio_utils import get_aac, get_mp3 | |
if tp.TYPE_CHECKING: | |
from ..models.encodec import CompressionModel | |
def select_audio_effects( | |
audio_effects: tp.Dict, | |
weights: tp.Optional[tp.Dict] = None, | |
mode: str = "all", | |
max_length: tp.Optional[int] = None, | |
): | |
"""Samples a subset of audio effects methods from the `AudioEffects` class. | |
This function allows you to select a subset of audio effects | |
based on the chosen selection mode and optional weights. | |
Args: | |
audio_effects (dict): A dictionary of available audio augmentations, usually | |
obtained from the output of the 'get_audio_effects' function. | |
weights (dict): A dictionary mapping augmentation names to their corresponding | |
probabilities of being selected. This argument is used when 'mode' is set | |
to "weighted." If 'weights' is None, all augmentations have equal | |
probability of being selected. | |
mode (str): The selection mode, which can be one of the following: | |
- "all": Select all available augmentations. | |
- "weighted": Select augmentations based on their probabilities in the | |
'weights' dictionary. | |
max_length (int): The maximum number of augmentations to select. If 'max_length' | |
is None, no limit is applied. | |
Returns: | |
dict: A subset of the 'audio_effects' dictionary containing the selected audio | |
augmentations. | |
Note: | |
- In "all" mode, all available augmentations are selected. | |
- In "weighted" mode, augmentations are selected with a probability | |
proportional to their weights specified in the 'weights' dictionary. | |
- If 'max_length' is set, the function limits the number of selected | |
augmentations. | |
- If no augmentations are selected or 'audio_effects' is empty, the function | |
defaults to including an "identity" augmentation. | |
- The "identity" augmentation means that no audio effect is applied. | |
""" | |
if mode == "all": # original code | |
out = audio_effects | |
elif mode == "weighted": | |
# Probability proportionnal to weights | |
assert weights is not None | |
out = { | |
name: value | |
for name, value in audio_effects.items() | |
if random.random() < weights.get(name, 1.0) | |
} | |
else: | |
raise ValueError(f"Unknown mode {mode}") | |
if max_length is not None: | |
# Help having a deterministic limit of the gpu memory usage | |
random_keys = random.sample(list(out.keys()), max_length) | |
out = {key: out[key] for key in random_keys} | |
if len(out) == 0: # Check not to return empty dict | |
out = {"identity": AudioEffects.identity} | |
return out | |
def get_audio_effects(cfg: omegaconf.DictConfig): | |
"""Automatically pull the list all effects available in this class based on the parameters from the cfg | |
Returns: | |
dict: A dict of names and pointers to all methods in this class. | |
""" | |
assert hasattr(cfg, "audio_effects") | |
cfg_audio_effects = dict(cfg["audio_effects"]) | |
return { | |
name: partial(value, **cfg_audio_effects.get(name, {})) | |
for name, value in inspect.getmembers(AudioEffects) | |
if inspect.isfunction(value) | |
} | |
def audio_effect_return( | |
tensor: torch.Tensor, mask: tp.Optional[torch.Tensor] | |
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: | |
"""Return the mask if it was in the input otherwise only the output tensor""" | |
if mask is None: | |
return tensor | |
else: | |
return tensor, mask | |
def generate_pink_noise(length: int) -> torch.Tensor: | |
"""Generate pink noise using Voss-McCartney algorithm with PyTorch.""" | |
num_rows = 16 | |
array = torch.randn(num_rows, length // num_rows + 1) | |
reshaped_array = torch.cumsum(array, dim=1) | |
reshaped_array = reshaped_array.reshape(-1) | |
reshaped_array = reshaped_array[:length] | |
# Normalize | |
pink_noise = reshaped_array / torch.max(torch.abs(reshaped_array)) | |
return pink_noise | |
def compress_with_encodec( | |
tensor: torch.Tensor, | |
n_q: int, | |
model: "CompressionModel", | |
sample_rate: int, | |
mask: tp.Optional[torch.Tensor] = None, | |
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: | |
"""Special augmentation function that compresses and decompresses wav tensor | |
using a compression model with the n_q codebooks | |
""" | |
model.to(tensor.device) | |
model.set_num_codebooks(n_q) | |
codes, scale = model.encode( | |
julius.resample_frac(tensor, old_sr=sample_rate, new_sr=model.sample_rate) | |
) | |
compressed = model.decode(codes=codes, scale=scale) | |
return audio_effect_return( | |
tensor=julius.resample_frac( | |
compressed, old_sr=model.sample_rate, new_sr=sample_rate | |
), | |
mask=mask, | |
) | |
def apply_compression_skip_grad(tensor: torch.Tensor, compression_fn, **kwargs): | |
"""Applies a specified compression function to the audio tensor. | |
Whire carrying over the grads to the output tensor with skip through estimator | |
this is a straight through estimator to make mp3/aac compression differentiable | |
see more: Yin et al. 2019 https://arxiv.org/pdf/1903.05662.pdf | |
Args: | |
tensor (torch.Tensor): The input audio tensor. | |
compression_fn (function): The compression function to apply. | |
**kwargs: Additional keyword arguments for the compression function. | |
Returns: | |
torch.Tensor: The output tensor after applying compression and straight through estimator. | |
""" | |
compressed = compression_fn(tensor.detach(), **kwargs) | |
# Trim compressed output if needed | |
compressed = compressed[:, :, : tensor.size(-1)] | |
# Straight through estimator for differentiable compression | |
out = tensor + (compressed - tensor).detach() | |
# Check that gradients are not broken | |
if out.requires_grad: | |
assert ( | |
out.grad_fn | |
), "The computation graph might be broken due to compression augmentation." | |
return out | |
class AudioEffects: | |
def speed( | |
tensor: torch.Tensor, | |
speed_range: tuple = (0.5, 1.5), | |
sample_rate: int = 16000, | |
mask: tp.Optional[torch.Tensor] = None, | |
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: | |
"""Function to change the speed of a batch of audio data. | |
The output will have a different length ! | |
Args: | |
audio_batch (torch.Tensor): The batch of audio data in torch tensor format. | |
speed (float): The speed to change the audio to. | |
Returns: | |
torch.Tensor: The batch of audio data with the speed changed. | |
""" | |
speed = torch.FloatTensor(1).uniform_(*speed_range) | |
new_sr = int(sample_rate * 1 / speed) | |
resampled_tensor = julius.resample.resample_frac(tensor, sample_rate, new_sr) | |
if mask is None: | |
return resampled_tensor | |
else: | |
return resampled_tensor, torch.nn.functional.interpolate( | |
mask, size=resampled_tensor.size(-1), mode="nearest-exact" | |
) | |
def updownresample( | |
tensor: torch.Tensor, | |
sample_rate: int = 16000, | |
intermediate_freq: int = 32000, | |
mask: tp.Optional[torch.Tensor] = None, | |
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: | |
orig_shape = tensor.shape | |
# upsample | |
tensor = resample_frac(tensor, sample_rate, intermediate_freq) | |
# downsample | |
tensor = resample_frac(tensor, intermediate_freq, sample_rate) | |
assert tensor.shape == orig_shape | |
return audio_effect_return(tensor=tensor, mask=mask) | |
def echo( | |
tensor: torch.Tensor, | |
volume_range: tuple = (0.1, 0.5), | |
duration_range: tuple = (0.1, 0.5), | |
sample_rate: int = 16000, | |
mask: tp.Optional[torch.Tensor] = None, | |
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: | |
"""Attenuating the audio volume by a factor of 0.4, delaying it by 100ms, | |
and then overlaying it with the original. | |
Args: | |
tensor: 3D Tensor representing the audio signal [bsz, channels, frames] | |
volumne range: volume range of the echo signal | |
duration range: duration range of the echo signal | |
sample_rate: Sample rate of the audio signal. | |
Returns: | |
Audio signal with reverb. | |
""" | |
# Create a simple impulse response | |
# Duration of the impulse response in seconds | |
duration = torch.FloatTensor(1).uniform_(*duration_range) | |
volume = torch.FloatTensor(1).uniform_(*volume_range) | |
n_samples = int(sample_rate * duration) | |
impulse_response = torch.zeros(n_samples).type(tensor.type()).to(tensor.device) | |
# Define a few reflections with decreasing amplitude | |
impulse_response[0] = 1.0 # Direct sound | |
impulse_response[ | |
int(sample_rate * duration) - 1 | |
] = volume # First reflection after 100ms | |
# Add batch and channel dimensions to the impulse response | |
impulse_response = impulse_response.unsqueeze(0).unsqueeze(0) | |
# Convolve the audio signal with the impulse response | |
reverbed_signal = fft_conv1d(tensor, impulse_response) | |
# Normalize to the original amplitude range for stability | |
reverbed_signal = ( | |
reverbed_signal | |
/ torch.max(torch.abs(reverbed_signal)) | |
* torch.max(torch.abs(tensor)) | |
) | |
# Ensure tensor size is not changed | |
tmp = torch.zeros_like(tensor) | |
tmp[..., : reverbed_signal.shape[-1]] = reverbed_signal | |
reverbed_signal = tmp | |
return audio_effect_return(tensor=reverbed_signal, mask=mask) | |
def random_noise( | |
waveform: torch.Tensor, | |
noise_std: float = 0.001, | |
mask: tp.Optional[torch.Tensor] = None, | |
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: | |
"""Add Gaussian noise to the waveform.""" | |
noise = torch.randn_like(waveform) * noise_std | |
noisy_waveform = waveform + noise | |
return audio_effect_return(tensor=noisy_waveform, mask=mask) | |
def pink_noise( | |
waveform: torch.Tensor, | |
noise_std: float = 0.01, | |
mask: tp.Optional[torch.Tensor] = None, | |
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: | |
"""Add pink background noise to the waveform.""" | |
noise = generate_pink_noise(waveform.shape[-1]) * noise_std | |
noise = noise.to(waveform.device) | |
# Assuming waveform is of shape (bsz, channels, length) | |
noisy_waveform = waveform + noise.unsqueeze(0).unsqueeze(0).to(waveform.device) | |
return audio_effect_return(tensor=noisy_waveform, mask=mask) | |
def lowpass_filter( | |
waveform: torch.Tensor, | |
cutoff_freq: float = 5000, | |
sample_rate: int = 16000, | |
mask: tp.Optional[torch.Tensor] = None, | |
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: | |
"""Filter the lowpass frequency from the waveform""" | |
return audio_effect_return( | |
tensor=julius.lowpass_filter(waveform, cutoff=cutoff_freq / sample_rate), | |
mask=mask, | |
) | |
def highpass_filter( | |
waveform: torch.Tensor, | |
cutoff_freq: float = 500, | |
sample_rate: int = 16000, | |
mask: tp.Optional[torch.Tensor] = None, | |
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: | |
"""Filter the highpass frequency from the waveform""" | |
return audio_effect_return( | |
tensor=julius.highpass_filter(waveform, cutoff=cutoff_freq / sample_rate), | |
mask=mask, | |
) | |
def bandpass_filter( | |
waveform: torch.Tensor, | |
cutoff_freq_low: float = 300, | |
cutoff_freq_high: float = 8000, | |
sample_rate: int = 16000, | |
mask: tp.Optional[torch.Tensor] = None, | |
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: | |
"""Apply a bandpass filter to the waveform by cascading | |
a high-pass filter followed by a low-pass filter. | |
Args: | |
waveform (torch.Tensor): Input audio waveform. | |
low_cutoff (float): Lower cutoff frequency. | |
high_cutoff (float): Higher cutoff frequency. | |
sample_rate (int): The sample rate of the waveform. | |
Returns: | |
torch.Tensor: Filtered audio waveform. | |
""" | |
return audio_effect_return( | |
tensor=julius.bandpass_filter( | |
waveform, | |
cutoff_low=cutoff_freq_low / sample_rate, | |
cutoff_high=cutoff_freq_high / sample_rate, | |
), | |
mask=mask, | |
) | |
def smooth( | |
tensor: torch.Tensor, | |
window_size_range: tuple = (2, 10), | |
mask: tp.Optional[torch.Tensor] = None, | |
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: | |
"""Smooths the input tensor (audio signal) using a moving average filter with the | |
given window size. | |
Args: | |
tensor (torch.Tensor): Input audio tensor. Assumes tensor shape is (batch_size, | |
channels, time). | |
window_size (int): Size of the moving average window. | |
mask: Masks for the input wave | |
Returns: | |
torch.Tensor: Smoothed audio tensor. | |
""" | |
window_size = int(torch.FloatTensor(1).uniform_(*window_size_range)) | |
# Create a uniform smoothing kernel | |
kernel = torch.ones(1, 1, window_size).type(tensor.type()) / window_size | |
kernel = kernel.to(tensor.device) | |
smoothed = fft_conv1d(tensor, kernel) | |
# Ensure tensor size is not changed | |
tmp = torch.zeros_like(tensor) | |
tmp[..., : smoothed.shape[-1]] = smoothed | |
smoothed = tmp | |
return audio_effect_return(tensor=smoothed, mask=mask) | |
def boost_audio( | |
tensor: torch.Tensor, | |
amount: float = 20, | |
mask: tp.Optional[torch.Tensor] = None, | |
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: | |
"""Filter the lowpass frequency from the waveform""" | |
return audio_effect_return(tensor=tensor * (1 + amount / 100), mask=mask) | |
def duck_audio( | |
tensor: torch.Tensor, | |
amount: float = 20, | |
mask: tp.Optional[torch.Tensor] = None, | |
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: | |
"""Mask input wav with some ducked signnals""" | |
return audio_effect_return(tensor=tensor * (1 - amount / 100), mask=mask) | |
def identity( | |
tensor: torch.Tensor, mask: tp.Optional[torch.Tensor] = None | |
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: | |
return audio_effect_return(tensor=tensor, mask=mask) | |
def mp3_compression( | |
tensor: torch.Tensor, | |
sample_rate: int = 16000, | |
bitrate: str = "128k", | |
mask: tp.Optional[torch.Tensor] = None, | |
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: | |
""" | |
Compress audio using MP3 algorithm | |
Args: | |
tensor (torch.Tensor): The input audio tensor. | |
sample_rate (int): The sample rate of the audio. | |
bitrate (str): The bitrate for MP3 compression. | |
Returns: | |
torch.Tensor: The output tensor after applying MP3 compression. | |
""" | |
out = apply_compression_skip_grad( | |
tensor, get_mp3, sr=sample_rate, bitrate=bitrate | |
) | |
return audio_effect_return(tensor=out, mask=mask) | |
def aac_compression( | |
tensor: torch.Tensor, | |
sample_rate: int = 16000, | |
bitrate: str = "128k", | |
lowpass_freq: tp.Optional[int] = None, | |
mask: tp.Optional[torch.Tensor] = None, | |
) -> tp.Union[tp.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: | |
"""Applies AAC compression to an audio tensor. | |
Args: | |
tensor (torch.Tensor): The input audio tensor. | |
sample_rate (int): The sample rate of the audio. | |
bitrate (str): The bitrate for AAC compression. | |
lowpass_freq (Optional[int]): The frequency for a low-pass filter. | |
Returns: | |
torch.Tensor: The output tensor after applying AAC compression. | |
""" | |
out = apply_compression_skip_grad( | |
tensor, get_aac, sr=sample_rate, bitrate=bitrate, lowpass_freq=lowpass_freq | |
) | |
return audio_effect_return(tensor=out, mask=mask) | |