Code

Implementation of the Vision Transformer

A complete vision Transformer can be implemented by assembling patch embeddings, learnable positional embeddings, a stack of encoder blocks, and a classification head. The implementation involves creating a sequential forward pass: extracting patch embeddings, concatenating a learnable <cls> token, adding learnable positional embeddings, applying dropout, passing the sequence through the stacked blocks, and finally using the <cls> token's output state for classification.

# PyTorch class ViT(d2l.Classifier): """Vision Transformer.""" def __init__(self, img_size, patch_size, num_hiddens, mlp_num_hiddens, num_heads, num_blks, emb_dropout, blk_dropout, lr=0.1, use_bias=False, num_classes=10): super().__init__() self.save_hyperparameters() self.patch_embedding = PatchEmbedding( img_size, patch_size, num_hiddens) self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens)) num_steps = self.patch_embedding.num_patches + 1 # Add the cls token # Positional embeddings are learnable self.pos_embedding = nn.Parameter( torch.randn(1, num_steps, num_hiddens)) self.dropout = nn.Dropout(emb_dropout) self.blks = nn.Sequential() for i in range(num_blks): self.blks.add_module(f"{i}", ViTBlock( num_hiddens, num_hiddens, mlp_num_hiddens, num_heads, blk_dropout, use_bias)) self.head = nn.Sequential(nn.LayerNorm(num_hiddens), nn.Linear(num_hiddens, num_classes)) def forward(self, X): X = self.patch_embedding(X) X = torch.cat((self.cls_token.expand(X.shape[0], -1, -1), X), 1) X = self.dropout(X + self.pos_embedding) for blk in self.blks: X = blk(X) return self.head(X[:, 0])
# JAX class ViT(d2l.Classifier): """Vision Transformer.""" img_size: int patch_size: int num_hiddens: int mlp_num_hiddens: int num_heads: int num_blks: int emb_dropout: float blk_dropout: float lr: float = 0.1 use_bias: bool = False num_classes: int = 10 training: bool = False def setup(self): self.patch_embedding = PatchEmbedding(self.img_size, self.patch_size, self.num_hiddens) self.cls_token = self.param('cls_token', nn.initializers.zeros, (1, 1, self.num_hiddens)) num_steps = self.patch_embedding.num_patches + 1 # Add the cls token # Positional embeddings are learnable self.pos_embedding = self.param('pos_embed', nn.initializers.normal(), (1, num_steps, self.num_hiddens)) self.blks = [ViTBlock(self.num_hiddens, self.mlp_num_hiddens, self.num_heads, self.blk_dropout, self.use_bias) for _ in range(self.num_blks)] self.head = nn.Sequential([nn.LayerNorm(), nn.Dense(self.num_classes)]) @nn.compact def __call__(self, X): X = self.patch_embedding(X) X = jnp.concatenate((jnp.tile(self.cls_token, (X.shape[0], 1, 1)), X), 1) X = nn.Dropout(emb_dropout, deterministic=not self.training)(X + self.pos_embedding) for blk in self.blks: X = blk(X, training=self.training) return self.head(X[:, 0])

0

1

Updated 2026-05-15

Contributors are:

Who are from:

Tags

D2L

Dive into Deep Learning @ D2L