Learn Before
Code
Implementation of MLP in Vision Transformers
The multilayer perceptron (MLP) for a vision Transformer encoder can be implemented using standard neural network modules in frameworks like PyTorch or JAX. The architecture consists of two dense layers. The first dense layer applies a linear transformation followed by a Gaussian Error Linear Unit (GELU) activation and a dropout layer for regularization. The second dense layer projects the features to the desired output dimension, and is subsequently followed by another dropout layer.
# PyTorch class ViTMLP(nn.Module): def __init__(self, mlp_num_hiddens, mlp_num_outputs, dropout=0.5): super().__init__() self.dense1 = nn.LazyLinear(mlp_num_hiddens) self.gelu = nn.GELU() self.dropout1 = nn.Dropout(dropout) self.dense2 = nn.LazyLinear(mlp_num_outputs) self.dropout2 = nn.Dropout(dropout) def forward(self, x): return self.dropout2(self.dense2(self.dropout1(self.gelu( self.dense1(x)))))
# JAX class ViTMLP(nn.Module): mlp_num_hiddens: int mlp_num_outputs: int dropout: float = 0.5 @nn.compact def __call__(self, x, training=False): x = nn.Dense(self.mlp_num_hiddens)(x) x = nn.gelu(x) x = nn.Dropout(self.dropout, deterministic=not training)(x) x = nn.Dense(self.mlp_num_outputs)(x) x = nn.Dropout(self.dropout, deterministic=not training)(x) return x
0
1
Updated 2026-05-15
Tags
D2L
Dive into Deep Learning @ D2L