File size: 3,630 Bytes
801501a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from salad.model_components.transformer import TimeMLP


class TimePointwiseLayer(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_ctx,
        mlp_ratio=2,
        act=F.leaky_relu,
        dropout=0.0,
        use_time=False,
    ):
        super().__init__()
        self.use_time = use_time
        self.act = act
        self.mlp1 = TimeMLP(
            dim_in, dim_in * mlp_ratio, dim_in, dim_ctx, use_time=use_time
        )
        self.norm1 = nn.LayerNorm(dim_in)

        self.mlp2 = TimeMLP(
            dim_in, dim_in * mlp_ratio, dim_in, dim_ctx, use_time=use_time
        )
        self.norm2 = nn.LayerNorm(dim_in)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, ctx=None):
        res = x
        x = self.mlp1(x, ctx=ctx)
        x = self.norm1(x + res)

        res = x
        x = self.mlp2(x, ctx=ctx)
        x = self.norm2(x + res)
        return x


class TimePointWiseEncoder(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_ctx=None,
        mlp_ratio=2,
        act=F.leaky_relu,
        dropout=0.0,
        use_time=True,
        num_layers=6,
        last_fc=False,
        last_fc_dim_out=None,
    ):
        super().__init__()
        self.last_fc = last_fc
        if last_fc:
            self.fc = nn.Linear(dim_in, last_fc_dim_out)
        self.layers = nn.ModuleList(
            [
                TimePointwiseLayer(
                    dim_in,
                    dim_ctx=dim_ctx,
                    mlp_ratio=mlp_ratio,
                    act=act,
                    dropout=dropout,
                    use_time=use_time,
                )
                for _ in range(num_layers)
            ]
        )

    def forward(self, x, ctx=None):
        for i, layer in enumerate(self.layers):
            x = layer(x, ctx=ctx)
        if self.last_fc:
            x = self.fc(x)
        return x


class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """

    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period)
            * torch.arange(start=0, end=half, dtype=torch.float32)
            / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat(
                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
            )
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb