05 / attention

The heart of the transformer

This is the chapter the kid is built around. Everything before this — vocabulary, embeddings, position — is preparation. Everything after is plumbing. Attention is the actual interesting machinery, and it's the only part of train.py that we wrote ourselves; the rest is scaffolding.

The idea: each position in the sequence gets to look back at every earlier position and decide how much attention to pay to each one. The model learns which past positions matter for predicting what comes next.

Mechanically, every position produces three vectors: a query ("what am I looking for?"), a key ("what do I offer?"), and a value("what would you take from me?"). Compare every query against every key to get attention weights; use those weights to mix the values.

train.pylines 73–96 — the only piece you write yourself
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key   = nn.Linear(N_EMBD, head_size, bias=False)
        self.query = nn.Linear(N_EMBD, head_size, bias=False)
        self.value = nn.Linear(N_EMBD, head_size, bias=False)
        # position t can only attend to positions 0..t (causal)
        self.register_buffer("mask", torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE)))
        self.head_size = head_size

    def forward(self, x):
        B, T, C = x.shape
        q = self.query(x)                                              # (B, T, head_size)
        k = self.key(x)                                                # (B, T, head_size)
        v = self.value(x)                                              # (B, T, head_size)
        scores = q @ k.transpose(-2, -1) / (self.head_size ** 0.5)     # (B, T, T)
        scores = scores.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        weights = F.softmax(scores, dim=-1)                            # rows sum to 1
        return weights @ v                                             # (B, T, head_size)

See attention happen

We ran the prompt ROMEO: To be through the model and captured the actual attention weights from every head and every layer. Pick a layer and a head, then click any position to see what that position is looking at.

What changes across layers:early layers (0, 1) tend to learn local patterns — "look at the previous letter" or "look at letters near you." Deeper layers (2, 3) build on those signals and attend in stranger, more diffuse ways. The kid invented every one of these patterns from scratch by getting better at predicting Shakespeare.
Next →
06

Wrap attention in a block, stack four