Model does not work with device set to `mps`
#2
by
akbir
- opened
Inference on M1 GPU (device=mps
) does not work.
Shapes end up being wrong here, not sure if this is a pytorch code or bad model implementation.
Can you share a snippet of code to reproduce?
This comment has been hidden
This comment has been hidden