Spaces:
Running
Running
feat(modeling): simplify abstract_init
Browse files
src/dalle_mini/model/modeling.py
CHANGED
@@ -334,7 +334,9 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
334 |
|
335 |
# init weights on CPU
|
336 |
if load_on_cpu:
|
337 |
-
init_fn = jax.jit(
|
|
|
|
|
338 |
else:
|
339 |
init_fn = self.init_weights
|
340 |
|
@@ -343,10 +345,11 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
343 |
# init the model weights only abstractly, eval_shape will return a pytree
|
344 |
# with the structure as weights but without any actual values, this will just contain
|
345 |
# the shape information. Weights need to be loaded later.
|
346 |
-
|
347 |
-
|
|
|
348 |
else:
|
349 |
-
random_params = init_fn(self.key, input_shape)
|
350 |
|
351 |
# save required_params as set
|
352 |
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
|
|
334 |
|
335 |
# init weights on CPU
|
336 |
if load_on_cpu:
|
337 |
+
init_fn = jax.jit(
|
338 |
+
self.init_weights, static_argnames="input_shape", backend="cpu"
|
339 |
+
)
|
340 |
else:
|
341 |
init_fn = self.init_weights
|
342 |
|
|
|
345 |
# init the model weights only abstractly, eval_shape will return a pytree
|
346 |
# with the structure as weights but without any actual values, this will just contain
|
347 |
# the shape information. Weights need to be loaded later.
|
348 |
+
random_params = jax.eval_shape(
|
349 |
+
init_fn, rng=self.key, input_shape=input_shape
|
350 |
+
)
|
351 |
else:
|
352 |
+
random_params = init_fn(rng=self.key, input_shape=input_shape)
|
353 |
|
354 |
# save required_params as set
|
355 |
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|