How the KV Cache works
# July 26, 2025
Oh attention, you really are all that we need.
At least for accuracy. For performance? You're pretty damn slow. And as we're collectively creeping into context windows sized 1M and beyond, that fundamental latency can really bite us.
At this point no company (frontier lab or startup alike) is doing transformer inference with no optimizations. It's just not efficient enough. Frameworks like flash-attention
and unsloth
are adopted almost without thinking when you're deploying a fresh model.
The KV cache is another one of those defacto standards. Here's how it works.
Brief attention intuition
There are a million visualized transformer flows, many of which are quite good. So I'm not going to bore you with a full breakdown of the attention internals. But I will share the intuition that I have when it comes to attention. The examples I'm using here will be focused on text because I find it the most interpretable - but under the hood the same is true for all transformers whether you're looking at pixels, compressed waveform data, or whatever.
Here are the emergent properties that we want:
- Full quality recall of any previous word in the sequence
- The ability to dynamically recall different things based on where we are in the sequence.1
- Intermediary matrix byproducts can increase with the size of the sequence, but they eventually need to be projected down to the same standard size regardless of timestep
In autoregressive generation, the output of each "timestep" of inference is the next token of the sequence that we want to generate. Specifically it's a vector that basically serves as a dictionary lookup: for each word, here's the probability of it occurring next. From there we can just take an argmax2 to get the actual token we pipe in next.
This probability vector in turn has to come from another vector, or a matrix multiplication that itself is compressed into a 1D vector. The point is that at some point we need to get everything into a single vector so we can project it into the final probabilities. How do we get this intermediary vector?
Well to get a good vector, we need to be able to look-back on every token. If a pronoun is vague (like him
) we need the model to be able resolve who we're talking about3. So for each new token that we generate, we want some score that denotes how much we actually care about each other token. Our final scores can be graphed as some heatmap.
Note that "he" is highly attending to "John". Within this attention matrix:
- Row i shows how token i attends to all tokens 0...i
- Future tokens are not visible during generation (this gives our causal mask the desired lookback only property)
- Each row sums to 1.0: Attention weights are normalized via softmax
But this doesn't seem very much like a vector at this stage. So how do we compress the current timestep's look back into a single vector to be able to project it into the output space?
QKV
Now comes some information retrieval inspiration.4 When we're searching for data in a search engine, we have some query text that we input and it delivers some results. Internally most smart search engines don't just search for exact text, they're chunking your text into phrases, getting rid of suffixes, weighting by their rareness in the corpora, etc.
There are three components to this: your query
, the result values
, and the processed text. Since we're already using some hashmap language, let's call that processed text the keys
. In some vague way we want the query to look up the values by their keys.
Okay back to attention. We can apply this same construct, but instead of actually searching by words or implementing lemma chunking, we can just let all of these concepts be vectors. For each token going through this attention layer you'll have a QKV tuple of values.
For each new token that we have as input, we'll make that token into a query
vector. We'll then dot-product it against each of the given keys to get a single scalar for how much they matter. Then we'll just weight the values
and sum them into one vector. This effectively fuses the information contained within all the value vectors into one.5
KV Cache
All of this seems pretty reasonable for the prediction of a single token. You get the KV embeddings, you do the matrix multiplications with Q, blend the values, and get a single vector as output.
The challenge comes when you're trying to generate multiple tokens. In order to actually compute the query
, key
and value
vectors for each input, we have to do a pretty sizable matrix multiplication. Even if we want the QKV vectors to be pretty small (most models use 64 or 128), to get these we have to project some input states into 64 dimensions. GPT-3 used a 12k sized embedding vector, which means somewhere in your pipeline you minimally have a 64*12288 transformation.
And we have to do that for each token. That's a lot of matrix multiplying to do. Because for each token - even ones long ago - you're having to calculate their K&V values from scratch.
A naive for loop for generation
Let me put some numbers to this. Say you're generating a 1000-token response and your model has a 128 dimension QKV. For each new token, you need to:
- Calculate the query vector for the current token:
Q_current = token_embedding × W_q
- Calculate attention scores against all previous tokens:
scores = Q_current × K_1, Q_current × K_2, ..., Q_current × K_previous
- Weight and sum all the value vectors:
output = softmax(scores) × [V_1, V_2, ..., V_previous]
Without caching, at token 500 you're doing 500 key-value computations. At token 999, you're doing 999. The total operations grow quadratically - O(n²) where n is your sequence length. For that 1000-token sequence, you end up doing roughly 500,000 redundant calculations.6
And that's just with one attention head. Most architectures have at least 32. Which means you're basically doing this whole thing 31 additional times.
But don't they change?
At each timestep we're feeding something different into the transformer; we're adding the additional latest token. Wouldn't this affect some of the internal state of the model and therefore result in different keys and values? So there's actually no shortcut.
No - thankfully.7 Because our attention mapping is set up to only be able to attend to the tokens that happened before the current token, the outputs that it internally assigns to the K/V values are always the same given the same prefix statement. Since in this autoregressive setting we're just appending the next selected token, we're not changing a thing in our past sequence.
Once we've computed K_1
and V_1
for the first token, those vectors never change. They're deterministic based on that token's embedding and the learned weight matrices. Because of this there's no need to recalculate these K/V values over and over again. The only thing that needs to be dynamic are the queries.
Caching them
We can shortcut this generation by recognizing that we're just re-calculating the same key and value vectors. The actual keys and values of historical tokens aren't going to change over time.
So instead of recomputing them, we just... don't. We cache them.
Here's what happens step by step:
Token 1: Compute Q₁, K₁, V₁. Do attention. Cache K₁, V₁.
Token 2: Compute Q₂, K₂, V₂. Look up cached K₁, V₁. Do attention against [K₁, K₂]. Cache K₂, V₂.
Token 3: Compute Q₃, K₃, V₃. Look up cached [K₁, V₁, K₂, V₂]. Do attention against [K₁, K₂, K₃]. Cache K₃, V₃.
...
Token n: Compute Qₙ, Kₙ, Vₙ. Look up all cached KV pairs. Do attention. Cache Kₙ, Vₙ.
This eliminates the redundant K/V projections, bringing that part from O(n²) to O(n), though attention scoring and aggregation remain O(n²) overall. We're trading memory for speed - storing those cached vectors in exchange for not recalculating them.
Picture your KV cache as two growing matrices. As you generate each token, you're essentially appending new columns:
K_cache = [K₁ | K₂ | K₃ | ... | Kₙ]
V_cache = [V₁ | V₂ | V₃ | ... | Vₙ]
Each column is a vector of size head_dim
. For a larger model like GPT-4 generating 2000 tokens.
assume: seq_len = 2000 layers = 32 heads = 32 head_dim = 128
total KV ≈ 2 (K,V) × 2000 × 32 × 32 × 128 × 2 bytes = 1.05 GB total
Conclusion
In addition to Speculative Decoding, the KV cache is one of those numerically lossless computations. It's just simplifying the math; there are no accuracy tradeoffs to consider with adopting it.
The only consideration is GPU memory and whether we have enough overhead to be able to cache these K/V vectors. In most settings this is a tradeoff we're more than willing to make. Indeed, if we didn't make it, models would be so unbearably slow that even if we have to swap these vectors to disk they'd probably still be faster than having to do all the matrix multiplications again.
There are additional optimizations that:
- Share KV values across heads (or groups of heads), which can cut KV size proportionally (see MQA & GQA)
- Paginate memory like vLLM which makes better use of the actual memory allocation on chips
- Restrict attention to a recent window and sink tokens, which helps keep computation/memory a bit more reasonable for longer sequence lengths
But these are not nearly as universal choices as the vanilla KV Cache is.
I find the most heavily adopted architecture and inference techniques tend to subscribe to this same philosophy. They are relatively unopinionated and feel like a free gain with no actual tradeoffs. Because they kind of are. When the field is moving so quickly, researchers are very hesitant to lock themselves into more opinionated constructs that are just going to be thrown away in a year's time. Math is much harder to throw away in the same way.
-
It's funny to me that transformers got their start on relatively short-form content (sentences or paragraphs). Most datapoints were still well within the sequence length band that we could put into a GRU. At the time it was motivated mostly by the foundational math of not needing forget gates - I don't think anyone actually saw million token context windows coming on the horizon. ↩
-
Or beam searched, top-p, top-k, or your sampling scheme of choice here. ↩
-
Until pretty recently this task of named entity disambiguation was a pretty SOTA measurement of language modeling performance. ↩
-
And I'm talking classic information retrieval, like what Google was doing in 2005. Not the fancy ML semantic search techniques that we see today. ↩
-
It seems unintuitive that fusing vectors would mean anything. But you have to build model architectures with enough numerical connectivity that the values can affect one another, and it's the training optimizer that gives these things some meaning, even if it's uninterpretable to us. We're building the track of data flow. The optimizer chooses what goes on the rails. ↩
-
The exact number depends on how you count operations, but the quadratic growth is the key insight. If you're generating n tokens, you do 1 + 2 + 3 + ... + n operations, which equals n(n+1)/2. For n=1000, that's about 500,500 operations instead of just 1000. ↩
-
Otherwise this would be an awfully silly blog post. ↩