frankleeeee
commited on
Commit
•
96995a1
1
Parent(s):
df479cc
Upload STDiT
Browse files- config.json +1 -0
- modeling_stdit.py +6 -1
config.json
CHANGED
@@ -11,6 +11,7 @@
|
|
11 |
"depth": 28,
|
12 |
"drop_path": 0.0,
|
13 |
"enable_flash_attn": false,
|
|
|
14 |
"enable_layernorm_kernel": false,
|
15 |
"enable_sequence_parallelism": false,
|
16 |
"freeze": null,
|
|
|
11 |
"depth": 28,
|
12 |
"drop_path": 0.0,
|
13 |
"enable_flash_attn": false,
|
14 |
+
"enable_flashattn": false,
|
15 |
"enable_layernorm_kernel": false,
|
16 |
"enable_sequence_parallelism": false,
|
17 |
"freeze": null,
|
modeling_stdit.py
CHANGED
@@ -109,6 +109,10 @@ class STDiT(PreTrainedModel):
|
|
109 |
Returns:
|
110 |
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
|
111 |
"""
|
|
|
|
|
|
|
|
|
112 |
# embedding
|
113 |
x = self.x_embedder(x) # [B, N, C]
|
114 |
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
|
@@ -144,7 +148,8 @@ class STDiT(PreTrainedModel):
|
|
144 |
tpe = self.pos_embed_temporal
|
145 |
else:
|
146 |
tpe = None
|
147 |
-
x =
|
|
|
148 |
|
149 |
if self.enable_sequence_parallelism:
|
150 |
x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
|
|
|
109 |
Returns:
|
110 |
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
|
111 |
"""
|
112 |
+
x = x.to(self.final_layer.linear.weight.dtype)
|
113 |
+
timestep = timestep.to(self.final_layer.linear.weight.dtype)
|
114 |
+
y = y.to(self.final_layer.linear.weight.dtype)
|
115 |
+
|
116 |
# embedding
|
117 |
x = self.x_embedder(x) # [B, N, C]
|
118 |
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
|
|
|
148 |
tpe = self.pos_embed_temporal
|
149 |
else:
|
150 |
tpe = None
|
151 |
+
x = block(x, y, t0, y_lens, tpe)
|
152 |
+
# x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)
|
153 |
|
154 |
if self.enable_sequence_parallelism:
|
155 |
x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
|