PyTorch is a generationally important project. I've never seen a tool that is so inline with how researchers learn and internalize a subject. Teaching Machine Learning before and after its adoption has been a completely different experience. Never can be said enough how cool it is that Meta fosters and supports it.
This is exactly why I gravitated to it so quickly. The first time I looked at pytorch code it was immediately obvious what the abstractions meant and how to use them to write a model architecture.
Jax looks like something completely different to me. Maybe I’m dumb and probably not the target audience, but it occurs to me that very few people are. When I read about using Jax, I find recommendations for a handful of other libraries that make it more useable. Which of those I choose to learn is not entirely obvious because they all seem to create a very fragmented ecosystem with code that isn’t portable.
I’m still not sure why I’d spend my time learning Jax, especially when it seems like most of the complaints from the author don’t really separate out training and inference, which don’t necessarily need to occur from the same framework.
Honestly, when I turn to JAX, I generally do it without a framework. It’s like asking for a framework to wrap numpy to me. Just JAX plus optax is sufficient for me in the cases I turn to it.
Python isn't hated AFAICT, though people will profess to hating building large projects in it (myself included), but many of those people also love it for shorter programs and scripts.
Additionally, nowadays it also has Java and C++ bindings to the same native libraries, so others can enjoy performance without having to rewrite their research afterwards.
Viva PyTorch! (Jax rocks too)