update layers.py for JAX deprecate "shape"

#4

used generic tuple instead of jax.core.NamedShape
used np to get total size of shape

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment