Code
Training Loop in CNN-Based Style Transfer
The training loop for CNN-based style transfer continuously updates the synthesized image to minimize the combined loss. In each training epoch, the model extracts the current content and style features from the synthesized image, computes the total loss (a weighted sum of content, style, and total variation losses), and performs backpropagation to update the image pixels using an optimizer and a learning rate scheduler.
# PyTorch def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch): X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y) scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8) animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[10, num_epochs], legend=['content', 'style', 'TV'], ncols=2, figsize=(7, 2.5)) for epoch in range(num_epochs): trainer.zero_grad() contents_Y_hat, styles_Y_hat = extract_features( X, content_layers, style_layers) contents_l, styles_l, tv_l, l = compute_loss( X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram) l.backward() trainer.step() scheduler.step() if (epoch + 1) % 10 == 0: animator.axes[1].imshow(postprocess(X)) animator.add(epoch + 1, [float(sum(contents_l)), float(sum(styles_l)), float(tv_l)]) return X
0
1
Updated 2026-05-21
Tags
D2L
Dive into Deep Learning @ D2L