Learn Before
DotProductAttention Implementation
The DotProductAttention class implements scaled dot-product attention as a neural network module. During the forward pass, it receives queries, keys, and values as three-dimensional tensors with shapes (batch_size, n, d), (batch_size, m, d), and (batch_size, m, v) respectively, along with optional valid lengths for masking. The computation proceeds by first obtaining the key dimension from the last axis of the queries tensor. The raw attention scores are then computed using batch matrix multiplication of queries with the transposed keys, yielding a score tensor of shape (batch_size, n, m), which is divided by for scaling. A masked_softmax operation converts these scaled scores into normalized attention weights, enforcing any valid-length constraints. Finally, dropout is applied to the attention weights for regularization, and a second batch matrix multiplication with the values produces the output of shape (batch_size, n, v).
class DotProductAttention(nn.Module): """Scaled dot product attention.""" def __init__(self, dropout): super().__init__() self.dropout = nn.Dropout(dropout) def forward(self, queries, keys, values, valid_lens=None): d = queries.shape[-1] scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d) self.attention_weights = masked_softmax(scores, valid_lens) return torch.bmm(self.dropout(self.attention_weights), values)
0
1
Tags
D2L
Dive into Deep Learning @ D2L
Related
Causal Attention Input Structure
Causal Attention Mask Matrix Definition
Causal Attention Weight Matrix Calculation
An engineer is implementing an attention mechanism where the output is a weighted sum of Value vectors, with weights determined by a Softmax function applied to scores. They observe that as the dimension (
d) of the Query and Key vectors increases, the attention weights become extremely concentrated on a single position (e.g.,[0.01, 0.98, 0.01]), causing training instability. The scores are derived from the dot product of Query (Q) and Key (K) matrices. What is the most likely cause of this issue?Attention Mechanism Misapplication in Summarization
Analyzing the Role of the Mask in Attention
Selecting an Attention Design for Long-Context, Low-Latency Inference
Diagnosing and Redesigning Attention for a Long-Context, Cost-Constrained LLM Service
Choosing an Attention Stack for a Regulated, Long-Document Review Assistant
Attention Redesign for a Long-Context Customer-Support Copilot Under GPU Memory Pressure
Attention Redesign for a Multi-Tenant LLM with Long Context and Strict KV-Cache Budgets
Attention Architecture Choice for On-Device Meeting Summarization with 60k Context
You’re debugging an LLM inference service that mus...
You’re reviewing a design doc for a Transformer at...
Your team is deploying a chat-based LLM that must ...
You’re leading an LLM platform team that must supp...
Variance Control in Dot Product Attention
DotProductAttention Implementation