# Summary
Sharded Optimizer 是用**通信换内存**的经典设计!
总结一下就是每个 GPU 上有全量参数,但是只有部分的[优化器 optimizer](优化器%20optimizer.md)的状态,最终所有 gpu 上的和是 n 倍的参数和 1 倍的优化器状态
# Cues
这是类似于[[DeepSpeed]] ZeRO-1的优化器状态分片技术。
# Notes
## Sharded Optimizer 内存公式
```Java
传统 DDP(N个GPU):
- 参数:N × 模型参数
- 梯度:N × 模型参数
- 优化器状态:N × 优化器状态
总计:N × (参数 + 梯度 + 优化器状态)
Sharded Optimizer(N个GPU):
- 参数:N × 模型参数
- 梯度:N × 模型参数
- 优化器状态:1 × 优化器状态(分散存储)
总计:N × (参数 + 梯度) + 1 × 优化器状态
```
## Sharded Optimizer 示例
### 1. 背景:为什么需要 Sharded Optimizer?
像 Adam 优化器这样的优化器会为每个参数维护额外的状态(如 exp_avg 和 exp_avg_sq),导致优化器的内存消耗至少是模型大小的两倍。
让我们用一个具体例子:
### 2. 模型和优化器设置
```Java
简单模型:
- W¹: 2×2 矩阵 = 4个参数
- b¹: 2个参数
- W²: 1×2 矩阵 = 2个参数
- b²: 1个参数
总共:9个参数
```
**Adam 优化器状态(每个参数需要):**
- 参数值本身
- 梯度
- exp_avg(一阶矩估计)
- exp_avg_sq(二阶矩估计)
### 3. 内存对比:DDP vs Sharded Optimizer
#### 传统 DDP(2个GPU)
```Java
GPU 0 内存:
- 模型参数: 9个值
- 梯度: 9个值
- exp_avg: 9个值
- exp_avg_sq: 9个值
总计: 36个值
GPU 1 内存:
- 模型参数: 9个值
- 梯度: 9个值
- exp_avg: 9个值
- exp_avg_sq: 9个值
总计: 36个值
总内存使用: 72个值
```
#### Sharded Optimizer(2个GPU)
```Java
参数分片方案:
GPU 0 负责: W¹(4个参数) + b¹(2个参数) = 6个参数
GPU 1 负责: W²(2个参数) + b²(1个参数) = 3个参数
GPU 0 内存:
- 模型参数: 9个值(完整模型)
- 梯度: 9个值(完整梯度)
- exp_avg: 6个值(只存储负责的参数)
- exp_avg_sq: 6个值(只存储负责的参数)
总计: 30个值
GPU 1 内存:
- 模型参数: 9个值(完整模型)
- 梯度: 9个值(完整梯度)
- exp_avg: 3个值(只存储负责的参数)
- exp_avg_sq: 3个值(只存储负责的参数)
总计: 24个值
总内存使用: 54个值(节省25%)
```
### 4. 训练过程详解
让我们模拟一个训练步骤:
#### Step 1: 前向传播(与DDP相同)
```Java
GPU 0: 处理 batch_0 → 计算 loss_0
GPU 1: 处理 batch_1 → 计算 loss_1
```
#### Step 2: 反向传播 + All-Reduce(与DDP相同)
```Java
GPU 0 梯度: [g₁⁰, g₂⁰, ..., g₉⁰]
GPU 1 梯度: [g₁¹, g₂¹, ..., g₉¹]
All-Reduce后:
所有GPU的梯度: [(g₁⁰+g₁¹)/2, (g₂⁰+g₂¹)/2, ..., (g₉⁰+g₉¹)/2]
```
#### Step 3: 优化器更新(关键区别!)
**GPU 0 只更新 W¹ 和 b¹:**
```python
# 只更新负责的参数
for param in [W¹, b¹]:
# 使用本地存储的优化器状态
exp_avg[param] = β₁ * exp_avg[param] + (1-β₁) * grad[param]
exp_avg_sq[param] = β₂ * exp_avg_sq[param] + (1-β₂) * grad[param]²
# 计算更新
param_update = -lr * exp_avg[param] / (sqrt(exp_avg_sq[param]) + ε)
param = param + param_update
```
**GPU 1 只更新 W² 和 b²:**
```python
# 只更新负责的参数
for param in [W², b²]:
# 使用本地存储的优化器状态
exp_avg[param] = β₁ * exp_avg[param] + (1-β₁) * grad[param]
exp_avg_sq[param] = β₂ * exp_avg_sq[param] + (1-β₂) * grad[param]²
# 计算更新
param_update = -lr * exp_avg[param] / (sqrt(exp_avg_sq[param]) + ε)
param = param + param_update
```
#### Step 4: 参数广播
优化器的 step() 函数只更新其分片中的参数,然后将更新后的参数广播给所有其他 DDP 进程:
```Java
GPU 0 广播: 更新后的 W¹, b¹ → GPU 1
GPU 1 广播: 更新后的 W², b² → GPU 0
结果:所有GPU拥有相同的更新后模型
```
### 5. 具体数值示例
假设初始状态:
```Java
W¹ = [[0.5, -0.3], [0.2, 0.4]]
学习率 lr = 0.01
β₁ = 0.9, β₂ = 0.999
GPU 0 的优化器状态(只存储W¹的状态):
exp_avg[W¹] = [[0.01, -0.02], [0.03, 0.01]]
exp_avg_sq[W¹] = [[0.001, 0.002], [0.001, 0.003]]
梯度(All-Reduce后):
grad[W¹] = [[0.1, -0.05], [0.08, 0.06]]
```
更新计算:
```Java
exp_avg_new = 0.9 * [[0.01, -0.02], [0.03, 0.01]] + 0.1 * [[0.1, -0.05], [0.08, 0.06]]
= [[0.019, -0.023], [0.035, 0.015]]
exp_avg_sq_new = 0.999 * [[0.001, 0.002], [0.001, 0.003]] + 0.001 * [[0.01, 0.0025], [0.0064, 0.0036]]
= [[0.00101, 0.002003], [0.001006, 0.003004]]
W¹_new = W¹ - 0.01 * exp_avg_new / (sqrt(exp_avg_sq_new) + 1e-8)
```
### 6. 通信模式对比
```Java
DDP:
├─ All-Reduce 梯度
└─ 每个GPU独立更新所有参数
Sharded Optimizer:
├─ All-Reduce 梯度(相同)
├─ 每个GPU只更新分片参数
└─ Broadcast 更新后的参数分片
```
### 7. 优缺点总结
**优点:**
- 显著减少优化器状态内存(特别是对Adam这样的优化器)
- 对于大模型,内存节省可达2-3倍
**缺点:**
- 额外的参数广播通信
- 实现复杂度增加
这就是 Sharded Optimizer 的核心思想:通过分片优化器状态来换取内存效率!