# Summary
FlashAttention 的本质就是重写了 Attention 的 [GPU Kernel 内核](GPU%20Kernel%20内核.md)。
FlashAttention FlashAttention 能够加快注意力机制的运算速度,同时减少对内存的使用。如果您想使用 FlashAttention,请在启动训练时在训练配置文件中添加以下参数:
## Notes
让我用通俗的方式介绍 FlashAttention!
## 什么是 FlashAttention?
想象你在看一部电影,需要记住每个角色之间的关系。
**普通 Attention(慢)**:
- 把整部电影的所有角色关系写在一个巨大的表格上
- 每次查看都要跑到仓库去看这个大表格
- 来回跑很浪费时间
**FlashAttention(快)**:
- 把相关的角色关系记在小本子上(快速内存)
- 一边看电影一边计算,不用反复跑仓库
- 速度快 2-4 倍!
## 核心问题:内存瓶颈
### 标准 Attention 的计算
```python
# Q, K, V 形状都是 [N, d],N 是序列长度
S = Q @ K.T / sqrt(d) # N×N 的巨大矩阵!
P = softmax(S) # N×N 的巨大矩阵!
O = P @ V # 最终输出
```
当 N=8192(GPT-3 的长度)时:
- S 矩阵:8192×8192 = 6700万个数字
- 需要 256MB 内存(仅这一个矩阵!)
- 必须在慢速的 GPU 全局内存中
## FlashAttention 的聪明做法
### 1. **分块计算**
不要一次计算整个 N×N 矩阵,而是分成小块:
```Java
原始:[8192 × 8192] 的巨大矩阵
分块:[128 × 128] 的小块
一次只算一小块
小块能放进快速内存!
```
### 2. **重新安排计算顺序**
**标准方式**:
```Java
1. 算完整个 S 矩阵 → 存到慢内存
2. 读回来 → 算完整个 softmax → 存到慢内存
3. 读回来 → 算最终结果
```
**FlashAttention**:
```Java
对每个小块:
1. 算一小块 S
2. 立即算这块的 softmax
3. 立即算这块的输出
4. 只存最终结果
(中间结果都在快速内存里!)
```
## 用做饭来类比
### 普通 Attention = 自助餐模式
1. 把所有菜都做完,摆在餐台上(占用大量空间)
2. 客人来回跑着取菜(慢速内存访问)
3. 浪费时间在路上
### FlashAttention = 现点现做
1. 把食材分成小份
2. 每次只做一小份,立即端给客人
3. 厨房(快速内存)里快速完成
4. 不需要巨大的餐台
## 实际效果
### 速度提升
```python
# 标准 Attention
序列长度 1K:10ms
序列长度 8K:640ms(64倍慢!)
序列长度 16K:内存爆炸 💥
# FlashAttention
序列长度 1K:5ms
序列长度 8K:160ms(只有16倍慢)
序列长度 16K:640ms(能跑了!)
```
### 内存节省
- 标准:O(N²) 内存
- Flash:O(N) 内存
这意味着可以处理**更长的序列**!
## 为什么这么难实现?
### 挑战 1:Softmax 需要全局信息
```python
softmax(x) = exp(x) / sum(exp(x))
需要知道所有值的和!
```
FlashAttention 用了巧妙的数学技巧,分块计算还能得到正确结果。
### 挑战 2:要写自定义 CUDA 代码
不能用 PyTorch 的标准操作,必须手写 GPU 代码来精确控制内存使用。
## 实际应用
### 1. **更长的上下文**
- GPT-4:32K → 128K 上下文
- Claude:100K → 200K 上下文
- 都得益于 FlashAttention!
### 2. **更大的批次**
```python
# 以前
batch_size = 8 # 内存不够
# 现在
batch_size = 32 # 省下的内存可以处理更多样本
```
### 3. **更快的训练**
- 2-4 倍加速
- 相同时间能训练更多 epoch
## 核心思想总结
**FlashAttention = 空间换时间?不!是聪明地使用空间!**
1. **利用内存层次**:快内存虽小但快
2. **减少数据搬运**:计算和数据放在一起
3. **融合操作**:多个步骤一次完成
这就像是把"仓库式管理"改成"即时生产"(JIT),虽然实现复杂,但效果显著!
## 一句话总结
FlashAttention 通过巧妙的分块计算和内存管理,让 Transformer 能够高效处理超长序列,是大模型能够扩展到 100K+ 上下文的关键技术!