Learn Before
Code
Multi-GPU Minibatch Training Implementation
The train_batch function implements data-parallel training for a single minibatch across multiple GPUs. The procedure follows four sequential stages:
- Data splitting: The minibatch of features and labels is divided across the available devices using
split_batch. - Per-GPU forward pass and loss: Each GPU independently computes the model's output and the loss on its local data shard. The losses are summed per device.
- Per-GPU backpropagation: Backpropagation is performed separately on each GPU to compute local gradients.
- Gradient synchronization and update: An
allreduceoperation sums and broadcasts all gradients across GPUs within atorch.no_grad()context. Finally, each GPU independently updates its own copy of the model parameters using SGD, scaling the update by the full (unsplit) batch size.
def train_batch(X, y, device_params, devices, lr): X_shards, y_shards = split_batch(X, y, devices) ls = [loss(lenet(X_shard, device_W), y_shard).sum() for X_shard, y_shard, device_W in zip( X_shards, y_shards, device_params)] for l in ls: l.backward() with torch.no_grad(): for i in range(len(device_params[0])): allreduce([device_params[c][i].grad for c in range(len(devices))]) for param in device_params: d2l.sgd(param, lr, X.shape[0])
Because there are no cross-device dependencies within the computational graph for a single minibatch, the per-GPU computations execute in parallel automatically.
0
1
Updated 2026-05-19
Tags
D2L
Dive into Deep Learning @ D2L