Code

Multi-GPU Minibatch Training Implementation

The train_batch function implements data-parallel training for a single minibatch across multiple GPUs. The procedure follows four sequential stages:

  1. Data splitting: The minibatch of features and labels is divided across the available devices using split_batch.
  2. Per-GPU forward pass and loss: Each GPU independently computes the model's output and the loss on its local data shard. The losses are summed per device.
  3. Per-GPU backpropagation: Backpropagation is performed separately on each GPU to compute local gradients.
  4. Gradient synchronization and update: An allreduce operation sums and broadcasts all gradients across GPUs within a torch.no_grad() context. Finally, each GPU independently updates its own copy of the model parameters using SGD, scaling the update by the full (unsplit) batch size.
def train_batch(X, y, device_params, devices, lr): X_shards, y_shards = split_batch(X, y, devices) ls = [loss(lenet(X_shard, device_W), y_shard).sum() for X_shard, y_shard, device_W in zip( X_shards, y_shards, device_params)] for l in ls: l.backward() with torch.no_grad(): for i in range(len(device_params[0])): allreduce([device_params[c][i].grad for c in range(len(devices))]) for param in device_params: d2l.sgd(param, lr, X.shape[0])

Because there are no cross-device dependencies within the computational graph for a single minibatch, the per-GPU computations execute in parallel automatically.

0

1

Updated 2026-05-19

Contributors are:

Who are from:

Tags

D2L

Dive into Deep Learning @ D2L