# Summary
> 梯度裁剪中算范数是为了避免类似爬山中下个点到当前点的直线距离特别远的情况
梯度裁剪可以理解为:
- 将所有参数的梯度看作一个高维向量
- 如果这个向量的长度超过了设定的最大值
- 就按比例缩短这个向量,但保持方向不变
- 这样可以防止梯度爆炸,同时保持优化方向
这种方法在训练RNN、LSTM等容易出现梯度爆炸的模型时特别有用。
# Cues
# Notes
## 梯度裁剪的具体例子
假设我们有一个简单的神经网络,包含两个参数矩阵:
**参数1 (W1)**: 2×2 矩阵
```Java
W1 = [[1.0, 2.0],
[3.0, 4.0]]
```
**参数2 (W2)**: 2×1 矩阵
```Java
W2 = [[5.0],
[6.0]]
```
### 步骤1:计算梯度
假设反向传播后,我们得到的梯度是:
```Java
grad_W1 = [[3.0, -4.0],
[2.0, 1.0]]
grad_W2 = [[5.0],
[-3.0]]
```
### 步骤2:计算总的L2[[范数]]
```python
# 对于 W1 的梯度
norm_W1_squared = 3.0² + (-4.0)² + 2.0² + 1.0²
= 9.0 + 16.0 + 4.0 + 1.0
= 30.0
# 对于 W2 的梯度
norm_W2_squared = 5.0² + (-3.0)²
= 25.0 + 9.0
= 34.0
# 总的L2范数
total_norm = √(30.0 + 34.0) = √64.0 = 8.0
```
### 步骤3:计算裁剪因子
假设我们设置 `max_l2_norm = 5.0`:
```python
clip_factor = min(1.0, max_l2_norm / total_norm)
= min(1.0, 5.0 / 8.0)
= min(1.0, 0.625)
= 0.625
```
### 步骤4:应用裁剪
将所有梯度乘以裁剪因子:
```python
# 裁剪后的 W1 梯度
clipped_grad_W1 = grad_W1 * 0.625
= [[3.0 * 0.625, -4.0 * 0.625],
[2.0 * 0.625, 1.0 * 0.625]]
= [[1.875, -2.5],
[1.25, 0.625]]
# 裁剪后的 W2 梯度
clipped_grad_W2 = grad_W2 * 0.625
= [[5.0 * 0.625],
[-3.0 * 0.625]]
= [[3.125],
[-1.875]]
```
### 验证裁剪效果
让我们验证裁剪后的梯度L2范数确实等于5.0:
```python
# 裁剪后的总L2范数
new_norm_squared = 1.875² + (-2.5)² + 1.25² + 0.625² + 3.125² + (-1.875)²
= 3.516 + 6.25 + 1.563 + 0.391 + 9.766 + 3.516
= 25.0
new_total_norm = √25.0 = 5.0 ✓
```
## 代码中的关键点
1. **`1e-6` 的作用**:防止除零错误。如果所有梯度都是0,`total_norm`会是0,加上这个小值避免除零。
2. **`min(1.0,...)`**:确保只在梯度范数超过阈值时才裁剪。如果梯度范数已经小于阈值,`clip_factor`会是1.0,相当于不裁剪。
3. **`mul_` 原地操作**:直接修改梯度张量,节省内存。