Spaces:
Running
Running
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
# This code is modified from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/layers.py | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
from torch.nn.utils import weight_norm | |
def WNConv1d(*args, **kwargs): | |
return weight_norm(nn.Conv1d(*args, **kwargs)) | |
def WNConvTranspose1d(*args, **kwargs): | |
return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) | |
# Scripting this brings model speed up 1.4x | |
def snake(x, alpha): | |
shape = x.shape | |
x = x.reshape(shape[0], shape[1], -1) | |
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) | |
x = x.reshape(shape) | |
return x | |
class Snake1d(nn.Module): | |
def __init__(self, channels): | |
super().__init__() | |
self.alpha = nn.Parameter(torch.ones(1, channels, 1)) | |
def forward(self, x): | |
return snake(x, self.alpha) | |