File size: 3,319 Bytes
9c3a994
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
from typing import Optional
from diffusers.models.embeddings import Timesteps
import math

from michelangelo.models.modules.transformer_blocks import MLP
from michelangelo.models.modules.diffusion_transformer import UNetDiffusionTransformer


class ConditionalASLUDTDenoiser(nn.Module):

    def __init__(self, *,
                 device: Optional[torch.device],
                 dtype: Optional[torch.dtype],
                 input_channels: int,
                 output_channels: int,
                 n_ctx: int,
                 width: int,
                 layers: int,
                 heads: int,
                 context_dim: int,
                 context_ln: bool = True,
                 skip_ln: bool = False,
                 init_scale: float = 0.25,
                 flip_sin_to_cos: bool = False,
                 use_checkpoint: bool = False):
        super().__init__()

        self.use_checkpoint = use_checkpoint

        init_scale = init_scale * math.sqrt(1.0 / width)

        self.backbone = UNetDiffusionTransformer(
            device=device,
            dtype=dtype,
            n_ctx=n_ctx,
            width=width,
            layers=layers,
            heads=heads,
            skip_ln=skip_ln,
            init_scale=init_scale,
            use_checkpoint=use_checkpoint
        )
        self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
        self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
        self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)

        # timestep embedding
        self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0)
        self.time_proj = MLP(
            device=device, dtype=dtype, width=width, init_scale=init_scale
        )

        self.context_embed = nn.Sequential(
            nn.LayerNorm(context_dim, device=device, dtype=dtype),
            nn.Linear(context_dim, width, device=device, dtype=dtype),
        )

        if context_ln:
            self.context_embed = nn.Sequential(
                nn.LayerNorm(context_dim, device=device, dtype=dtype),
                nn.Linear(context_dim, width, device=device, dtype=dtype),
            )
        else:
            self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype)

    def forward(self,
                model_input: torch.FloatTensor,
                timestep: torch.LongTensor,
                context: torch.FloatTensor):

        r"""
        Args:
            model_input (torch.FloatTensor): [bs, n_data, c]
            timestep (torch.LongTensor): [bs,]
            context (torch.FloatTensor): [bs, context_tokens, c]

        Returns:
            sample (torch.FloatTensor): [bs, n_data, c]

        """

        _, n_data, _ = model_input.shape

        # 1. time
        t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1)

        # 2. conditions projector
        context = self.context_embed(context)

        # 3. denoiser
        x = self.input_proj(model_input)
        x = torch.cat([t_emb, context, x], dim=1)
        x = self.backbone(x)
        x = self.ln_post(x)
        x = x[:, -n_data:]
        sample = self.output_proj(x)

        return sample