Code

Training Loop in CNN-Based Style Transfer

The training loop for CNN-based style transfer continuously updates the synthesized image to minimize the combined loss. In each training epoch, the model extracts the current content and style features from the synthesized image, computes the total loss (a weighted sum of content, style, and total variation losses), and performs backpropagation to update the image pixels using an optimizer and a learning rate scheduler.

# PyTorch def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch): X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y) scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8) animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[10, num_epochs], legend=['content', 'style', 'TV'], ncols=2, figsize=(7, 2.5)) for epoch in range(num_epochs): trainer.zero_grad() contents_Y_hat, styles_Y_hat = extract_features( X, content_layers, style_layers) contents_l, styles_l, tv_l, l = compute_loss( X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram) l.backward() trainer.step() scheduler.step() if (epoch + 1) % 10 == 0: animator.axes[1].imshow(postprocess(X)) animator.add(epoch + 1, [float(sum(contents_l)), float(sum(styles_l)), float(tv_l)]) return X

0

1

Updated 2026-05-21

Contributors are:

Who are from:

Tags

D2L

Dive into Deep Learning @ D2L