Yes, but is it really the same? Afaik the `unroll=n` parameter translates `n` iterations into a vanilla `for` loop which is then unrolled into sequential statements (in contrast to a JAX `fori` loop). There still is no loop on the accelerator, strictly speaking?
I think this is up to XLA to handle not Jax. The whole selling point in TF of the tf.function decorator (which uses XLA underneath as well) is that it fuses arithmetic to lower launch count.