前言

本文隶属于专栏《机器学习数学通关指南》,该专栏为笔者原创,引用请注明来源,不足和错误之处请在评论区帮忙指出,谢谢!

本专栏目录结构和参考文献请见《机器学习数学通关指南》


正文

在这里插入图片描述

📚 一、定义与公式

链式法则(Chain Rule)是计算复合函数导数的核心规则,在机器学习尤其是神经网络中扮演着至关重要的角色:

  • 数学表述:如果 y = f ( u ) y = f(u) y=f(u),其中 u = g ( x ) u = g(x) u=g(x),则复合函数 y = f ( g ( x ) ) y = f(g(x)) y=f(g(x)) 的导数为:
    d y d x = d y d u ⋅ d u d x \frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx} dxdy=dudydxdu

  • 物理意义:描述"函数嵌套"时的变化率传递。外部函数对中间变量的导数( d y d u \frac{dy}{du} dudy),与中间变量对自变量的导数( d u d x \frac{du}{dx} dxdu)相乘。

在机器学习中,链式法则是理解神经网络训练过程中误差反向传递机制的数学基础,它解释了梯度如何通过网络层层传递。


💡 二、核心作用

  1. 分解复杂函数
    将多层嵌套函数(如 sin ⁡ ( e 2 x ) \sin(e^{2x}) sin(e2x))分解为简单函数的导数乘积,避免直接计算整体极限。

  2. 兼容其他求导法则
    常与乘积法则等组合使用,例如 3 e 2 x ⋅ sin ⁡ x 3e^{2x} \cdot \sin x 3e2xsinx 的导数计算:

    • 先用链式法则求 e 2 x e^{2x} e2x 的导数(外层函数 e u e^u eu,内层 u = 2 x u=2x u=2x,导数 2 e 2 x 2e^{2x} 2e2x
    • 再用乘积法则组合结果 2 e 2 x sin ⁡ x + e 2 x cos ⁡ x 2e^{2x}\sin x + e^{2x}\cos x 2e2xsinx+e2xcosx
  3. 支撑神经网络反向传播
    在深度学习中,链式法则是反向传播算法的理论基础,通过它我们可以计算复杂网络中各参数对损失函数的影响程度。


⚙️ 三、应用步骤

具体操作流程:

  1. 识别复合结构
    明确函数的内外层关系。例如函数 e 2 x e^{2x} e2x 中,外层是 e u e^u eu,内层是 u = 2 x u=2x u=2x

  2. 逐层求导

    • 先对外层函数求导: d y d u = e u = e 2 x \frac{dy}{du} = e^u = e^{2x} dudy=eu=e2x
    • 再对内层函数求导: d u d x = 2 \frac{du}{dx} = 2 dxdu=2
  3. 乘积合成结果
    d y d x = e 2 x ⋅ 2 = 2 e 2 x \frac{dy}{dx} = e^{2x} \cdot 2 = 2e^{2x} dxdy=e2x2=2e2x


🧠 四、链式法则在机器学习中的应用

4.1 神经网络反向传播

反向传播算法(Backpropagation)是深度学习的核心,它利用链式法则计算损失函数对各层参数的梯度:

  1. 前向传播:计算神经网络的输出值
  2. 计算损失:比较输出与目标值的差异
  3. 反向传播误差:利用链式法则,从输出层向输入层逐层计算梯度

4.2 多元函数的链式法则

在机器学习中,我们经常处理多元函数的情况。对于函数 z = f ( x , y ) z = f(x, y) z=f(x,y),其中 x = g ( t ) x = g(t) x=g(t) y = h ( t ) y = h(t) y=h(t),可以使用链式法则计算 d z d t \frac{dz}{dt} dtdz

d z d t = ∂ z ∂ x ⋅ d x d t + ∂ z ∂ y ⋅ d y d t \frac{dz}{dt} = \frac{\partial z}{\partial x} \cdot \frac{dx}{dt} + \frac{\partial z}{\partial y} \cdot \frac{dy}{dt} dtdz=xzdtdx+yzdtdy

这在处理神经网络中同时依赖多个输入的节点时非常有用。

4.3 实际计算示例

假设有一个简单的神经网络层: y = σ ( w x + b ) y = \sigma(wx + b) y=σ(wx+b),其中 σ \sigma σ 是激活函数,计算损失函数 L L L 对权重 w w w 的梯度:

∂ L ∂ w = ∂ L ∂ y ⋅ ∂ y ∂ σ ⋅ ∂ σ ∂ w = ∂ L ∂ y ⋅ σ ′ ( w x + b ) ⋅ x \frac{\partial L}{\partial w} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial \sigma} \cdot \frac{\partial \sigma}{\partial w} = \frac{\partial L}{\partial y} \cdot \sigma'(wx + b) \cdot x wL=yLσywσ=yLσ(wx+b)x

这正是通过链式法则将复杂的梯度计算分解为单个简单步骤的过程。


⚠️ 五、注意事项

  1. 可导性要求
    链条中的每一层函数需在对应点可导(如内层函数 u = 2 x u = 2x u=2x 需可导)。

  2. 嵌套扩展性
    支持多重复合(如 y = f ( g ( h ( x ) ) ) y = f(g(h(x))) y=f(g(h(x)))):
    d y d x = d f d g ⋅ d g d h ⋅ d h d x \frac{dy}{dx} = \frac{df}{dg} \cdot \frac{dg}{dh} \cdot \frac{dh}{dx} dxdy=dgdfdhdgdxdh

  3. 计算图理解
    在复杂神经网络中,链式法则可以通过计算图(computational graph)来直观理解,每个节点代表一个操作,边表示数据流动和梯度传递路径。

  4. 梯度消失/爆炸问题
    链式法则在深度网络中连续应用可能导致梯度消失或爆炸问题,这也是为什么选择合适的激活函数和初始化方法很重要。


🎓 六、实践应用

6.1 Python代码实现简单反向传播

# 简单神经元的前向传播与反向传播实现
import numpy as np

# 前向传播
def forward(x, w, b):
    # 线性组合
    z = np.dot(x, w) + b
    # sigmoid激活函数
    a = 1 / (1 + np.exp(-z))
    return a

# 通过链式法则计算梯度
def backward(x, y, a):
    # 损失函数对输出的梯度
    dL_da = -(y/a - (1-y)/(1-a))
    # sigmoid函数的导数
    da_dz = a * (1-a)
    # 链式法则:组合梯度
    dL_dz = dL_da * da_dz
    # 权重的梯度
    dL_dw = x * dL_dz
    # 偏置的梯度
    dL_db = dL_dz
    
    return dL_dw, dL_db

这个简单例子展示了如何使用链式法则实现神经网络中的梯度计算,是反向传播算法的核心思想。

6.2 CNN中的链式法则应用

在卷积神经网络(CNN)中,链式法则的应用更为复杂,因为需要处理多维张量和特殊操作(如卷积、池化):

# CNN中反向传播的概念示例
def cnn_backward(dL_dout, out, inputs, filters, stride):
    # 输出层梯度已知: dL_dout
    
    # 激活函数梯度(假设ReLU)
    dout_dz = (out > 0).astype(float)
    
    # 链式法则: 损失对激活前值的梯度
    dL_dz = dL_dout * dout_dz
    
    # 链式法则: 计算损失对卷积核的梯度
    dL_dfilters = convolve(inputs, dL_dz)
    
    # 链式法则: 计算损失对输入的梯度(用于传递到前一层)
    dL_dinputs = full_conv(dL_dz, filters)
    
    return dL_dinputs, dL_dfilters

这展示了链式法则在复杂网络结构中的应用方式。


📌 总结

链式法则的本质是传递变化率,它适用于任何复合函数。其工程价值在于将复杂问题分解为局部可计算的部分,是微积分工具箱中的瑰宝,也是深度学习中反向传播算法的基石。

在机器学习中,链式法则不仅是一个数学概念,更是连接理论与实践的桥梁。掌握链式法则对理解神经网络的学习机制、设计优化算法以及解决梯度问题都有着不可替代的作用。

通过将复杂的梯度计算分解为简单且可计算的步骤,链式法则使得训练大规模深度神经网络成为可能,奠定了现代机器学习和深度学习的理论基础。

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐