update layers.py for JAX deprecate "shape"
#4
by
zxyse
- opened
used generic tuple
instead of jax.core.NamedShape
used np to get total size of shape
used generic tuple
instead of jax.core.NamedShape
used np to get total size of shape