frankleeeee commited on
Commit
96995a1
1 Parent(s): df479cc

Upload STDiT

Browse files
Files changed (2) hide show
  1. config.json +1 -0
  2. modeling_stdit.py +6 -1
config.json CHANGED
@@ -11,6 +11,7 @@
11
  "depth": 28,
12
  "drop_path": 0.0,
13
  "enable_flash_attn": false,
 
14
  "enable_layernorm_kernel": false,
15
  "enable_sequence_parallelism": false,
16
  "freeze": null,
 
11
  "depth": 28,
12
  "drop_path": 0.0,
13
  "enable_flash_attn": false,
14
+ "enable_flashattn": false,
15
  "enable_layernorm_kernel": false,
16
  "enable_sequence_parallelism": false,
17
  "freeze": null,
modeling_stdit.py CHANGED
@@ -109,6 +109,10 @@ class STDiT(PreTrainedModel):
109
  Returns:
110
  x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
111
  """
 
 
 
 
112
  # embedding
113
  x = self.x_embedder(x) # [B, N, C]
114
  x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
@@ -144,7 +148,8 @@ class STDiT(PreTrainedModel):
144
  tpe = self.pos_embed_temporal
145
  else:
146
  tpe = None
147
- x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)
 
148
 
149
  if self.enable_sequence_parallelism:
150
  x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
 
109
  Returns:
110
  x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
111
  """
112
+ x = x.to(self.final_layer.linear.weight.dtype)
113
+ timestep = timestep.to(self.final_layer.linear.weight.dtype)
114
+ y = y.to(self.final_layer.linear.weight.dtype)
115
+
116
  # embedding
117
  x = self.x_embedder(x) # [B, N, C]
118
  x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
 
148
  tpe = self.pos_embed_temporal
149
  else:
150
  tpe = None
151
+ x = block(x, y, t0, y_lens, tpe)
152
+ # x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)
153
 
154
  if self.enable_sequence_parallelism:
155
  x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")