## 1. 单变量情形 ### 1.1 公式形式 最常见、最基础的**链式法则**公式是针对**单变量**函数的复合: $y = f\bigl(g(x)\bigr)$ 要对 $y$ 关于 $x$ 求导数,即 $\frac{dy}{dx}$,**链式法则**告诉我们: $\frac{dy}{dx} = f'\bigl(g(x)\bigr) \times g'(x)$ 含义是: > 当你把$x$ 先输入到 $g$ 函数里,再把 $g(x)$ 的结果输入到 $f$ 函数里,如果你想知道 **$x$ 的微小变化对输出 $y$ 有什么影响**,可以把它拆分为"$g$ 对 $x$ 的影响" 和 "$f$ 对 $g$ 的影响" 两部分,再把它们**相乘**。 ### 1.2 示例 1:$y=(x^2+1)^3$ 这里可以把 $\underbrace{g(x)}_{\text{内函数}} = x^2 + 1, \quad \underbrace{f(u)}_{\text{外函数}} = u^3 \quad (\text{其中 }u=g(x))$ 于是 $y = f\bigl(g(x)\bigr) = (x^2 + 1)^3$ - **先求** $g'(x)$: $g'(x) = \frac{d}{dx}(x^2 + 1) = 2x$ - **再求** $f'(u)$: $f'(u) = \frac{d}{du}(u^3) = 3u^2$ 在具体求导时,要把 $u$ 替换成 $g(x)$,也就是 $x^2+1$。 **最后用链式法则**: $\frac{dy}{dx} = f'(g(x)) \cdot g'(x) = 3 \,\bigl(x^2 + 1\bigr)^2 \,\times\, 2x = 6x\,\bigl(x^2 + 1\bigr)^2$ ### 1.3 示例 2:$y=\sqrt{1 + x^2}$ 可以写成 $y = (1 + x^2)^{\tfrac12}$ 看作 $g(x) = 1 + x^2,\quad f(u) = u^{\tfrac12}$ - $g'(x) = 2x$ - $f'(u) = \frac12 u^{-\tfrac12} = \frac{1}{2\sqrt{u}}$ 所以 $\frac{dy}{dx} = \frac{1}{2\sqrt{g(x)}} \,\times\, 2x = \frac{2x}{2\sqrt{1 + x^2}} = \frac{x}{\sqrt{1 + x^2}}$ 这个例子也常见于导数的练习题。 ## 2. 多变量情形 在神经网络里(以及更普遍的场景),我们往往遇到**多变量**的复合函数。比如: $z = f\bigl(u,v\bigr)$ 而 $u = g(x,y), \quad v = h(x,y)$ 这个时候,如果要对 $z$ 关于 $x$ 求偏导数,就要用到**多元链式法则**: $\frac{\partial z}{\partial x} = \frac{\partial f}{\partial u}\cdot\frac{\partial u}{\partial x} + \frac{\partial f}{\partial v}\cdot\frac{\partial v}{\partial x}$ > 这是因为 $z$ 既会随 $u$ 的变化而变化,也会随 $v$ 的变化而变化;而 $u$ 和 $v$ 又都依赖于 $x$。 ## 2.1 示例:$z=f(g(x,y), h(x,y))$ 让我们构造一个稍微具体一点的例子: $z = (u + v)^2, \quad u = xy, \quad v = x + y$ 你可以认为 $f(u,v) = (u+v)^2$ $g(x,y)=x y,\quad h(x,y)=x + y$ 要对 $z$ 关于 $x$ 和 $y$ 分别求偏导,就得同时考虑 $u$ 和 $v$ 对 $x$ 的影响。 1. **$\frac{\partial z}{\partial x}$**: - 先算 $\frac{\partial f}{\partial u} = \frac{\partial}{\partial u} (u+v)^2 = 2(u+v)$ - 再算 $\frac{\partial u}{\partial x} = \frac{\partial}{\partial x}(xy) = y$ - 还有 $\frac{\partial f}{\partial v} = 2(u+v)$(同理,因为对 $v$ 的偏导也一样) - 再算 $\frac{\partial v}{\partial x} = \frac{\partial}{\partial x}(x + y) = 1$ 由多元链式法则: $\frac{\partial z}{\partial x} = \underbrace{2(u+v)}_{\frac{\partial f}{\partial u}}\times \underbrace{y}_{\frac{\partial u}{\partial x}} + \underbrace{2(u+v)}_{\frac{\partial f}{\partial v}}\times \underbrace{1}_{\frac{\partial v}{\partial x}} = 2(u+v)\cdot y + 2(u+v)\cdot 1 = 2(u+v)(y + 1)$ 最后再把 $u=xy, v=x+y$ 带回去: $\frac{\partial z}{\partial x} = 2\bigl(xy + x + y\bigr)\,(y + 1)$ 2. **$\frac{\partial z}{\partial y}$**: 可以做类似步骤: - $\frac{\partial u}{\partial y} = x$ - $\frac{\partial v}{\partial y} = 1$ $\frac{\partial z}{\partial y} = 2(u+v)\cdot x + 2(u+v)\cdot 1 = 2(u+v)(x + 1)$ 再替换 $u,v$: $\frac{\partial z}{\partial y} = 2(xy + x + y)\,(x + 1)$ 这就是"多元链式法则"在一个小示例中的运用。 ## 3. 为什么在神经网络中很重要? 在神经网络里,常见情形是: $\hat{y} = a^{(L)}$ 而 $a^{(\ell)} = f^{(\ell)}\bigl(W^{(\ell)}\,a^{(\ell-1)} + b^{(\ell)}\bigr)$ 其中 $\ell = 1,2,\dots,L$ 表示不同层级;$a^{(\ell-1)}$ 表示上一层的输出等。 如果再加上一个损失函数 $J(y, \hat{y})$,想要求 $\frac{\partial J}{\partial W^{(\ell)}}$ 或 $\frac{\partial J}{\partial b^{(\ell)}}$,就需要**不断应用链式法则**,才能把损失 $J$ 一直"传回"到某个权重/偏置上。具体展开时,你会看到我们一层层反推,所以叫**"反向传播"**(backpropagation)。但**底层逻辑**就是:对多层复合函数求导时,一次次用到链式法则**。 ## 4. 总结与建议 1. **单变量链式法则**是最基础的:$\frac{d}{dx}f(g(x))=f'(g(x))g'(x)$。先从这些经典例子($(x^2+1)^3$、$\sqrt{1+x^2}$等)入手,把它记熟并理解透。 2. **多元复合函数**时,要先搞清楚哪个变量依赖于谁,再用"对每个中间量的偏导相乘,然后相加"的方式来求最终对 $x$ 或者 $y$ 的偏导。这就是**多元链式法则**。 3. **在神经网络中**,它会体现为: - 输出层对损失的偏导 $\times$(输出层激活函数的导数) - $\times$(输出层对上一层输出的依赖) - $\times$(上一层激活函数的导数) - … 一直乘到你所关心的某个权重或偏置的那一层。 > 所以,**掌握链式法则的本质**("复合函数求导时,**分段求导再相乘**")就足以支撑你理解反向传播中那看似复杂的一连串乘积运算。 ### 参考学习顺序 1. **复习高中或大学微积分**里关于复合函数求导、偏导数、梯度的章节。 2. 做一些简单单变量、多变量例题,确保对**链式法则**和**偏导**都熟练。 3. 再回到**神经网络**:把每层输出当成一个中间变量,在纸上清晰地写出"这一层对上一层的依赖关系",然后用链式法则"连乘"求导。 这样一来,你就能慢慢理解:反向传播其实就是在做一大串链式法则的乘法运算。尤其是看到公式里"$\sigma'(z)quot;、"$(a^{(2)} - y)quot; 等等,都只是一小块一小块的局部导数。每一层计算完之后,再往前一层乘,这就是**"反向"**思想的由来,也是为什么它能高效地计算深层网络的梯度。