Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

There is an ‘unroll’ parameter in scan that lets you control how many iterations of the loop are fused into a single kernel.


Yes, but is it really the same? Afaik the `unroll=n` parameter translates `n` iterations into a vanilla `for` loop which is then unrolled into sequential statements (in contrast to a JAX `fori` loop). There still is no loop on the accelerator, strictly speaking?


I think this is up to XLA to handle not Jax. The whole selling point in TF of the tf.function decorator (which uses XLA underneath as well) is that it fuses arithmetic to lower launch count.


there's a paper about it that I just found, enjoy https://arxiv.org/pdf/2301.13062




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: