Learn Before
Code
Detaching Computation in Deep Learning Frameworks
Modern deep learning libraries offer specific built-in functions to detach variables from a computational graph, effectively stopping the backward flow of gradients. In PyTorch and MXNet, a tensor's computational history can be erased using the .detach() method. In JAX, this provenance wiping is achieved by wrapping the operation with the jax.lax.stop_gradient() function. Similarly, TensorFlow provides the tf.stop_gradient() function to accomplish the exact same detachment.
0
1
Updated 2026-05-02
Tags
D2L
Dive into Deep Learning @ D2L