Spaces:
Running
Running
feat: add cogview
Browse files- README.md +12 -0
- src/dalle_mini/model/configuration.py +2 -1
- src/dalle_mini/model/modeling.py +8 -8
README.md
CHANGED
@@ -124,6 +124,7 @@ Sequence to sequence model based on "[BART: Denoising Sequence-to-Sequence Pre-t
|
|
124 |
- "[Deepnet: Scaling Transformers to 1,000 Layers](https://arxiv.org/abs/2203.00555)"
|
125 |
- "[NormFormer: Improved Transformer Pretraining with Extra Normalization](https://arxiv.org/abs/2110.09456)"
|
126 |
- "[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)"
|
|
|
127 |
- "[Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)"
|
128 |
|
129 |
Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization for Deep Learning](https://arxiv.org/abs/2002.09018)".
|
@@ -225,6 +226,17 @@ Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization f
|
|
225 |
}
|
226 |
```
|
227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
```text
|
229 |
@misc{zhang2019root,
|
230 |
title = {Root Mean Square Layer Normalization},
|
|
|
124 |
- "[Deepnet: Scaling Transformers to 1,000 Layers](https://arxiv.org/abs/2203.00555)"
|
125 |
- "[NormFormer: Improved Transformer Pretraining with Extra Normalization](https://arxiv.org/abs/2110.09456)"
|
126 |
- "[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)"
|
127 |
+
- "[CogView: Mastering Text-to-Image Generation via Transformers](https://arxiv.org/abs/2105.13290v2)
|
128 |
- "[Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)"
|
129 |
|
130 |
Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization for Deep Learning](https://arxiv.org/abs/2002.09018)".
|
|
|
226 |
}
|
227 |
```
|
228 |
|
229 |
+
```text
|
230 |
+
@misc{ding2021cogview,
|
231 |
+
title = {CogView: Mastering Text-to-Image Generation via Transformers},
|
232 |
+
author = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
|
233 |
+
year = {2021},
|
234 |
+
eprint = {2105.13290},
|
235 |
+
archivePrefix = {arXiv},
|
236 |
+
primaryClass = {cs.CV}
|
237 |
+
}
|
238 |
+
```
|
239 |
+
|
240 |
```text
|
241 |
@misc{zhang2019root,
|
242 |
title = {Root Mean Square Layer Normalization},
|
src/dalle_mini/model/configuration.py
CHANGED
@@ -60,7 +60,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
60 |
# transformer variants
|
61 |
head_scale=False, # used in NormFormer
|
62 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
63 |
-
ln_positions="deepnet", # layer normalization positions, "normformer", "swinv2", "deepnet" (same as post-ln)
|
64 |
use_cosine_attention=False, # used in Swin v2
|
65 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
66 |
use_deepnet_scaling=False, # used in Deepnet
|
@@ -80,6 +80,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
80 |
assert ln_positions in [
|
81 |
"normformer",
|
82 |
"swinv2",
|
|
|
83 |
"deepnet",
|
84 |
], "ln_positions must be 'normformer', 'swinv2' or 'deepnet'"
|
85 |
self.ln_positions = ln_positions
|
|
|
60 |
# transformer variants
|
61 |
head_scale=False, # used in NormFormer
|
62 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
63 |
+
ln_positions="deepnet", # layer normalization positions, "normformer", "swinv2", "cogview", "deepnet" (same as post-ln)
|
64 |
use_cosine_attention=False, # used in Swin v2
|
65 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
66 |
use_deepnet_scaling=False, # used in Deepnet
|
|
|
80 |
assert ln_positions in [
|
81 |
"normformer",
|
82 |
"swinv2",
|
83 |
+
"cogview",
|
84 |
"deepnet",
|
85 |
], "ln_positions must be 'normformer', 'swinv2' or 'deepnet'"
|
86 |
self.ln_positions = ln_positions
|
src/dalle_mini/model/modeling.py
CHANGED
@@ -373,7 +373,7 @@ class GLU(nn.Module):
|
|
373 |
self.config
|
374 |
)
|
375 |
|
376 |
-
if self.config.ln_positions in ["normformer"]:
|
377 |
x = norm(
|
378 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
|
379 |
)(x)
|
@@ -411,7 +411,7 @@ class GLU(nn.Module):
|
|
411 |
if self.config.use_deepnet_scaling
|
412 |
else jax.nn.initializers.normal(self.config.init_std),
|
413 |
)(x)
|
414 |
-
if self.config.ln_positions in ["swinv2"]:
|
415 |
x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
|
416 |
x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
|
417 |
return x
|
@@ -432,7 +432,7 @@ class FFN(nn.Module):
|
|
432 |
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
|
433 |
self.config
|
434 |
)
|
435 |
-
if self.config.ln_positions in ["normformer"]:
|
436 |
x = norm(
|
437 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
|
438 |
)(x)
|
@@ -460,7 +460,7 @@ class FFN(nn.Module):
|
|
460 |
if self.config.use_deepnet_scaling
|
461 |
else jax.nn.initializers.normal(self.config.init_std),
|
462 |
)(x)
|
463 |
-
if self.config.ln_positions in ["swinv2"]:
|
464 |
x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
|
465 |
x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
|
466 |
return x
|
@@ -593,7 +593,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
593 |
residual = hidden_states
|
594 |
|
595 |
# Self Attention
|
596 |
-
if self.config.ln_positions in ["normformer"]:
|
597 |
hidden_states = norm(
|
598 |
self.config.ln_type,
|
599 |
dtype=self.dtype,
|
@@ -615,7 +615,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
615 |
init_cache=init_cache,
|
616 |
)
|
617 |
|
618 |
-
if self.config.ln_positions in ["normformer", "swinv2"]:
|
619 |
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
620 |
hidden_states
|
621 |
)
|
@@ -632,7 +632,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
632 |
cross_attn_weights = None
|
633 |
if encoder_hidden_states is not None:
|
634 |
residual = hidden_states
|
635 |
-
if self.config.ln_positions in ["normformer"]:
|
636 |
hidden_states = norm(
|
637 |
self.config.ln_type,
|
638 |
dtype=self.dtype,
|
@@ -652,7 +652,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
652 |
key_value_states=encoder_hidden_states,
|
653 |
attention_mask=encoder_attention_mask,
|
654 |
)
|
655 |
-
if self.config.ln_positions in ["normformer", "swinv2"]:
|
656 |
hidden_states = norm(
|
657 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
658 |
)(hidden_states)
|
|
|
373 |
self.config
|
374 |
)
|
375 |
|
376 |
+
if self.config.ln_positions in ["normformer", "cogview"]:
|
377 |
x = norm(
|
378 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
|
379 |
)(x)
|
|
|
411 |
if self.config.use_deepnet_scaling
|
412 |
else jax.nn.initializers.normal(self.config.init_std),
|
413 |
)(x)
|
414 |
+
if self.config.ln_positions in ["swinv2", "cogview"]:
|
415 |
x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
|
416 |
x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
|
417 |
return x
|
|
|
432 |
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
|
433 |
self.config
|
434 |
)
|
435 |
+
if self.config.ln_positions in ["normformer", "cogview"]:
|
436 |
x = norm(
|
437 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
|
438 |
)(x)
|
|
|
460 |
if self.config.use_deepnet_scaling
|
461 |
else jax.nn.initializers.normal(self.config.init_std),
|
462 |
)(x)
|
463 |
+
if self.config.ln_positions in ["swinv2", "cogview"]:
|
464 |
x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
|
465 |
x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
|
466 |
return x
|
|
|
593 |
residual = hidden_states
|
594 |
|
595 |
# Self Attention
|
596 |
+
if self.config.ln_positions in ["normformer", "cogview"]:
|
597 |
hidden_states = norm(
|
598 |
self.config.ln_type,
|
599 |
dtype=self.dtype,
|
|
|
615 |
init_cache=init_cache,
|
616 |
)
|
617 |
|
618 |
+
if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
|
619 |
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
620 |
hidden_states
|
621 |
)
|
|
|
632 |
cross_attn_weights = None
|
633 |
if encoder_hidden_states is not None:
|
634 |
residual = hidden_states
|
635 |
+
if self.config.ln_positions in ["normformer", "cogview"]:
|
636 |
hidden_states = norm(
|
637 |
self.config.ln_type,
|
638 |
dtype=self.dtype,
|
|
|
652 |
key_value_states=encoder_hidden_states,
|
653 |
attention_mask=encoder_attention_mask,
|
654 |
)
|
655 |
+
if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
|
656 |
hidden_states = norm(
|
657 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
658 |
)(hidden_states)
|