## 1. 单变量情形
### 1.1 公式形式
最常见、最基础的**链式法则**公式是针对**单变量**函数的复合:
$y = f\bigl(g(x)\bigr)$
要对 $y$ 关于 $x$ 求导数,即 $\frac{dy}{dx}$,**链式法则**告诉我们:
$\frac{dy}{dx} = f'\bigl(g(x)\bigr) \times g'(x)$
含义是:
> 当你把$x$ 先输入到 $g$ 函数里,再把 $g(x)$ 的结果输入到 $f$ 函数里,如果你想知道 **$x$ 的微小变化对输出 $y$ 有什么影响**,可以把它拆分为"$g$ 对 $x$ 的影响" 和 "$f$ 对 $g$ 的影响" 两部分,再把它们**相乘**。
### 1.2 示例 1:$y=(x^2+1)^3$
这里可以把
$\underbrace{g(x)}_{\text{内函数}} = x^2 + 1, \quad \underbrace{f(u)}_{\text{外函数}} = u^3 \quad (\text{其中 }u=g(x))$
于是
$y = f\bigl(g(x)\bigr) = (x^2 + 1)^3$
- **先求** $g'(x)$:
$g'(x) = \frac{d}{dx}(x^2 + 1) = 2x$
- **再求** $f'(u)$:
$f'(u) = \frac{d}{du}(u^3) = 3u^2$
在具体求导时,要把 $u$ 替换成 $g(x)$,也就是 $x^2+1$。
**最后用链式法则**:
$\frac{dy}{dx} = f'(g(x)) \cdot g'(x) = 3 \,\bigl(x^2 + 1\bigr)^2 \,\times\, 2x = 6x\,\bigl(x^2 + 1\bigr)^2$
### 1.3 示例 2:$y=\sqrt{1 + x^2}$
可以写成
$y = (1 + x^2)^{\tfrac12}$
看作
$g(x) = 1 + x^2,\quad f(u) = u^{\tfrac12}$
- $g'(x) = 2x$
- $f'(u) = \frac12 u^{-\tfrac12} = \frac{1}{2\sqrt{u}}$
所以
$\frac{dy}{dx} = \frac{1}{2\sqrt{g(x)}} \,\times\, 2x = \frac{2x}{2\sqrt{1 + x^2}} = \frac{x}{\sqrt{1 + x^2}}$
这个例子也常见于导数的练习题。
## 2. 多变量情形
在神经网络里(以及更普遍的场景),我们往往遇到**多变量**的复合函数。比如:
$z = f\bigl(u,v\bigr)$ 而 $u = g(x,y), \quad v = h(x,y)$
这个时候,如果要对 $z$ 关于 $x$ 求偏导数,就要用到**多元链式法则**:
$\frac{\partial z}{\partial x} = \frac{\partial f}{\partial u}\cdot\frac{\partial u}{\partial x} + \frac{\partial f}{\partial v}\cdot\frac{\partial v}{\partial x}$
> 这是因为 $z$ 既会随 $u$ 的变化而变化,也会随 $v$ 的变化而变化;而 $u$ 和 $v$ 又都依赖于 $x$。
## 2.1 示例:$z=f(g(x,y), h(x,y))$
让我们构造一个稍微具体一点的例子:
$z = (u + v)^2, \quad u = xy, \quad v = x + y$
你可以认为
$f(u,v) = (u+v)^2$
$g(x,y)=x y,\quad h(x,y)=x + y$
要对 $z$ 关于 $x$ 和 $y$ 分别求偏导,就得同时考虑 $u$ 和 $v$ 对 $x$ 的影响。
1. **$\frac{\partial z}{\partial x}$**:
- 先算 $\frac{\partial f}{\partial u} = \frac{\partial}{\partial u} (u+v)^2 = 2(u+v)$
- 再算 $\frac{\partial u}{\partial x} = \frac{\partial}{\partial x}(xy) = y$
- 还有 $\frac{\partial f}{\partial v} = 2(u+v)$(同理,因为对 $v$ 的偏导也一样)
- 再算 $\frac{\partial v}{\partial x} = \frac{\partial}{\partial x}(x + y) = 1$
由多元链式法则:
$\frac{\partial z}{\partial x} = \underbrace{2(u+v)}_{\frac{\partial f}{\partial u}}\times \underbrace{y}_{\frac{\partial u}{\partial x}} + \underbrace{2(u+v)}_{\frac{\partial f}{\partial v}}\times \underbrace{1}_{\frac{\partial v}{\partial x}} = 2(u+v)\cdot y + 2(u+v)\cdot 1 = 2(u+v)(y + 1)$
最后再把 $u=xy, v=x+y$ 带回去:
$\frac{\partial z}{\partial x} = 2\bigl(xy + x + y\bigr)\,(y + 1)$
2. **$\frac{\partial z}{\partial y}$**:
可以做类似步骤:
- $\frac{\partial u}{\partial y} = x$
- $\frac{\partial v}{\partial y} = 1$
$\frac{\partial z}{\partial y} = 2(u+v)\cdot x + 2(u+v)\cdot 1 = 2(u+v)(x + 1)$
再替换 $u,v$:
$\frac{\partial z}{\partial y} = 2(xy + x + y)\,(x + 1)$
这就是"多元链式法则"在一个小示例中的运用。
## 3. 为什么在神经网络中很重要?
在神经网络里,常见情形是:
$\hat{y} = a^{(L)}$ 而 $a^{(\ell)} = f^{(\ell)}\bigl(W^{(\ell)}\,a^{(\ell-1)} + b^{(\ell)}\bigr)$
其中 $\ell = 1,2,\dots,L$ 表示不同层级;$a^{(\ell-1)}$ 表示上一层的输出等。
如果再加上一个损失函数 $J(y, \hat{y})$,想要求 $\frac{\partial J}{\partial W^{(\ell)}}$ 或 $\frac{\partial J}{\partial b^{(\ell)}}$,就需要**不断应用链式法则**,才能把损失 $J$ 一直"传回"到某个权重/偏置上。具体展开时,你会看到我们一层层反推,所以叫**"反向传播"**(backpropagation)。但**底层逻辑**就是:对多层复合函数求导时,一次次用到链式法则**。
## 4. 总结与建议
1. **单变量链式法则**是最基础的:$\frac{d}{dx}f(g(x))=f'(g(x))g'(x)$。先从这些经典例子($(x^2+1)^3$、$\sqrt{1+x^2}$等)入手,把它记熟并理解透。
2. **多元复合函数**时,要先搞清楚哪个变量依赖于谁,再用"对每个中间量的偏导相乘,然后相加"的方式来求最终对 $x$ 或者 $y$ 的偏导。这就是**多元链式法则**。
3. **在神经网络中**,它会体现为:
- 输出层对损失的偏导 $\times$(输出层激活函数的导数)
- $\times$(输出层对上一层输出的依赖)
- $\times$(上一层激活函数的导数)
- … 一直乘到你所关心的某个权重或偏置的那一层。
> 所以,**掌握链式法则的本质**("复合函数求导时,**分段求导再相乘**")就足以支撑你理解反向传播中那看似复杂的一连串乘积运算。
### 参考学习顺序
1. **复习高中或大学微积分**里关于复合函数求导、偏导数、梯度的章节。
2. 做一些简单单变量、多变量例题,确保对**链式法则**和**偏导**都熟练。
3. 再回到**神经网络**:把每层输出当成一个中间变量,在纸上清晰地写出"这一层对上一层的依赖关系",然后用链式法则"连乘"求导。
这样一来,你就能慢慢理解:反向传播其实就是在做一大串链式法则的乘法运算。尤其是看到公式里"$\sigma'(z)
quot;、"$(a^{(2)} - y)quot; 等等,都只是一小块一小块的局部导数。每一层计算完之后,再往前一层乘,这就是**"反向"**思想的由来,也是为什么它能高效地计算深层网络的梯度。