# Summary # Cues # Notes ## 🔢 **假设参数** ```python batch_size = 2 num_heads = 2 seq_len = 3 d_k = 4 ``` ## 📊 **第1个测试:合并格式(3D)** ### **输入矩阵Q的形状:** ```python # 合并格式:(batch*heads, seq, d_k) = (4, 3, 4) Q = [ # 第1个batch的第1个head [[1, 2, 3, 4], # token 0 [5, 6, 7, 8], # token 1 [9, 10, 11, 12]], # token 2 # 第1个batch的第2个head [[2, 3, 4, 5], [6, 7, 8, 9], [10, 11, 12, 13]], # 第2个batch的第1个head [[3, 4, 5, 6], [7, 8, 9, 10], [11, 12, 13, 14]], # 第2个batch的第2个head [[4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] ] # K = Q (自注意力), V = Q ``` ### **注意力计算:** ```python # scores = Q @ K^T # 对每个 (batch*head) 独立计算 (3×4) @ (4×3) = (3×3) scores[0] = Q[0] @ Q[0].T = [[30, 70, 110], # (1×1+2×2+3×3+4×4), (1×5+2×6+3×7+4×8), ... [70, 174, 278], # 每行是一个query对所有key的分数 [110, 278, 446]] scores[1] = Q[1] @ Q[1].T = 类似计算... # 总共4个 3×3 的注意力矩阵 ``` ## 📊 **第2个测试:分离格式(4D)** ### **重排后的Q矩阵:** ```python # rearrange: "(batch head) seq d -> batch head seq d", head=2 # 从 (4, 3, 4) → (2, 2, 3, 4) Q_reshaped = [ [ # batch 0 [[1, 2, 3, 4], # head 0, token 0 [5, 6, 7, 8], # head 0, token 1 [9, 10, 11, 12]], # head 0, token 2 [[2, 3, 4, 5], # head 1, token 0 [6, 7, 8, 9], # head 1, token 1 [10, 11, 12, 13]] # head 1, token 2 ], [ # batch 1 [[3, 4, 5, 6], # head 0, token 0 [7, 8, 9, 10], # head 0, token 1 [11, 12, 13, 14]], # head 0, token 2 [[4, 5, 6, 7], # head 1, token 0 [8, 9, 10, 11], # head 1, token 1 [12, 13, 14, 15]] # head 1, token 2 ] ] ``` ### **注意力计算:** ```python # 现在对每个 (batch, head) 组合独立计算 # Q_reshaped[0,0] @ Q_reshaped[0,0].T = (3×4) @ (4×3) = (3×3) scores[0,0] = [[30, 70, 110], # batch0-head0的注意力分数 [70, 174, 278], # 与第1个测试的scores[0]完全相同! [110, 278, 446]] scores[0,1] = [[54, 122, 190], # batch0-head1的注意力分数 [122, 286, 450], # 与第1个测试的scores[1]完全相同! [190, 450, 710]] # 总共2×2 = 4个 3×3 的注意力矩阵 ``` ## 🎯 **关键对比:** | 方面 | 合并格式(3D) | 分离格式(4D) | |------|-------------|-------------| | **输入形状** | `(4, 3, 4)` | `(2, 2, 3, 4)` | | **循环结构** | 对4个样本循环 | 对2个batch×2个head循环 | | **内存布局** | 连续存储所有头 | 按batch分组存储 | | **计算结果** | `scores[0]` = batch0-head0 | `scores[0,0]` = batch0-head0 | ## 💻 **PyTorch中的实际计算:** ```python # 两种格式的矩阵乘法是等价的: # 格式1:批量处理 Q1.shape = (4, 3, 4) # batch*head维度被当作batch scores1 = torch.matmul(Q1, Q1.transpose(-1, -2)) # (4, 3, 3) # 格式2:显式多头 Q2.shape = (2, 2, 3, 4) # 明确的batch和head维度 scores2 = torch.matmul(Q2, Q2.transpose(-1, -2)) # (2, 2, 3, 3) # 数值完全相同,只是组织方式不同: assert torch.allclose(scores1.view(2, 2, 3, 3), scores2) ``` ## 🧠 **为什么测试两种格式?** 1. **验证维度无关性**:函数能处理任意前置维度 2. **确保数值一致性**:重排不影响计算结果 3. **覆盖实际使用**:两种格式在实践中都常见 **结论**:两种格式计算结果完全相同,只是数据的组织方式不同!🎯