Learn Before
Concept
JAX Gradient Tracking and Transformation
Unlike other deep learning frameworks, JAX does not maintain gradients within the parameter objects themselves because the parameters and the network architecture are decoupled. Instead of tracking gradients over neural network parameters directly, JAX allows the user to express their computation as a pure Python function and applies the grad transformation to compute derivatives.
0
1
Updated 2026-05-08
Tags
D2L
Dive into Deep Learning @ D2L