Code

DotProductAttention Implementation

The DotProductAttention class implements scaled dot-product attention as a neural network module. During the forward pass, it receives queries, keys, and values as three-dimensional tensors with shapes (batch_size, n, d), (batch_size, m, d), and (batch_size, m, v) respectively, along with optional valid lengths for masking. The computation proceeds by first obtaining the key dimension dd from the last axis of the queries tensor. The raw attention scores are then computed using batch matrix multiplication of queries with the transposed keys, yielding a score tensor of shape (batch_size, n, m), which is divided by d\sqrt{d} for scaling. A masked_softmax operation converts these scaled scores into normalized attention weights, enforcing any valid-length constraints. Finally, dropout is applied to the attention weights for regularization, and a second batch matrix multiplication with the values produces the output of shape (batch_size, n, v).

class DotProductAttention(nn.Module): """Scaled dot product attention.""" def __init__(self, dropout): super().__init__() self.dropout = nn.Dropout(dropout) def forward(self, queries, keys, values, valid_lens=None): d = queries.shape[-1] scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d) self.attention_weights = masked_softmax(scores, valid_lens) return torch.bmm(self.dropout(self.attention_weights), values)

0

1

Updated 2026-05-14

Contributors are:

Who are from:

Tags

D2L

Dive into Deep Learning @ D2L

Related