Learn Before
Code
Implementation of the Vision Transformer Encoder Block
The vision Transformer encoder block is implemented by combining layer normalization, multi-head attention, and a specialized multilayer perceptron (MLP). Adhering to the pre-normalization design, the input tensor is first normalized before passing through the attention mechanism, and the result is added to the original input via a residual connection. This intermediate output is then normalized again before being processed by the MLP, followed by a second residual connection. This design ensures that the structural flow remains stable and the dimensions are preserved.
# PyTorch class ViTBlock(nn.Module): def __init__(self, num_hiddens, norm_shape, mlp_num_hiddens, num_heads, dropout, use_bias=False): super().__init__() self.ln1 = nn.LayerNorm(norm_shape) self.attention = d2l.MultiHeadAttention(num_hiddens, num_heads, dropout, use_bias) self.ln2 = nn.LayerNorm(norm_shape) self.mlp = ViTMLP(mlp_num_hiddens, num_hiddens, dropout) def forward(self, X, valid_lens=None): X = X + self.attention(*([self.ln1(X)] * 3), valid_lens) return X + self.mlp(self.ln2(X))
# JAX class ViTBlock(nn.Module): num_hiddens: int mlp_num_hiddens: int num_heads: int dropout: float use_bias: bool = False def setup(self): self.attention = d2l.MultiHeadAttention(self.num_hiddens, self.num_heads, self.dropout, self.use_bias) self.mlp = ViTMLP(self.mlp_num_hiddens, self.num_hiddens, self.dropout) @nn.compact def __call__(self, X, valid_lens=None, training=False): X = X + self.attention(*([nn.LayerNorm()(X)] * 3), valid_lens, training=training)[0] return X + self.mlp(nn.LayerNorm()(X), training=training)
0
1
Updated 2026-05-15
Tags
D2L
Dive into Deep Learning @ D2L