Example

Example of Patch Embedding Output Shape

To verify the output shape of a patch embedding, an instance of the embedding module can be applied to a dummy input tensor. For example, given images with a height and width of 9696, a patch size of 1616, a batch size of 44, and a hidden vector length of 512512, the expected output shape is (4,(96//16)2,512)(4, (96 // 16)^2, 512), which simplifies to (4,36,512)(4, 36, 512).

# PyTorch img_size, patch_size, num_hiddens, batch_size = 96, 16, 512, 4 patch_emb = PatchEmbedding(img_size, patch_size, num_hiddens) X = torch.zeros(batch_size, 3, img_size, img_size) d2l.check_shape(patch_emb(X), (batch_size, (img_size//patch_size)**2, num_hiddens)) # JAX img_size, patch_size, num_hiddens, batch_size = 96, 16, 512, 4 patch_emb = PatchEmbedding(img_size, patch_size, num_hiddens) X = jnp.zeros((batch_size, img_size, img_size, 3)) output, _ = patch_emb.init_with_output(d2l.get_key(), X) d2l.check_shape(output, (batch_size, (img_size//patch_size)**2, num_hiddens))

0

1

Updated 2026-05-15

Contributors are:

Who are from:

Tags

D2L

Dive into Deep Learning @ D2L