# Summary Sharded Optimizer 是用**通信换内存**的经典设计! 总结一下就是每个 GPU 上有全量参数,但是只有部分的[优化器 optimizer](优化器%20optimizer.md)的状态,最终所有 gpu 上的和是 n 倍的参数和 1 倍的优化器状态 # Cues 这是类似于[[DeepSpeed]] ZeRO-1的优化器状态分片技术。 # Notes ## Sharded Optimizer 内存公式 ```Java 传统 DDP(N个GPU): - 参数:N × 模型参数 - 梯度:N × 模型参数 - 优化器状态:N × 优化器状态 总计:N × (参数 + 梯度 + 优化器状态) Sharded Optimizer(N个GPU): - 参数:N × 模型参数 - 梯度:N × 模型参数 - 优化器状态:1 × 优化器状态(分散存储) 总计:N × (参数 + 梯度) + 1 × 优化器状态 ``` ## Sharded Optimizer 示例 ### 1. 背景:为什么需要 Sharded Optimizer? 像 Adam 优化器这样的优化器会为每个参数维护额外的状态(如 exp_avg 和 exp_avg_sq),导致优化器的内存消耗至少是模型大小的两倍。 让我们用一个具体例子: ### 2. 模型和优化器设置 ```Java 简单模型: - W¹: 2×2 矩阵 = 4个参数 - b¹: 2个参数 - W²: 1×2 矩阵 = 2个参数 - b²: 1个参数 总共:9个参数 ``` **Adam 优化器状态(每个参数需要):** - 参数值本身 - 梯度 - exp_avg(一阶矩估计) - exp_avg_sq(二阶矩估计) ### 3. 内存对比:DDP vs Sharded Optimizer #### 传统 DDP(2个GPU) ```Java GPU 0 内存: - 模型参数: 9个值 - 梯度: 9个值 - exp_avg: 9个值 - exp_avg_sq: 9个值 总计: 36个值 GPU 1 内存: - 模型参数: 9个值 - 梯度: 9个值 - exp_avg: 9个值 - exp_avg_sq: 9个值 总计: 36个值 总内存使用: 72个值 ``` #### Sharded Optimizer(2个GPU) ```Java 参数分片方案: GPU 0 负责: W¹(4个参数) + b¹(2个参数) = 6个参数 GPU 1 负责: W²(2个参数) + b²(1个参数) = 3个参数 GPU 0 内存: - 模型参数: 9个值(完整模型) - 梯度: 9个值(完整梯度) - exp_avg: 6个值(只存储负责的参数) - exp_avg_sq: 6个值(只存储负责的参数) 总计: 30个值 GPU 1 内存: - 模型参数: 9个值(完整模型) - 梯度: 9个值(完整梯度) - exp_avg: 3个值(只存储负责的参数) - exp_avg_sq: 3个值(只存储负责的参数) 总计: 24个值 总内存使用: 54个值(节省25%) ``` ### 4. 训练过程详解 让我们模拟一个训练步骤: #### Step 1: 前向传播(与DDP相同) ```Java GPU 0: 处理 batch_0 → 计算 loss_0 GPU 1: 处理 batch_1 → 计算 loss_1 ``` #### Step 2: 反向传播 + All-Reduce(与DDP相同) ```Java GPU 0 梯度: [g₁⁰, g₂⁰, ..., g₉⁰] GPU 1 梯度: [g₁¹, g₂¹, ..., g₉¹] All-Reduce后: 所有GPU的梯度: [(g₁⁰+g₁¹)/2, (g₂⁰+g₂¹)/2, ..., (g₉⁰+g₉¹)/2] ``` #### Step 3: 优化器更新(关键区别!) **GPU 0 只更新 W¹ 和 b¹:** ```python # 只更新负责的参数 for param in [W¹, b¹]: # 使用本地存储的优化器状态 exp_avg[param] = β₁ * exp_avg[param] + (1-β₁) * grad[param] exp_avg_sq[param] = β₂ * exp_avg_sq[param] + (1-β₂) * grad[param]² # 计算更新 param_update = -lr * exp_avg[param] / (sqrt(exp_avg_sq[param]) + ε) param = param + param_update ``` **GPU 1 只更新 W² 和 b²:** ```python # 只更新负责的参数 for param in [W², b²]: # 使用本地存储的优化器状态 exp_avg[param] = β₁ * exp_avg[param] + (1-β₁) * grad[param] exp_avg_sq[param] = β₂ * exp_avg_sq[param] + (1-β₂) * grad[param]² # 计算更新 param_update = -lr * exp_avg[param] / (sqrt(exp_avg_sq[param]) + ε) param = param + param_update ``` #### Step 4: 参数广播 优化器的 step() 函数只更新其分片中的参数,然后将更新后的参数广播给所有其他 DDP 进程: ```Java GPU 0 广播: 更新后的 W¹, b¹ → GPU 1 GPU 1 广播: 更新后的 W², b² → GPU 0 结果:所有GPU拥有相同的更新后模型 ``` ### 5. 具体数值示例 假设初始状态: ```Java W¹ = [[0.5, -0.3], [0.2, 0.4]] 学习率 lr = 0.01 β₁ = 0.9, β₂ = 0.999 GPU 0 的优化器状态(只存储W¹的状态): exp_avg[W¹] = [[0.01, -0.02], [0.03, 0.01]] exp_avg_sq[W¹] = [[0.001, 0.002], [0.001, 0.003]] 梯度(All-Reduce后): grad[W¹] = [[0.1, -0.05], [0.08, 0.06]] ``` 更新计算: ```Java exp_avg_new = 0.9 * [[0.01, -0.02], [0.03, 0.01]] + 0.1 * [[0.1, -0.05], [0.08, 0.06]] = [[0.019, -0.023], [0.035, 0.015]] exp_avg_sq_new = 0.999 * [[0.001, 0.002], [0.001, 0.003]] + 0.001 * [[0.01, 0.0025], [0.0064, 0.0036]] = [[0.00101, 0.002003], [0.001006, 0.003004]] W¹_new = W¹ - 0.01 * exp_avg_new / (sqrt(exp_avg_sq_new) + 1e-8) ``` ### 6. 通信模式对比 ```Java DDP: ├─ All-Reduce 梯度 └─ 每个GPU独立更新所有参数 Sharded Optimizer: ├─ All-Reduce 梯度(相同) ├─ 每个GPU只更新分片参数 └─ Broadcast 更新后的参数分片 ``` ### 7. 优缺点总结 **优点:** - 显著减少优化器状态内存(特别是对Adam这样的优化器) - 对于大模型,内存节省可达2-3倍 **缺点:** - 额外的参数广播通信 - 实现复杂度增加 这就是 Sharded Optimizer 的核心思想:通过分片优化器状态来换取内存效率!