# Summary
rmsnorm 是 Root Mean Square Normalization (均方根归一化) 的缩写
# Cues
# Notes
我来用一个具体的小例子演示RMSNorm的计算过程。
假设我们有一个输入向量 **$x =[2, 4, 6, 8]$**,维度为4。
## 计算步骤
**第1步:计算均方根(RMS)**
$
\begin{aligned}
\text{RMS} &= \sqrt{\frac{1}{n} \sum_{i=1}^{n} x_i^2} = \sqrt{\frac{1}{4} \times (2^2 + 4^2 + 6^2 + 8^2)} \\
&= \sqrt{\frac{1}{4} \times (4 + 16 + 36 + 64)} \\
&= \sqrt{\frac{1}{4} \times 120} \\
&= \sqrt{30} \\
&\approx 5.477
\end{aligned}
$
**第2步:归一化**
将每个元素除以RMS:
- $x_1' = 2 / 5.477 \approx 0.365$
- $x_2' = 4 / 5.477 \approx 0.730$
- $x_3' = 6 / 5.477 \approx 1.095$
- $x_4' = 8 / 5.477 \approx 1.461$
归一化后:**$x' =[0.365, 0.730, 1.095, 1.461]$**
**第3步:缩放(可选)**
假设缩放参数 $\gamma =[1.2, 0.8, 1.0, 1.5]$(通常是可学习的参数)
最终输出:
- $y_1 = 0.365 \times 1.2 = 0.438$
- $y_2 = 0.730 \times 0.8 = 0.584$
- $y_3 = 1.095 \times 1.0 = 1.095$
- $y_4 = 1.461 \times 1.5 = 2.192$
最终结果:**$y =[0.438, 0.584, 1.095, 2.192]$**
## 与LayerNorm的区别
如果用LayerNorm处理同样的数据:
1. 先计算均值:$\mu = (2+4+6+8)/4 = 5$
2. 计算方差:$\sigma^2 =[(2-5)^2 + (4-5)^2 + (6-5)^2 + (8-5)^2]/4 = 5$
3. 标准化:$x' = (x - \mu) / \sqrt{\sigma^2 + \epsilon}$
可以看到,RMSNorm跳过了计算均值和减去均值的步骤,直接使用均方根进行归一化,计算更简单高效。
---
我来将代码与之前的计算过程对应起来。让我用同样的数据 **$x =[2, 4, 6, 8]$** 来解释每行代码:
## 代码逐行解析
**1. `x_sq = x.pow(2)`**
计算每个元素的平方
$x_{sq} =[2^2, 4^2, 6^2, 8^2]=[4, 16, 36, 64]$
**2. `mean_sq = x_sq.mean(-1, keepdim=True)`**
计算平方的均值(沿最后一个维度)
$mean_{sq} = (4 + 16 + 36 + 64) / 4 = 120 / 4 = 30$
**3. `r_rms = torch.rsqrt(mean_sq + eps)`**
计算倒数平方根 (reciprocal square root)
$r_{rms} = 1 / \sqrt{30 + 1e-5} \approx 1 / \sqrt{30} \approx 1 / 5.477 \approx 0.1826$
**4. `y = x * r_rms * g`**
归一化并缩放
假设 $g =[1.2, 0.8, 1.0, 1.5]$
$y_1 = 2 \times 0.1826 \times 1.2 = 0.438$
$y_2 = 4 \times 0.1826 \times 0.8 = 0.584$
$y_3 = 6 \times 0.1826 \times 1.0 = 1.095$
$y_4 = 8 \times 0.1826 \times 1.5 = 2.192$
## 关键点对应
1. **RMS计算优化**:
- 手动计算:$RMS = \sqrt{\text{mean}(x^2)}$,然后 $x' = x / RMS$
- 代码优化:直接计算 $r_{rms} = 1/\sqrt{\text{mean}(x^2)}$,然后 $y = x \times r_{rms}$
2. **$eps$的作用**:
- 防止除零错误
- 提高数值稳定性
3. **保存中间变量**:
```python
ctx.save_for_backward(x, g, r_rms)
```
这是为反向传播保存必要的中间结果。
## 完整示例验证:完美!代码的计算结果与我们手动计算的结果一致
总结一下代码的巧妙之处:
- 使用 `torch.rsqrt()` 直接计算倒数平方根,避免了先开方再求倒数
- 使用 `keepdim=True` 保持维度,便于后续的广播运算
- 整个计算过程简洁高效,只需要一次平方、一次均值、一次倒数平方根和一次乘法运算