Spaces:
Running
Running
make checkpointing optional
Browse files
dalle_mini/modeling_bart_flax.py
CHANGED
@@ -252,8 +252,7 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
252 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
253 |
)
|
254 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
255 |
-
|
256 |
-
@nn.remat
|
257 |
def __call__(
|
258 |
self,
|
259 |
hidden_states: jnp.ndarray,
|
@@ -283,8 +282,9 @@ class FlaxBartEncoderLayerCollection(nn.Module):
|
|
283 |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
284 |
|
285 |
def setup(self):
|
|
|
286 |
self.layers = [
|
287 |
-
|
288 |
]
|
289 |
|
290 |
def __call__(
|
@@ -344,8 +344,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
344 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
345 |
)
|
346 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
347 |
-
|
348 |
-
@nn.remat
|
349 |
def __call__(
|
350 |
self,
|
351 |
hidden_states: jnp.ndarray,
|
@@ -394,8 +393,9 @@ class FlaxBartDecoderLayerCollection(nn.Module):
|
|
394 |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
395 |
|
396 |
def setup(self):
|
|
|
397 |
self.layers = [
|
398 |
-
|
399 |
]
|
400 |
|
401 |
def __call__(
|
|
|
252 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
253 |
)
|
254 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
255 |
+
|
|
|
256 |
def __call__(
|
257 |
self,
|
258 |
hidden_states: jnp.ndarray,
|
|
|
282 |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
283 |
|
284 |
def setup(self):
|
285 |
+
layer_module = nn.remat(FlaxBartEncoderLayer) if self.config.gradient_checkpointing else FlaxBartEncoderLayer
|
286 |
self.layers = [
|
287 |
+
layer_module(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers)
|
288 |
]
|
289 |
|
290 |
def __call__(
|
|
|
344 |
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
345 |
)
|
346 |
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
347 |
+
|
|
|
348 |
def __call__(
|
349 |
self,
|
350 |
hidden_states: jnp.ndarray,
|
|
|
393 |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
394 |
|
395 |
def setup(self):
|
396 |
+
layer_module = nn.remat(FlaxBartDecoderLayer) if self.config.gradient_checkpointing else FlaxBartDecoderLayer
|
397 |
self.layers = [
|
398 |
+
layer_module(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers)
|
399 |
]
|
400 |
|
401 |
def __call__(
|