深入 KV Cache 的运作过程

KV Cache 的工作主要发生在 Transformer 模型的 Decoder Block 中,特别是其多头自注意力(Multi-Head Self-Attention)层

整个推理过程通常分为两个阶段:预填充阶段 (Pre-fill)解码阶段 (Decoding)

1. 预填充阶段 (Pre-fill / Prompt Encoding)

这个阶段处理用户的整个输入 Prompt(比如 100 个 token)。

步骤详解:

  1. 输入与投影: 完整的输入序列 [t1,t2,,t100][t_1, t_2, \dots, t_{100}] 进入 Transformer 的每一层。

  2. 计算 Q,K,VQ, K, V 在自注意力层中,模型使用权重矩阵 WQ,WK,WVW_Q, W_K, W_V 对输入向量进行投影,一次性计算出所有 token 的 QQKKVV 矩阵:

    • Kfull=[K1,K2,,K100]K_{\text{full}} = [K_1, K_2, \dots, K_{100}]

    • Vfull=[V1,V2,,V100]V_{\text{full}} = [V_1, V_2, \dots, V_{100}]

  3. 自注意力计算: 模型计算完整的自注意力,生成 Prompt 的编码表示。

  4. 缓存 KKVV 关键步骤。 计算得到的

    KfullK_{\text{full}}

    VfullV_{\text{full}}

    矩阵被存储到 GPU 显存中的 KV Cache 区域。

特点: 这是一个高度并行化的过程(所有 token 同时计算),速度快,但计算量大(二次复杂度 O(L2)O(L^2))。

2. 解码阶段 (Decoding / Token Generation)

这个阶段是模型逐个生成新的输出 token。假设模型现在要生成第 101 个 token 。

步骤详解:

  1. 输入 Q: 模型的输入是上一步生成的最后一个 token

    t100t_{100}

    • 模型计算

      t101t_{101}

      对应的

      QnewQ_{\text{new}}

      向量。
  2. 获取 K和 V: 模型计算

    t101t_{101}

    对应的

    KnewK_{\text{new}}

    VnewV_{\text{new}}

    向量。

    • 模型从 KV Cache读取上一个阶段存储的所有

      KcacheK_{\text{cache}}

      VcacheV_{\text{cache}}

  3. 拼接 K 和 V:

    • 将新的

      KnewK_{\text{new}}

      向量追加到缓存的

      KcacheK_{\text{cache}}

      后面,形成完整的 K’ 矩阵:

      K=[Kcache,Knew]K' = [K_{\text{cache}}, K_{\text{new}}]

    • 同样,将

      VnewV_{\text{new}}

      追加到缓存的

      VcacheV_{\text{cache}}

      后面,形成完整的 V’ 矩阵:

      V=[Vcache,Vnew]V' = [V_{\text{cache}}, V_{\text{new}}]

  4. 注意力计算: 模型使用新的

    QnewQ_{\text{new}}

    与拼接后的 K’ 和 V’ 进行注意力计算:

    Attentionnew=Softmax(Qnew(K)Tdk)V\text{Attention}_{\text{new}} = \text{Softmax}\left(\frac{Q_{\text{new}} (K')^T}{\sqrt{d_k}}\right) V'

  5. 生成下一个 Token: 注意力输出经过后续的 Feed-Forward 层和 Softmax 预测,生成下一个 token

    t101t_{101}

  6. 更新缓存: KnewK_{\text{new}}VnewV_{\text{new}} 向量被永久保存并追加到 KV Cache 中,供下一个

    tokent102token t_{102}

    使用。

特点: 这是一个串行自回归的过程(一次只能生成一个 token)。最重要的是,**每次计算 K’ 和 V’ 时,不需要重新计算 Prompt 部分的 K/V。**计算量大大降低(线性复杂度 O(L)),但因为是串行的,总耗时依赖于生成长度。

总结:KV Cache 的本质和优势

维度 无 KV Cache 使用 KV Cache
计算 K/VK/V 每生成一个 token,都需要重新计算所有先前 token 的 K/V。 只计算当前新 token 的 K/V,并从缓存中获取历史 K/V。
时间复杂度 在 L长度的序列上,每次计算 K/V 的复杂度是 O(L^2)。 每次生成一个 token 的计算复杂度是 O(L)。
显存代价 极低(仅存储模型权重)。 ,需要存储所有 Transformer 层、所有注意力头的历史 K/V 向量。
推理速度 极慢(尤其在长序列上)。 显著加快。