VoiceRestore / voice_restore.py
jadechoghari's picture
add initial files
96e64e9 verified
raw
history blame
8.24 kB
#
# Code is adapted from https://github.com/lucidrains/e2-tts-pytorch
#
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
from typing import Dict, Any, Optional
from functools import partial
import torch
from torch import nn
from torch.nn import Module, ModuleList, Sequential, Linear
import torch.nn.functional as F
from torchdiffeq import odeint
from einops.layers.torch import Rearrange
from einops import rearrange, repeat, pack, unpack
from x_transformers import Attention, FeedForward, RMSNorm, AdaptiveRMSNorm
from x_transformers.x_transformers import RotaryEmbedding
from gateloop_transformer import SimpleGateLoopLayer
from tensor_typing import Float
class Identity(Module):
def forward(self, x, **kwargs):
return x
class AdaLNZero(Module):
def __init__(self, dim: int, dim_condition: Optional[int] = None, init_bias_value: float = -2.):
super().__init__()
dim_condition = dim_condition or dim
self.to_gamma = nn.Linear(dim_condition, dim)
nn.init.zeros_(self.to_gamma.weight)
nn.init.constant_(self.to_gamma.bias, init_bias_value)
def forward(self, x: torch.Tensor, *, condition: torch.Tensor) -> torch.Tensor:
if condition.ndim == 2:
condition = rearrange(condition, 'b d -> b 1 d')
gamma = self.to_gamma(condition).sigmoid()
return x * gamma
def exists(v: Any) -> bool:
return v is not None
def default(v: Any, d: Any) -> Any:
return v if exists(v) else d
def divisible_by(num: int, den: int) -> bool:
return (num % den) == 0
class Transformer(Module):
def __init__(
self,
*,
dim: int,
depth: int = 8,
cond_on_time: bool = True,
skip_connect_type: str = 'concat',
abs_pos_emb: bool = True,
max_seq_len: int = 8192,
heads: int = 8,
dim_head: int = 64,
num_gateloop_layers: int = 1,
dropout: float = 0.1,
num_registers: int = 32,
attn_kwargs: Dict[str, Any] = dict(gate_value_heads=True, softclamp_logits=True),
ff_kwargs: Dict[str, Any] = dict()
):
super().__init__()
assert divisible_by(depth, 2), 'depth needs to be even'
self.max_seq_len = max_seq_len
self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if abs_pos_emb else None
self.dim = dim
self.skip_connect_type = skip_connect_type
needs_skip_proj = skip_connect_type == 'concat'
self.depth = depth
self.layers = ModuleList([])
self.num_registers = num_registers
self.registers = nn.Parameter(torch.zeros(num_registers, dim))
nn.init.normal_(self.registers, std=0.02)
self.rotary_emb = RotaryEmbedding(dim_head)
self.cond_on_time = cond_on_time
rmsnorm_klass = AdaptiveRMSNorm if cond_on_time else RMSNorm
postbranch_klass = partial(AdaLNZero, dim=dim) if cond_on_time else Identity
self.time_cond_mlp = Sequential(
Rearrange('... -> ... 1'),
Linear(1, dim),
nn.SiLU()
) if cond_on_time else nn.Identity()
for ind in range(depth):
is_later_half = ind >= (depth // 2)
gateloop = SimpleGateLoopLayer(dim=dim)
attn_norm = rmsnorm_klass(dim)
attn = Attention(dim=dim, heads=heads, dim_head=dim_head, dropout=dropout, **attn_kwargs)
attn_adaln_zero = postbranch_klass()
ff_norm = rmsnorm_klass(dim)
ff = FeedForward(dim=dim, glu=True, dropout=dropout, **ff_kwargs)
ff_adaln_zero = postbranch_klass()
skip_proj = Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
self.layers.append(ModuleList([
gateloop, skip_proj, attn_norm, attn, attn_adaln_zero,
ff_norm, ff, ff_adaln_zero
]))
self.final_norm = RMSNorm(dim)
def forward(
self,
x: Float['b n d'],
times: Optional[Float['b'] | Float['']] = None,
) -> torch.Tensor:
batch, seq_len, device = *x.shape[:2], x.device
assert not (exists(times) ^ self.cond_on_time), '`times` must be passed in if `cond_on_time` is set to `True` and vice versa'
norm_kwargs = {}
if exists(self.abs_pos_emb):
# assert seq_len <= self.max_seq_len, f'{seq_len} exceeds the set `max_seq_len` ({self.max_seq_len}) on Transformer'
seq = torch.arange(seq_len, device=device)
x = x + self.abs_pos_emb(seq)
if exists(times):
if times.ndim == 0:
times = repeat(times, ' -> b', b=batch)
times = self.time_cond_mlp(times)
norm_kwargs['condition'] = times
registers = repeat(self.registers, 'r d -> b r d', b=batch)
x, registers_packed_shape = pack((registers, x), 'b * d')
rotary_pos_emb = self.rotary_emb.forward_from_seq_len(x.shape[-2])
skips = []
for ind, (
gateloop, maybe_skip_proj, attn_norm, attn, maybe_attn_adaln_zero,
ff_norm, ff, maybe_ff_adaln_zero
) in enumerate(self.layers):
layer = ind + 1
is_first_half = layer <= (self.depth // 2)
if is_first_half:
skips.append(x)
else:
skip = skips.pop()
if self.skip_connect_type == 'concat':
x = torch.cat((x, skip), dim=-1)
x = maybe_skip_proj(x)
x = gateloop(x) + x
attn_out = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb=rotary_pos_emb)
x = x + maybe_attn_adaln_zero(attn_out, **norm_kwargs)
ff_out = ff(ff_norm(x, **norm_kwargs))
x = x + maybe_ff_adaln_zero(ff_out, **norm_kwargs)
assert len(skips) == 0
_, x = unpack(x, registers_packed_shape, 'b * d')
return self.final_norm(x)
class VoiceRestore(nn.Module):
def __init__(
self,
sigma: float = 0.0,
transformer: Optional[Dict[str, Any]] = None,
odeint_kwargs: Optional[Dict[str, Any]] = None,
num_channels: int = 100,
):
super().__init__()
self.sigma = sigma
self.num_channels = num_channels
self.transformer = Transformer(**transformer, cond_on_time=True)
self.odeint_kwargs = odeint_kwargs or {'atol': 1e-5, 'rtol': 1e-5, 'method': 'midpoint'}
self.proj_in = nn.Linear(num_channels, self.transformer.dim)
self.cond_proj = nn.Linear(num_channels, self.transformer.dim)
self.to_pred = nn.Linear(self.transformer.dim, num_channels)
def transformer_with_pred_head(self, x: torch.Tensor, times: torch.Tensor, cond: Optional[torch.Tensor] = None) -> torch.Tensor:
x = self.proj_in(x)
if cond is not None:
cond_proj = self.cond_proj(cond)
x = x + cond_proj
attended = self.transformer(x, times=times)
return self.to_pred(attended)
def cfg_transformer_with_pred_head(
self,
*args,
cond=None,
mask=None,
cfg_strength: float = 0.5,
**kwargs,
):
pred = self.transformer_with_pred_head(*args, **kwargs, cond=cond)
if cfg_strength < 1e-5:
return pred * mask.unsqueeze(-1) if mask is not None else pred
null_pred = self.transformer_with_pred_head(*args, **kwargs, cond=None)
result = pred + (pred - null_pred) * cfg_strength
return result * mask.unsqueeze(-1) if mask is not None else result
@torch.no_grad()
def sample(self, processed: torch.Tensor, steps: int = 32, cfg_strength: float = 0.5) -> torch.Tensor:
self.eval()
times = torch.linspace(0, 1, steps, device=processed.device)
def ode_fn(t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
return self.cfg_transformer_with_pred_head(x, times=t, cond=processed, cfg_strength=cfg_strength)
y0 = torch.randn_like(processed)
trajectory = odeint(ode_fn, y0, times, **self.odeint_kwargs)
restored = trajectory[-1]
return restored