# Summary # Cues # Notes ## KV-cache介绍 * 先来看看Attention的公式: $ Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V $ * 现在主流的 [Decoder-only](Decoder-only.md) LLM 的自回归生成模式,每一个iteration都会将上一次输出的token作为输入,用以预测next token。 * 但是在这个过程中,对于每一个Attention Block的计算,其中Q只涉及当前最新位置上的一列vector (假设batchsize=1)。而K,V的虽然涉及所有位置上的vector,但是其中也只有一列是来自于最新位置的(fresh),其余位置都是以往已经计算过的(stale)。因此(3) 提出将这每一步stale的计算缓存在GPU HBM上,用以在下一个iteration进行计算。 * 详细公式推导可以参考(1) * Q是对于当前token的编码,但是K、V是对于之前+当前token的编码矩阵 * 在Transformer架构中,只有self attention会进行token间交互,这里就会使用KV-cache:即记录下之前的kv矩阵 * 一般情况下,kv-cache常驻显存 ## KV-cache显存计算 $ \text{Size} = (\text{batch\_size} \times \text{seq\_len} \times \text{num\_layers} \times \text{num\_heads} \times \text{head\_dim}) \times 2 \times \text{precision} $ * **例子:OPT-30B** * `num_layers` = 48 * `num_heads` = 56 * `head_dim` = 128 * `precision` = 2 (fp16) * `batch_size` = 1 * `seq_len` = 2048 * **Size** = (1 * 2048 * 48 * 56 * 128) * 2 * 2 = 2.8 GB ## 具体例子:生成 "I love AI" 序列 假设我们有一个简化的模型,其中 `d_model = 4`(隐藏维度)。 ### Step 1: 输入 "I",生成 "love" **输入处理:** - 输入 token: "I" - 经过 embedding 和位置编码后得到向量:`h_1 =[1, 2, 3, 4]` **计算 Q, K, V:** ```Java Q_1 = W_Q × h_1 = [0.5, 1.0, 1.5, 2.0] (1×4 矩阵) K_1 = W_K × h_1 = [0.6, 1.2, 1.8, 2.4] (1×4 矩阵) V_1 = W_V × h_1 = [0.7, 1.4, 2.1, 2.8] (1×4 矩阵) ``` **Attention 计算:** ```Java Attention = softmax(Q_1 × K_1^T / √4) × V_1 = softmax([7.2]) × V_1 = 1.0 × V_1 = [0.7, 1.4, 2.1, 2.8] ``` **缓存 K_1 和 V_1** ✓ ### Step 2: 输入 "I love",生成 "AI" **新输入处理:** - 新 token: "love" - 编码后:`h_2 =[2, 3, 4, 5]` **只计算新位置的 Q, K, V:** ```Java Q_2 = W_Q × h_2 = [1.0, 1.5, 2.0, 2.5] (1×4 矩阵,只有当前位置) K_2 = W_K × h_2 = [1.2, 1.8, 2.4, 3.0] (1×4 矩阵,新位置) V_2 = W_V × h_2 = [1.4, 2.1, 2.8, 3.5] (1×4 矩阵,新位置) ``` **关键:使用 KV-cache!** ```Java K_full = [K_1] = [0.6, 1.2, 1.8, 2.4] (2×4 矩阵) [K_2] [1.2, 1.8, 2.4, 3.0] V_full = [V_1] = [0.7, 1.4, 2.1, 2.8] (2×4 矩阵) [V_2] [1.4, 2.1, 2.8, 3.5] ``` **Attention 计算:** ```Java Attention = softmax(Q_2 × K_full^T / √4) × V_full = softmax([1.0×0.6+1.5×1.2+2.0×1.8+2.5×2.4, # Q_2与K_1的点积 1.0×1.2+1.5×1.8+2.0×2.4+2.5×3.0]) # Q_2与K_2的点积 = softmax([12.0, 16.2]) × V_full = [0.01, 0.99] × V_full ≈ 0.01×[0.7,1.4,2.1,2.8] + 0.99×[1.4,2.1,2.8,3.5] ≈ [1.39, 2.09, 2.79, 3.49] ``` **更新缓存:** K_cache = K_full, V_cache = V_full ### Step 3: 输入 "I love AI",生成下一个 token **新计算:** ```Java Q_3 = W_Q × h_3 (1×4,只有最新位置) K_3 = W_K × h_3 (1×4,只有最新位置) V_3 = W_V × h_3 (1×4,只有最新位置) K_full = [K_1] (3×4 矩阵,前两行来自缓存) [K_2] [K_3] V_full = [V_1] (3×4 矩阵,前两行来自缓存) [V_2] [V_3] ``` ## 内存节省示例 **不使用 KV-cache:** - 每次都重新计算所有位置的 K, V - Step 3 需要计算 3×3 = 9 次矩阵乘法 **使用 KV-cache:** - 只计算新位置的 K, V - Step 3 只需要计算 1×3 = 3 次矩阵乘法 - 节省了 66% 的计算量! ## 关键要点 1. **Q 始终只有当前位置**:因为我们只需要为当前 token 计算注意力分数 2. **K, V 包含所有历史**:需要让当前 token 能够 "看到" 所有之前的 token 3. **缓存的价值**:避免重复计算已经处理过的 token 的 K, V 值 4. **内存常驻**:KV-cache 保存在 GPU 显存中,用空间换时间 这就是为什么说 "Q 是对当前 token 的编码,而 K、V 是对之前+当前 token 的编码矩阵"。