# Summary
# Cues
# Notes
## 🔢 **假设参数**
```python
batch_size = 2
num_heads = 2
seq_len = 3
d_k = 4
```
## 📊 **第1个测试:合并格式(3D)**
### **输入矩阵Q的形状:**
```python
# 合并格式:(batch*heads, seq, d_k) = (4, 3, 4)
Q = [
# 第1个batch的第1个head
[[1, 2, 3, 4], # token 0
[5, 6, 7, 8], # token 1
[9, 10, 11, 12]], # token 2
# 第1个batch的第2个head
[[2, 3, 4, 5],
[6, 7, 8, 9],
[10, 11, 12, 13]],
# 第2个batch的第1个head
[[3, 4, 5, 6],
[7, 8, 9, 10],
[11, 12, 13, 14]],
# 第2个batch的第2个head
[[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15]]
]
# K = Q (自注意力), V = Q
```
### **注意力计算:**
```python
# scores = Q @ K^T
# 对每个 (batch*head) 独立计算 (3×4) @ (4×3) = (3×3)
scores[0] = Q[0] @ Q[0].T =
[[30, 70, 110], # (1×1+2×2+3×3+4×4), (1×5+2×6+3×7+4×8), ...
[70, 174, 278], # 每行是一个query对所有key的分数
[110, 278, 446]]
scores[1] = Q[1] @ Q[1].T = 类似计算...
# 总共4个 3×3 的注意力矩阵
```
## 📊 **第2个测试:分离格式(4D)**
### **重排后的Q矩阵:**
```python
# rearrange: "(batch head) seq d -> batch head seq d", head=2
# 从 (4, 3, 4) → (2, 2, 3, 4)
Q_reshaped = [
[ # batch 0
[[1, 2, 3, 4], # head 0, token 0
[5, 6, 7, 8], # head 0, token 1
[9, 10, 11, 12]], # head 0, token 2
[[2, 3, 4, 5], # head 1, token 0
[6, 7, 8, 9], # head 1, token 1
[10, 11, 12, 13]] # head 1, token 2
],
[ # batch 1
[[3, 4, 5, 6], # head 0, token 0
[7, 8, 9, 10], # head 0, token 1
[11, 12, 13, 14]], # head 0, token 2
[[4, 5, 6, 7], # head 1, token 0
[8, 9, 10, 11], # head 1, token 1
[12, 13, 14, 15]] # head 1, token 2
]
]
```
### **注意力计算:**
```python
# 现在对每个 (batch, head) 组合独立计算
# Q_reshaped[0,0] @ Q_reshaped[0,0].T = (3×4) @ (4×3) = (3×3)
scores[0,0] =
[[30, 70, 110], # batch0-head0的注意力分数
[70, 174, 278], # 与第1个测试的scores[0]完全相同!
[110, 278, 446]]
scores[0,1] =
[[54, 122, 190], # batch0-head1的注意力分数
[122, 286, 450], # 与第1个测试的scores[1]完全相同!
[190, 450, 710]]
# 总共2×2 = 4个 3×3 的注意力矩阵
```
## 🎯 **关键对比:**
| 方面 | 合并格式(3D) | 分离格式(4D) |
|------|-------------|-------------|
| **输入形状** | `(4, 3, 4)` | `(2, 2, 3, 4)` |
| **循环结构** | 对4个样本循环 | 对2个batch×2个head循环 |
| **内存布局** | 连续存储所有头 | 按batch分组存储 |
| **计算结果** | `scores[0]` = batch0-head0 | `scores[0,0]` = batch0-head0 |
## 💻 **PyTorch中的实际计算:**
```python
# 两种格式的矩阵乘法是等价的:
# 格式1:批量处理
Q1.shape = (4, 3, 4) # batch*head维度被当作batch
scores1 = torch.matmul(Q1, Q1.transpose(-1, -2)) # (4, 3, 3)
# 格式2:显式多头
Q2.shape = (2, 2, 3, 4) # 明确的batch和head维度
scores2 = torch.matmul(Q2, Q2.transpose(-1, -2)) # (2, 2, 3, 3)
# 数值完全相同,只是组织方式不同:
assert torch.allclose(scores1.view(2, 2, 3, 3), scores2)
```
## 🧠 **为什么测试两种格式?**
1. **验证维度无关性**:函数能处理任意前置维度
2. **确保数值一致性**:重排不影响计算结果
3. **覆盖实际使用**:两种格式在实践中都常见
**结论**:两种格式计算结果完全相同,只是数据的组织方式不同!🎯