Learn Before
Code Implementation of Learning a Convolution Kernel
A convolutional kernel can be learned from data by comparing the output of a convolutional layer to a target tensor and using gradient descent to update the kernel weights. In the implementations below across PyTorch, MXNet, JAX, and TensorFlow, we initialize a 2D convolutional layer with random weights and ignore the bias. During each training iteration, we compute the squared error loss between the predicted output Y_hat and the target Y, compute the gradients, and update the kernel weights using a specified learning rate lr.
PyTorch Implementation:
# Construct a two-dimensional convolutional layer with 1 output channel and a # kernel of shape (1, 2). For the sake of simplicity, we ignore the bias here conv2d = nn.LazyConv2d(1, kernel_size=(1, 2), bias=False) # The two-dimensional convolutional layer uses four-dimensional input and # output in the format of (example, channel, height, width), where the batch # size (number of examples in the batch) and the number of channels are both 1 X = X.reshape((1, 1, 6, 8)) Y = Y.reshape((1, 1, 6, 7)) lr = 3e-2 # Learning rate for i in range(10): Y_hat = conv2d(X) l = (Y_hat - Y) ** 2 conv2d.zero_grad() l.sum().backward() # Update the kernel conv2d.weight.data[:] -= lr * conv2d.weight.grad if (i + 1) % 2 == 0: print(f'epoch {i + 1}, loss {l.sum():.3f}')
MXNet Implementation:
# Construct a two-dimensional convolutional layer with 1 output channel and a # kernel of shape (1, 2). For the sake of simplicity, we ignore the bias here conv2d = nn.Conv2D(1, kernel_size=(1, 2), use_bias=False) conv2d.initialize() # The two-dimensional convolutional layer uses four-dimensional input and # output in the format of (example, channel, height, width), where the batch # size (number of examples in the batch) and the number of channels are both 1 X = X.reshape(1, 1, 6, 8) Y = Y.reshape(1, 1, 6, 7) lr = 3e-2 # Learning rate for i in range(10): with autograd.record(): Y_hat = conv2d(X) l = (Y_hat - Y) ** 2 l.backward() # Update the kernel conv2d.weight.data()[:] -= lr * conv2d.weight.grad() if (i + 1) % 2 == 0: print(f'epoch {i + 1}, loss {float(l.sum()):.3f}')
JAX Implementation:
# Construct a two-dimensional convolutional layer with 1 output channel and a # kernel of shape (1, 2). For the sake of simplicity, we ignore the bias here conv2d = nn.Conv(1, kernel_size=(1, 2), use_bias=False, padding='VALID') # The two-dimensional convolutional layer uses four-dimensional input and # output in the format of (example, height, width, channel), where the batch # size (number of examples in the batch) and the number of channels are both 1 X = X.reshape((1, 6, 8, 1)) Y = Y.reshape((1, 6, 7, 1)) lr = 3e-2 # Learning rate params = conv2d.init(jax.random.PRNGKey(d2l.get_seed()), X) def loss(params, X, Y): Y_hat = conv2d.apply(params, X) return ((Y_hat - Y) ** 2).sum() for i in range(10): l, grads = jax.value_and_grad(loss)(params, X, Y) # Update the kernel params = jax.tree_map(lambda p, g: p - lr * g, params, grads) if (i + 1) % 2 == 0: print(f'epoch {i + 1}, loss {l:.3f}')
TensorFlow Implementation:
# Construct a two-dimensional convolutional layer with 1 output channel and a # kernel of shape (1, 2). For the sake of simplicity, we ignore the bias here conv2d = tf.keras.layers.Conv2D(1, (1, 2), use_bias=False) # The two-dimensional convolutional layer uses four-dimensional input and # output in the format of (example, height, width, channel), where the batch # size (number of examples in the batch) and the number of channels are both 1 X = tf.reshape(X, (1, 6, 8, 1)) Y = tf.reshape(Y, (1, 6, 7, 1)) lr = 3e-2 # Learning rate Y_hat = conv2d(X) for i in range(10): with tf.GradientTape(watch_accessed_variables=False) as g: g.watch(conv2d.weights[0]) Y_hat = conv2d(X) l = (abs(Y_hat - Y)) ** 2 # Update the kernel update = tf.multiply(lr, g.gradient(l, conv2d.weights[0])) weights = conv2d.get_weights() weights[0] = conv2d.weights[0] - update conv2d.set_weights(weights) if (i + 1) % 2 == 0: print(f'epoch {i + 1}, loss {tf.reduce_sum(l):.3f}')
After iterations, the error drops to a small value, and the learned kernel tensor will be remarkably close to the optimal target weights.
0
1
Tags
D2L
Dive into Deep Learning @ D2L