Learn Before
Concise Optimization Training Function
The train_concise_ch11 function is the high-level-API counterpart of the from-scratch train_ch11 harness. Instead of manually managing parameter tensors and calling a custom optimizer function, it constructs a single-layer linear model using nn.Sequential(nn.Linear(5, 1)), initializes its weights from a normal distribution with standard deviation , and instantiates a framework-provided optimizer (e.g., torch.optim.SGD) by passing net.parameters() along with a hyperparameters dictionary. The loss is computed with nn.MSELoss(reduction='none'), which returns per-element squared errors without the factor. During each epoch the function iterates over minibatches, zeroes the gradients via optimizer.zero_grad(), performs a forward pass, reshapes the labels to match the output, computes and averages the loss, calls l.mean().backward() to obtain gradients, and advances the optimizer with optimizer.step(). Loss is periodically evaluated on the full dataset and divided by to align with the half-squared-error convention. This concise version produces identical convergence behavior to the from-scratch harness while requiring substantially less boilerplate code.
def train_concise_ch11(trainer_fn, hyperparams, data_iter, num_epochs=4): net = nn.Sequential(nn.Linear(5, 1)) def init_weights(module): if type(module) == nn.Linear: torch.nn.init.normal_(module.weight, std=0.01) net.apply(init_weights) optimizer = trainer_fn(net.parameters(), **hyperparams) loss = nn.MSELoss(reduction='none') animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[0, num_epochs], ylim=[0.22, 0.35]) n, timer = 0, d2l.Timer() for _ in range(num_epochs): for X, y in data_iter: optimizer.zero_grad() out = net(X) y = y.reshape(out.shape) l = loss(out, y) l.mean().backward() optimizer.step() n += X.shape[0] if n % 200 == 0: timer.stop() animator.add(n/X.shape[0]/len(data_iter), (d2l.evaluate_loss(net, data_iter, loss) / 2,)) timer.start() print(f'loss: {animator.Y[0][-1]:.3f}, ' f'{timer.sum()/num_epochs:.3f} sec/epoch')
0
1
Tags
D2L
Dive into Deep Learning @ D2L