Learn Before
Code
Implementation of Patch Embedding in Vision Transformers
In deep learning frameworks, patch embedding can be implemented as a neural network module. The core mechanism is a 2D convolution layer where both the kernel size and stride are set to the desired patch size. The output of the convolution is then flattened spatially and transposed to produce a sequence of patch representations.
# PyTorch class PatchEmbedding(nn.Module): def __init__(self, img_size=96, patch_size=16, num_hiddens=512): super().__init__() def _make_tuple(x): if not isinstance(x, (list, tuple)): return (x, x) return x img_size, patch_size = _make_tuple(img_size), _make_tuple(patch_size) self.num_patches = (img_size[0] // patch_size[0]) * ( img_size[1] // patch_size[1]) self.conv = nn.LazyConv2d(num_hiddens, kernel_size=patch_size, stride=patch_size) def forward(self, X): # Output shape: (batch size, no. of patches, no. of channels) return self.conv(X).flatten(2).transpose(1, 2) # JAX class PatchEmbedding(nn.Module): img_size: int = 96 patch_size: int = 16 num_hiddens: int = 512 def setup(self): def _make_tuple(x): if not isinstance(x, (list, tuple)): return (x, x) return x img_size, patch_size = _make_tuple(self.img_size), _make_tuple(self.patch_size) self.num_patches = (img_size[0] // patch_size[0]) * ( img_size[1] // patch_size[1]) self.conv = nn.Conv(self.num_hiddens, kernel_size=patch_size, strides=patch_size, padding='SAME') def __call__(self, X): # Output shape: (batch size, no. of patches, no. of channels) X = self.conv(X) return X.reshape((X.shape[0], -1, X.shape[3]))
0
1
Updated 2026-05-15
Tags
D2L
Dive into Deep Learning @ D2L