File size: 2,201 Bytes
c968fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright (c) 2024 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils import weight_norm


def WNConv1d(*args, **kwargs):
    return weight_norm(nn.Conv1d(*args, **kwargs))


def WNConvTranspose1d(*args, **kwargs):
    return weight_norm(nn.ConvTranspose1d(*args, **kwargs))


class LookupFreeQuantize(nn.Module):
    def __init__(
        self,
        input_dim,
        codebook_size,
        codebook_dim,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim

        assert 2**codebook_dim == codebook_size

        if self.input_dim != self.codebook_dim:
            self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1)
            self.out_project = WNConv1d(
                self.codebook_dim, self.input_dim, kernel_size=1
            )

        else:
            self.in_project = nn.Identity()
            self.out_project = nn.Identity()

    def forward(self, z):
        z_e = self.in_project(z)
        z_e = F.sigmoid(z_e)

        z_q = z_e + (torch.round(z_e) - z_e).detach()

        z_q = self.out_project(z_q)

        commit_loss = torch.zeros(z.shape[0], device=z.device)
        codebook_loss = torch.zeros(z.shape[0], device=z.device)

        bits = (
            2
            ** torch.arange(self.codebook_dim, device=z.device)
            .unsqueeze(0)
            .unsqueeze(-1)
            .long()
        )  # (1, d, 1)
        indices = (torch.round(z_e.clone().detach()).long() * bits).sum(1).long()

        return z_q, commit_loss, codebook_loss, indices, z_e

    def vq2emb(self, vq, out_proj=True):
        emb = torch.zeros(
            vq.shape[0], self.codebook_dim, vq.shape[-1], device=vq.device
        )  # (B, d, T)
        for i in range(self.codebook_dim):
            emb[:, i, :] = (vq % 2).float()
            vq = vq // 2
        if out_proj:
            emb = self.out_project(emb)
        return emb