Learn Before
Generic Optimization Training Function
The train_ch11 function provides a generic, from-scratch training harness for benchmarking different optimization algorithms on a linear regression model. It accepts five arguments: a trainer function (trainer_fn), optimizer states, a hyperparameters dictionary, a data iterator, and the feature dimensionality. The function initializes weight parameters from a normal distribution with mean and standard deviation , and a bias of . During each epoch, it iterates over minibatches, computes the mean squared loss, performs backpropagation, and calls trainer_fn to update the parameters. It periodically evaluates the full-dataset loss and records the cumulative training time. This reusable design allows any optimizer with a compatible call signature to be plugged in by simply passing a different trainer_fn. A companion function, train_concise_ch11, provides an equivalent harness that delegates to framework-level optimizer APIs instead of custom update functions.
0
1
Tags
D2L
Dive into Deep Learning @ D2L