# Summary 想象你在训练一个识别猫的模型: - GPU 0 看到了2张猫的图片,觉得"耳朵特征的权重应该增大0.15" - GPU 1 看到了另外2张猫的图片,觉得"耳朵特征的权重应该增大0.12" - 最终决定:取个平均,增大0.135 DDP(Distributed Data Parallel)核心思想:把同一份模型复制到多块 GPU/多台机器上;每块 GPU 各算自己那份小批次样本的 前向 + 反向,然后用一次 all‑reduce 把梯度做平均,再各自执行 optimizer.step()。这样一轮迭代后,所有副本的参数仍保持一致。 在 **“两台/多台 GPU 同时算”** 的场景里,把“每个 W 矩阵里的每个位置下次要怎么变”改成“先让所有 GPU 讨论好再一起改”。 [[Pytorch 进程管理]] # Cues [[分布式训练策略]] # Notes --- ## 1 .它解决的核心问题 > **一句话**:每块 GPU 都各算各的数据、各算各的梯度,但最后 **一定要把梯度平均**,否则大家的参数会越跑越不一样。 - **单机训练**:只要把 `loss.backward()` 得到的梯度直接 `optimizer.step()` 就行。 - **多机/多 GPU**: 1. **先**:每张卡自己 forward / backward,得到一份独立的 `grad`; 2. **后**:把所有卡的梯度求平均(All‑Reduce)再用来更新参数。 `DDPIndividualParameters` 就是把这事儿“手搓”了一遍,让你体验 **梯度同步** 的全过程。 --- ## 2 .代码里到底做了啥? |阶段|类似“一起做作业”的比喻|关键代码|你要抓住的点| |---|---|---|---| |**初始化**|**班长发讲义**:Rank 0 把最新参数发给其他同学|`dist.broadcast(param.data, src=0)`|保证一开始大家手里的参数一模一样| |**注册钩子**|**每个人算完一道题马上报答案**|`param.register_hook(self._make_hook(param))`|反向传播时钩子触发,立刻启动梯度同步| |**异步 All‑Reduce**|**报答案+抄答案同时进行**|`handle = dist.all_reduce(grad_clone, async_op=True)`|先把通信句柄存起来,CPU 继续算下一层| |**等待通信完成**|**老师收作业前统一对答案**|`handle.wait(); param.grad.copy_(grad_clone/世界大小)`|所有通信结束后,再把平均梯度写回 `param.grad`| 这样就实现了“**计算-通信重叠**”:梯度一算完就扔到网卡里飞,GPU 继续算别的层,省时间。 --- ## 3 .跟官方 `torch.nn.parallel.DistributedDataParallel` 有何不同? |维度|**DDPIndividualParameters(作业手搓版)**|**PyTorch 原生 DDP(工业强度)**| |---|---|---| |同步粒度|**按参数**:每个权重张量单独 all‑reduce|**按桶(bucket)**:先把几个参数拼一起,减少小包通信开销| |实现行数|几十行教学代码,可读性优先|数千行 C++/CUDA + 复杂优化| |内存峰值|单参数克隆一份梯度|Bucket 化 + 翻转梯度等 trick,更省| |适用场景|学习 & 小模型实验|生产训练、大模型、梯度累积等| > **你学到的是思路**:**“钩子 + 异步 all‑reduce + handle.wait()”** 这套模板,后面无论改 Bucket 还是用 Reduce‑Scatter‑Gather,骨架都一样。 --- 我来用一个简单的例子,结合你已经理解的MLP知识,展示DDP(数据并行)的分布式训练过程。 ## 分布式数据并行(DDP)示例 ### 1. 网络结构(简化版) 我们使用一个更简单的网络: - 输入层:2个神经元 - 隐藏层:2个神经元 - 输出层:1个神经元 ```Java 输入层(2) → 隐藏层(2) → 输出层(1) ``` ### 2. 初始参数 **权重矩阵:** - $W^{(1)} = \begin{bmatrix} 0.5 & -0.3 \ 0.2 & 0.4 \end{bmatrix}$(2×2) - $W^{(2)} = \begin{bmatrix} 0.6 & -0.2 \end{bmatrix}$(1×2) **偏置:** - $b^{(1)} = \begin{bmatrix} 0.1 \ -0.1 \end{bmatrix}$ - $b^{(2)} = 0.2$ ### 3. DDP的关键:数据并行 假设我们有2个GPU(rank 0 和 rank 1),**每个GPU都有完整的模型副本**,但处理不同的数据。 **总数据集(4个样本):** ```Java X = [x₁, x₂, x₃, x₄] Y = [y₁, y₂, y₃, y₄] ``` **数据分配:** - GPU 0 (rank 0): 处理 x₁, x₂ - GPU 1 (rank 1): 处理 x₃, x₄ ### 4. 具体训练过程 让我们用具体数字模拟一个训练步骤: **输入数据:** ```Java GPU 0: x₁ = [1, 2]ᵀ, x₂ = [0, 1]ᵀ GPU 1: x₃ = [2, 0]ᵀ, x₄ = [1, 1]ᵀ 目标值: y = [1, 0, 1, 0] ``` #### Step 1: 前向传播(各GPU独立计算) **GPU 0 计算:** ```Java 对于 x₁ = [1, 2]ᵀ: z₁⁽¹⁾ = W⁽¹⁾x₁ + b⁽¹⁾ = [0.5×1 + (-0.3)×2 + 0.1, 0.2×1 + 0.4×2 + (-0.1)]ᵀ = [0.0, 0.9]ᵀ a₁⁽¹⁾ = σ(z₁⁽¹⁾) = [0.5, 0.71]ᵀ 最终输出: a₁⁽²⁾ = 0.62 损失: L₁ = (0.62 - 1)² = 0.144 ``` **GPU 1 计算:** ```Java 对于 x₃ = [2, 0]ᵀ: z₃⁽¹⁾ = W⁽¹⁾x₃ + b⁽¹⁾ = [1.1, 0.3]ᵀ a₃⁽¹⁾ = σ(z₃⁽¹⁾) = [0.75, 0.57]ᵀ 最终输出: a₃⁽²⁾ = 0.73 损失: L₃ = (0.73 - 1)² = 0.073 ``` #### Step 2: 反向传播(各GPU独立计算梯度) 每个GPU独立计算自己数据的[[梯度]]: **GPU 0 的梯度(基于2个样本的平均):** ```Java ∂L/∂W⁽²⁾|GPU0 = [-0.15, -0.18] ∂L/∂W⁽¹⁾|GPU0 = [[-0.08, -0.16], [-0.03, -0.06]] ``` **GPU 1 的梯度(基于2个样本的平均):** ```Java ∂L/∂W⁽²⁾|GPU1 = [-0.12, -0.14] ∂L/∂W⁽¹⁾|GPU1 = [[-0.10, -0.05], [-0.04, -0.02]] ``` #### Step 3: 梯度同步(DDP的核心) 这是DDP的关键步骤!使用**All-Reduce**操作: **Individual Parameters方式:** ```python # 每个参数单独通信 for param in [W⁽¹⁾, b⁽¹⁾, W⁽²⁾, b⁽²⁾]: # 异步all-reduce handle = dist.all_reduce(param.grad, async_op=True) handles.append(handle) # 等待所有通信完成 for handle in handles: handle.wait() ``` **Bucketed方式:** ```python # 将参数打包成桶 bucket1 = [W⁽¹⁾.grad, b⁽¹⁾.grad] # 第一个桶 bucket2 = [W⁽²⁾.grad, b⁽²⁾.grad] # 第二个桶 # 批量通信 dist.all_reduce(bucket1) dist.all_reduce(bucket2) ``` **All-Reduce结果(求和后除以GPU数量):** ```Java GPU 0 认为: [-0.15, -0.18] ←─┐ ├─→ 取平均 = [-0.135, -0.16] GPU 1 认为: [-0.12, -0.14] ←─┘ ∂L/∂W⁽²⁾|final = ([-0.15, -0.18] + [-0.12, -0.14]) / 2 = [-0.135, -0.16] ∂L/∂W⁽¹⁾|final = ([[-0.08, -0.16], [-0.03, -0.06]] + [[-0.10, -0.05], [-0.04, -0.02]]) / 2 = [[-0.09, -0.105], [-0.035, -0.04]] ``` #### Step 4: 参数更新(所有GPU更新相同) 学习率 η = 0.1,所有GPU执行相同的更新: ```Java W⁽²⁾_new = [0.6, -0.2] - 0.1 × [-0.135, -0.16] = [0.614, -0.184] W⁽¹⁾_new = [[0.5, -0.3], [0.2, 0.4]] - 0.1 × [[-0.09, -0.105], [-0.035, -0.04]] = [[0.509, -0.289], [0.204, 0.404]] ``` ### 5. DDP通信模式图解 ```Java ┌─────────────┐ ┌─────────────┐ │ GPU 0 │ │ GPU 1 │ │ ┌─────────┐ │ │ ┌─────────┐ │ │ │ Model │ │ │ │ Model │ │ │ │ (完整) │ │ │ │ (完整) │ │ │ └─────────┘ │ │ └─────────┘ │ │ │ │ │ │ Data: x₁,x₂ │ │ Data: x₃,x₄ │ └─────────────┘ └─────────────┘ ↓ ↓ 计算梯度 计算梯度 ↓ ↓ └───── All-Reduce ──┘ (梯度同步) ↓ 所有GPU获得相同梯度 ↓ 所有GPU更新为相同参数 ``` ### 6. 关键理解点 1. **模型不分割**:每个GPU都有完整模型副本 2. **数据分割**:不同GPU处理不同批次的数据 3. **梯度聚合**:通过All-Reduce确保所有GPU的梯度一致 4. **同步更新**:所有GPU最终得到相同的模型参数 这就是DDP的核心思想:通过数据并行加速训练,同时保证模型的一致性!