Learn Before
Concept
JAX lax.scan in RNN Iteration
When implementing recurrent architectures like Long Short-Term Memory (LSTM) networks in JAX, using a standard Python for-loop to iterate over sequence time steps causes extremely long Just-In-Time (JIT) compilation times on the first run. To bypass this performance bottleneck, JAX provides the jax.lax.scan utility transformation. It accepts an initial carry state and an input array, efficiently scanning across the leading axis of the inputs to update the state and yield stacked outputs.
0
1
Updated 2026-05-14
Tags
D2L
Dive into Deep Learning @ D2L