Example
Example of Patch Embedding Output Shape
To verify the output shape of a patch embedding, an instance of the embedding module can be applied to a dummy input tensor. For example, given images with a height and width of , a patch size of , a batch size of , and a hidden vector length of , the expected output shape is , which simplifies to .
# PyTorch img_size, patch_size, num_hiddens, batch_size = 96, 16, 512, 4 patch_emb = PatchEmbedding(img_size, patch_size, num_hiddens) X = torch.zeros(batch_size, 3, img_size, img_size) d2l.check_shape(patch_emb(X), (batch_size, (img_size//patch_size)**2, num_hiddens)) # JAX img_size, patch_size, num_hiddens, batch_size = 96, 16, 512, 4 patch_emb = PatchEmbedding(img_size, patch_size, num_hiddens) X = jnp.zeros((batch_size, img_size, img_size, 3)) output, _ = patch_emb.init_with_output(d2l.get_key(), X) d2l.check_shape(output, (batch_size, (img_size//patch_size)**2, num_hiddens))
0
1
Updated 2026-05-15
Tags
D2L
Dive into Deep Learning @ D2L