# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # This source file is copied from https://github.com/facebookresearch/encodec # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Normalization modules.""" import typing as tp import einops import torch from torch import nn class ConvLayerNorm(nn.LayerNorm): """ Convolution-friendly LayerNorm that moves channels to last dimensions before running the normalization and moves them back to original position right after. """ def __init__( self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs ): super().__init__(normalized_shape, **kwargs) def forward(self, x): x = einops.rearrange(x, "b ... t -> b t ...") x = super().forward(x) x = einops.rearrange(x, "b t ... -> b ... t") return