Apple supports JAX[0] along with PyTorch[1] and Tensorflow[2] on macOS with both Apple Silicon and AMD GPUs (on x86 Macs). Although, the perf isn't great. I write most of my experimental ML code in JAX on an M2 Macbook Air and then move to a proper multi-GPU Linux box for full training runs.
[0]: https://developer.apple.com/metal/jax/
[1]: https://developer.apple.com/metal/pytorch/
[2]: https://developer.apple.com/metal/tensorflow-plugin/