boris commited on
Commit
b7d8724
1 Parent(s): 803c7df

fix: correct use of dtype

Browse files
Files changed (1) hide show
  1. dalle_mini/model.py +4 -10
dalle_mini/model.py CHANGED
@@ -18,21 +18,20 @@ class CustomFlaxBartModule(FlaxBartModule):
18
  self.shared = nn.Embed(
19
  self.config.vocab_size,
20
  self.config.d_model,
21
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
22
- dtype=self.dtype,
23
  )
24
  # a separate embedding is used for the decoder
25
  self.decoder_embed = nn.Embed(
26
  self.config.image_vocab_size + 1,
27
  self.config.d_model,
28
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
29
- dtype=self.dtype,
30
  )
31
  self.encoder = FlaxBartEncoder(
32
  self.config, dtype=self.dtype, embed_tokens=self.shared
33
  )
34
 
35
  # the decoder has a different config
 
36
  decoder_config = BartConfig(self.config.to_dict())
37
  decoder_config.max_position_embeddings = (
38
  self.config.image_length + 1 # image tokens + BOS
@@ -47,16 +46,11 @@ class CustomFlaxBartForConditionalGenerationModule(
47
  FlaxBartForConditionalGenerationModule
48
  ):
49
  def setup(self):
50
- # check config is valid, otherwise set default values
51
- # TODO: simplify with custom config class
52
- self.config.text_normalized = True / False
53
-
54
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
55
  self.lm_head = nn.Dense(
56
  self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
57
  use_bias=False,
58
- dtype=self.dtype,
59
- kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
60
  )
61
  self.final_logits_bias = self.param(
62
  "final_logits_bias", self.bias_init, (1, self.config.image_vocab_size + 1)
 
18
  self.shared = nn.Embed(
19
  self.config.vocab_size,
20
  self.config.d_model,
21
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
 
22
  )
23
  # a separate embedding is used for the decoder
24
  self.decoder_embed = nn.Embed(
25
  self.config.image_vocab_size + 1,
26
  self.config.d_model,
27
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
 
28
  )
29
  self.encoder = FlaxBartEncoder(
30
  self.config, dtype=self.dtype, embed_tokens=self.shared
31
  )
32
 
33
  # the decoder has a different config
34
+ # TODO: should not be needed once we have custom config/module
35
  decoder_config = BartConfig(self.config.to_dict())
36
  decoder_config.max_position_embeddings = (
37
  self.config.image_length + 1 # image tokens + BOS
 
46
  FlaxBartForConditionalGenerationModule
47
  ):
48
  def setup(self):
 
 
 
 
49
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
50
  self.lm_head = nn.Dense(
51
  self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
52
  use_bias=False,
53
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
 
54
  )
55
  self.final_logits_bias = self.param(
56
  "final_logits_bias", self.bias_init, (1, self.config.image_vocab_size + 1)