Concept

JAX Gradient Tracking and Transformation

Unlike other deep learning frameworks, JAX does not maintain gradients within the parameter objects themselves because the parameters and the network architecture are decoupled. Instead of tracking gradients over neural network parameters directly, JAX allows the user to express their computation as a pure Python function and applies the grad transformation to compute derivatives.

0

1

Updated 2026-05-08

Contributors are:

Who are from:

Tags

D2L

Dive into Deep Learning @ D2L