Learn Before
Concept

JAX lax.scan in RNN Iteration

When implementing recurrent architectures like Long Short-Term Memory (LSTM) networks in JAX, using a standard Python for-loop to iterate over sequence time steps causes extremely long Just-In-Time (JIT) compilation times on the first run. To bypass this performance bottleneck, JAX provides the jax.lax.scan utility transformation. It accepts an initial carry state and an input array, efficiently scanning across the leading axis of the inputs to update the state and yield stacked outputs.

0

1

Updated 2026-05-14

Contributors are:

Who are from:

Tags

D2L

Dive into Deep Learning @ D2L