# Summary
想象你在训练一个识别猫的模型:
- GPU 0 看到了2张猫的图片,觉得"耳朵特征的权重应该增大0.15"
- GPU 1 看到了另外2张猫的图片,觉得"耳朵特征的权重应该增大0.12"
- 最终决定:取个平均,增大0.135
DDP(Distributed Data Parallel)核心思想:把同一份模型复制到多块 GPU/多台机器上;每块 GPU 各算自己那份小批次样本的 前向 + 反向,然后用一次 all‑reduce 把梯度做平均,再各自执行 optimizer.step()。这样一轮迭代后,所有副本的参数仍保持一致。
在 **“两台/多台 GPU 同时算”** 的场景里,把“每个 W 矩阵里的每个位置下次要怎么变”改成“先让所有 GPU 讨论好再一起改”。
[[Pytorch 进程管理]]
# Cues
[[分布式训练策略]]
# Notes
---
## 1 .它解决的核心问题
> **一句话**:每块 GPU 都各算各的数据、各算各的梯度,但最后 **一定要把梯度平均**,否则大家的参数会越跑越不一样。
- **单机训练**:只要把 `loss.backward()` 得到的梯度直接 `optimizer.step()` 就行。
- **多机/多 GPU**:
1. **先**:每张卡自己 forward / backward,得到一份独立的 `grad`;
2. **后**:把所有卡的梯度求平均(All‑Reduce)再用来更新参数。
`DDPIndividualParameters` 就是把这事儿“手搓”了一遍,让你体验 **梯度同步** 的全过程。
---
## 2 .代码里到底做了啥?
|阶段|类似“一起做作业”的比喻|关键代码|你要抓住的点|
|---|---|---|---|
|**初始化**|**班长发讲义**:Rank 0 把最新参数发给其他同学|`dist.broadcast(param.data, src=0)`|保证一开始大家手里的参数一模一样|
|**注册钩子**|**每个人算完一道题马上报答案**|`param.register_hook(self._make_hook(param))`|反向传播时钩子触发,立刻启动梯度同步|
|**异步 All‑Reduce**|**报答案+抄答案同时进行**|`handle = dist.all_reduce(grad_clone, async_op=True)`|先把通信句柄存起来,CPU 继续算下一层|
|**等待通信完成**|**老师收作业前统一对答案**|`handle.wait(); param.grad.copy_(grad_clone/世界大小)`|所有通信结束后,再把平均梯度写回 `param.grad`|
这样就实现了“**计算-通信重叠**”:梯度一算完就扔到网卡里飞,GPU 继续算别的层,省时间。
---
## 3 .跟官方 `torch.nn.parallel.DistributedDataParallel` 有何不同?
|维度|**DDPIndividualParameters(作业手搓版)**|**PyTorch 原生 DDP(工业强度)**|
|---|---|---|
|同步粒度|**按参数**:每个权重张量单独 all‑reduce|**按桶(bucket)**:先把几个参数拼一起,减少小包通信开销|
|实现行数|几十行教学代码,可读性优先|数千行 C++/CUDA + 复杂优化|
|内存峰值|单参数克隆一份梯度|Bucket 化 + 翻转梯度等 trick,更省|
|适用场景|学习 & 小模型实验|生产训练、大模型、梯度累积等|
> **你学到的是思路**:**“钩子 + 异步 all‑reduce + handle.wait()”** 这套模板,后面无论改 Bucket 还是用 Reduce‑Scatter‑Gather,骨架都一样。
---
我来用一个简单的例子,结合你已经理解的MLP知识,展示DDP(数据并行)的分布式训练过程。
## 分布式数据并行(DDP)示例
### 1. 网络结构(简化版)
我们使用一个更简单的网络:
- 输入层:2个神经元
- 隐藏层:2个神经元
- 输出层:1个神经元
```Java
输入层(2) → 隐藏层(2) → 输出层(1)
```
### 2. 初始参数
**权重矩阵:**
- $W^{(1)} = \begin{bmatrix} 0.5 & -0.3 \ 0.2 & 0.4 \end{bmatrix}$(2×2)
- $W^{(2)} = \begin{bmatrix} 0.6 & -0.2 \end{bmatrix}$(1×2)
**偏置:**
- $b^{(1)} = \begin{bmatrix} 0.1 \ -0.1 \end{bmatrix}$
- $b^{(2)} = 0.2$
### 3. DDP的关键:数据并行
假设我们有2个GPU(rank 0 和 rank 1),**每个GPU都有完整的模型副本**,但处理不同的数据。
**总数据集(4个样本):**
```Java
X = [x₁, x₂, x₃, x₄]
Y = [y₁, y₂, y₃, y₄]
```
**数据分配:**
- GPU 0 (rank 0): 处理 x₁, x₂
- GPU 1 (rank 1): 处理 x₃, x₄
### 4. 具体训练过程
让我们用具体数字模拟一个训练步骤:
**输入数据:**
```Java
GPU 0: x₁ = [1, 2]ᵀ, x₂ = [0, 1]ᵀ
GPU 1: x₃ = [2, 0]ᵀ, x₄ = [1, 1]ᵀ
目标值: y = [1, 0, 1, 0]
```
#### Step 1: 前向传播(各GPU独立计算)
**GPU 0 计算:**
```Java
对于 x₁ = [1, 2]ᵀ:
z₁⁽¹⁾ = W⁽¹⁾x₁ + b⁽¹⁾ = [0.5×1 + (-0.3)×2 + 0.1, 0.2×1 + 0.4×2 + (-0.1)]ᵀ
= [0.0, 0.9]ᵀ
a₁⁽¹⁾ = σ(z₁⁽¹⁾) = [0.5, 0.71]ᵀ
最终输出: a₁⁽²⁾ = 0.62
损失: L₁ = (0.62 - 1)² = 0.144
```
**GPU 1 计算:**
```Java
对于 x₃ = [2, 0]ᵀ:
z₃⁽¹⁾ = W⁽¹⁾x₃ + b⁽¹⁾ = [1.1, 0.3]ᵀ
a₃⁽¹⁾ = σ(z₃⁽¹⁾) = [0.75, 0.57]ᵀ
最终输出: a₃⁽²⁾ = 0.73
损失: L₃ = (0.73 - 1)² = 0.073
```
#### Step 2: 反向传播(各GPU独立计算梯度)
每个GPU独立计算自己数据的[[梯度]]:
**GPU 0 的梯度(基于2个样本的平均):**
```Java
∂L/∂W⁽²⁾|GPU0 = [-0.15, -0.18]
∂L/∂W⁽¹⁾|GPU0 = [[-0.08, -0.16], [-0.03, -0.06]]
```
**GPU 1 的梯度(基于2个样本的平均):**
```Java
∂L/∂W⁽²⁾|GPU1 = [-0.12, -0.14]
∂L/∂W⁽¹⁾|GPU1 = [[-0.10, -0.05], [-0.04, -0.02]]
```
#### Step 3: 梯度同步(DDP的核心)
这是DDP的关键步骤!使用**All-Reduce**操作:
**Individual Parameters方式:**
```python
# 每个参数单独通信
for param in [W⁽¹⁾, b⁽¹⁾, W⁽²⁾, b⁽²⁾]:
# 异步all-reduce
handle = dist.all_reduce(param.grad, async_op=True)
handles.append(handle)
# 等待所有通信完成
for handle in handles:
handle.wait()
```
**Bucketed方式:**
```python
# 将参数打包成桶
bucket1 = [W⁽¹⁾.grad, b⁽¹⁾.grad] # 第一个桶
bucket2 = [W⁽²⁾.grad, b⁽²⁾.grad] # 第二个桶
# 批量通信
dist.all_reduce(bucket1)
dist.all_reduce(bucket2)
```
**All-Reduce结果(求和后除以GPU数量):**
```Java
GPU 0 认为: [-0.15, -0.18] ←─┐
├─→ 取平均 = [-0.135, -0.16]
GPU 1 认为: [-0.12, -0.14] ←─┘
∂L/∂W⁽²⁾|final = ([-0.15, -0.18] + [-0.12, -0.14]) / 2 = [-0.135, -0.16]
∂L/∂W⁽¹⁾|final = ([[-0.08, -0.16], [-0.03, -0.06]] +
[[-0.10, -0.05], [-0.04, -0.02]]) / 2
= [[-0.09, -0.105], [-0.035, -0.04]]
```
#### Step 4: 参数更新(所有GPU更新相同)
学习率 η = 0.1,所有GPU执行相同的更新:
```Java
W⁽²⁾_new = [0.6, -0.2] - 0.1 × [-0.135, -0.16] = [0.614, -0.184]
W⁽¹⁾_new = [[0.5, -0.3], [0.2, 0.4]] - 0.1 × [[-0.09, -0.105], [-0.035, -0.04]]
= [[0.509, -0.289], [0.204, 0.404]]
```
### 5. DDP通信模式图解
```Java
┌─────────────┐ ┌─────────────┐
│ GPU 0 │ │ GPU 1 │
│ ┌─────────┐ │ │ ┌─────────┐ │
│ │ Model │ │ │ │ Model │ │
│ │ (完整) │ │ │ │ (完整) │ │
│ └─────────┘ │ │ └─────────┘ │
│ │ │ │
│ Data: x₁,x₂ │ │ Data: x₃,x₄ │
└─────────────┘ └─────────────┘
↓ ↓
计算梯度 计算梯度
↓ ↓
└───── All-Reduce ──┘
(梯度同步)
↓
所有GPU获得相同梯度
↓
所有GPU更新为相同参数
```
### 6. 关键理解点
1. **模型不分割**:每个GPU都有完整模型副本
2. **数据分割**:不同GPU处理不同批次的数据
3. **梯度聚合**:通过All-Reduce确保所有GPU的梯度一致
4. **同步更新**:所有GPU最终得到相同的模型参数
这就是DDP的核心思想:通过数据并行加速训练,同时保证模型的一致性!