Code

Multi-Head Attention Implementation

The MultiHeadAttention class implements a multi-head attention mechanism by executing attention operations in parallel across multiple heads. It initializes linear transformations for queries, keys, values, and the final output. The forward method relies on transposition methods to manipulate tensor shapes. Specifically, the transpose_qkv method reshapes and permutes the tensors so that the attention heads are moved to the batch dimension, allowing the DotProductAttention module to process them simultaneously. After attention pooling, transpose_output reverses this manipulation to concatenate the heads back before the final linear projection.

class MultiHeadAttention(d2l.Module): """Multi-head attention.""" def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs): super().__init__() self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = nn.LazyLinear(num_hiddens, bias=bias) self.W_k = nn.LazyLinear(num_hiddens, bias=bias) self.W_v = nn.LazyLinear(num_hiddens, bias=bias) self.W_o = nn.LazyLinear(num_hiddens, bias=bias) def forward(self, queries, keys, values, valid_lens): queries = self.transpose_qkv(self.W_q(queries)) keys = self.transpose_qkv(self.W_k(keys)) values = self.transpose_qkv(self.W_v(values)) if valid_lens is not None: valid_lens = torch.repeat_interleave( valid_lens, repeats=self.num_heads, dim=0) output = self.attention(queries, keys, values, valid_lens) output_concat = self.transpose_output(output) return self.W_o(output_concat) @d2l.add_to_class(MultiHeadAttention) def transpose_qkv(self, X): """Transposition for parallel computation of multiple attention heads.""" X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1) X = X.permute(0, 2, 1, 3) return X.reshape(-1, X.shape[2], X.shape[3]) @d2l.add_to_class(MultiHeadAttention) def transpose_output(self, X): """Reverse the operation of transpose_qkv.""" X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2]) X = X.permute(0, 2, 1, 3) return X.reshape(X.shape[0], X.shape[1], -1)

0

1

Updated 2026-05-14

Contributors are:

Who are from:

Tags

D2L

Dive into Deep Learning @ D2L