update layers.py for JAX deprecate "shape"
Browse filesused generic `tuple` instead of `jax.core.NamedShape`
used np to get total size of shape
- whisper_jax/layers.py +2 -2
whisper_jax/layers.py
CHANGED
@@ -60,7 +60,7 @@ default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", o
|
|
60 |
# Temporary inlined JAX N-d initializer code
|
61 |
# TODO(levskaya): remove once new JAX release is out.
|
62 |
# ------------------------------------------------------------------------------
|
63 |
-
def _compute_fans(shape:
|
64 |
"""Inlined JAX `nn.initializer._compute_fans`."""
|
65 |
if isinstance(in_axis, int):
|
66 |
in_size = shape[in_axis]
|
@@ -70,7 +70,7 @@ def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1):
|
|
70 |
out_size = shape[out_axis]
|
71 |
else:
|
72 |
out_size = int(np.prod([shape[i] for i in out_axis]))
|
73 |
-
receptive_field_size = shape
|
74 |
fan_in = in_size * receptive_field_size
|
75 |
fan_out = out_size * receptive_field_size
|
76 |
return fan_in, fan_out
|
|
|
60 |
# Temporary inlined JAX N-d initializer code
|
61 |
# TODO(levskaya): remove once new JAX release is out.
|
62 |
# ------------------------------------------------------------------------------
|
63 |
+
def _compute_fans(shape: tuple, in_axis=-2, out_axis=-1):
|
64 |
"""Inlined JAX `nn.initializer._compute_fans`."""
|
65 |
if isinstance(in_axis, int):
|
66 |
in_size = shape[in_axis]
|
|
|
70 |
out_size = shape[out_axis]
|
71 |
else:
|
72 |
out_size = int(np.prod([shape[i] for i in out_axis]))
|
73 |
+
receptive_field_size = np.prod(shape) / in_size / out_size
|
74 |
fan_in = in_size * receptive_field_size
|
75 |
fan_out = out_size * receptive_field_size
|
76 |
return fan_in, fan_out
|