ydshieh commited on
Commit
a01b02a
1 Parent(s): f082d66

remove only_self_attn

Browse files
Files changed (1) hide show
  1. vit_gpt2/modeling_flax_gpt2.py +9 -10
vit_gpt2/modeling_flax_gpt2.py CHANGED
@@ -299,15 +299,13 @@ class FlaxGPT2Block(nn.Module):
299
 
300
  def setup(self):
301
 
302
- self.only_self_attn = not self.config.add_cross_attention
303
-
304
  hidden_size = self.config.hidden_size
305
  inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
306
 
307
  self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
308
  self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
309
 
310
- if not self.only_self_attn:
311
  self.cross_attn_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
312
  # [IMPORTANT] Cross attention requires ``causal=False``! This is a bug I made previously.
313
  self.cross_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, self_attn=False)
@@ -343,16 +341,17 @@ class FlaxGPT2Block(nn.Module):
343
  attn_output = outputs[0]
344
  hidden_states = attn_output + residual
345
 
346
- # sanity check
347
- if not self.only_self_attn:
348
- assert encoder_hidden_states is not None
349
- else:
350
- assert encoder_hidden_states is None
351
-
352
  # Cross-Attention Block
353
  cross_attn_weights = None
354
  if encoder_hidden_states is not None:
355
 
 
 
 
 
 
 
 
356
  project_encoder = getattr(self.config, "project_encoder", None)
357
  if project_encoder:
358
  residual = encoder_hidden_states
@@ -393,7 +392,7 @@ class FlaxGPT2Block(nn.Module):
393
  if output_attentions:
394
  self_attn_weights = attn_output[1]
395
  outputs += (self_attn_weights,)
396
- if not self.only_self_attn:
397
  outputs += (cross_attn_weights,)
398
 
399
  return outputs
 
299
 
300
  def setup(self):
301
 
 
 
302
  hidden_size = self.config.hidden_size
303
  inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
304
 
305
  self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
306
  self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
307
 
308
+ if self.config.add_cross_attention:
309
  self.cross_attn_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
310
  # [IMPORTANT] Cross attention requires ``causal=False``! This is a bug I made previously.
311
  self.cross_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, self_attn=False)
 
341
  attn_output = outputs[0]
342
  hidden_states = attn_output + residual
343
 
 
 
 
 
 
 
344
  # Cross-Attention Block
345
  cross_attn_weights = None
346
  if encoder_hidden_states is not None:
347
 
348
+ # add one self-attention block for cross-attention
349
+ if not hasattr(self, "cross_attn"):
350
+ raise ValueError(
351
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
352
+ "cross-attention layers by setting `config.add_cross_attention=True`"
353
+ )
354
+
355
  project_encoder = getattr(self.config, "project_encoder", None)
356
  if project_encoder:
357
  residual = encoder_hidden_states
 
392
  if output_attentions:
393
  self_attn_weights = attn_output[1]
394
  outputs += (self_attn_weights,)
395
+ if cross_attn_weights is not None:
396
  outputs += (cross_attn_weights,)
397
 
398
  return outputs