Spaces:
Running
Running
fix: correct use of dtype
Browse files- dalle_mini/model.py +4 -10
dalle_mini/model.py
CHANGED
@@ -18,21 +18,20 @@ class CustomFlaxBartModule(FlaxBartModule):
|
|
18 |
self.shared = nn.Embed(
|
19 |
self.config.vocab_size,
|
20 |
self.config.d_model,
|
21 |
-
embedding_init=jax.nn.initializers.normal(self.config.init_std
|
22 |
-
dtype=self.dtype,
|
23 |
)
|
24 |
# a separate embedding is used for the decoder
|
25 |
self.decoder_embed = nn.Embed(
|
26 |
self.config.image_vocab_size + 1,
|
27 |
self.config.d_model,
|
28 |
-
embedding_init=jax.nn.initializers.normal(self.config.init_std
|
29 |
-
dtype=self.dtype,
|
30 |
)
|
31 |
self.encoder = FlaxBartEncoder(
|
32 |
self.config, dtype=self.dtype, embed_tokens=self.shared
|
33 |
)
|
34 |
|
35 |
# the decoder has a different config
|
|
|
36 |
decoder_config = BartConfig(self.config.to_dict())
|
37 |
decoder_config.max_position_embeddings = (
|
38 |
self.config.image_length + 1 # image tokens + BOS
|
@@ -47,16 +46,11 @@ class CustomFlaxBartForConditionalGenerationModule(
|
|
47 |
FlaxBartForConditionalGenerationModule
|
48 |
):
|
49 |
def setup(self):
|
50 |
-
# check config is valid, otherwise set default values
|
51 |
-
# TODO: simplify with custom config class
|
52 |
-
self.config.text_normalized = True / False
|
53 |
-
|
54 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
55 |
self.lm_head = nn.Dense(
|
56 |
self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
|
57 |
use_bias=False,
|
58 |
-
|
59 |
-
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
60 |
)
|
61 |
self.final_logits_bias = self.param(
|
62 |
"final_logits_bias", self.bias_init, (1, self.config.image_vocab_size + 1)
|
|
|
18 |
self.shared = nn.Embed(
|
19 |
self.config.vocab_size,
|
20 |
self.config.d_model,
|
21 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
22 |
)
|
23 |
# a separate embedding is used for the decoder
|
24 |
self.decoder_embed = nn.Embed(
|
25 |
self.config.image_vocab_size + 1,
|
26 |
self.config.d_model,
|
27 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
28 |
)
|
29 |
self.encoder = FlaxBartEncoder(
|
30 |
self.config, dtype=self.dtype, embed_tokens=self.shared
|
31 |
)
|
32 |
|
33 |
# the decoder has a different config
|
34 |
+
# TODO: should not be needed once we have custom config/module
|
35 |
decoder_config = BartConfig(self.config.to_dict())
|
36 |
decoder_config.max_position_embeddings = (
|
37 |
self.config.image_length + 1 # image tokens + BOS
|
|
|
46 |
FlaxBartForConditionalGenerationModule
|
47 |
):
|
48 |
def setup(self):
|
|
|
|
|
|
|
|
|
49 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
50 |
self.lm_head = nn.Dense(
|
51 |
self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
|
52 |
use_bias=False,
|
53 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
54 |
)
|
55 |
self.final_logits_bias = self.param(
|
56 |
"final_logits_bias", self.bias_init, (1, self.config.image_vocab_size + 1)
|