自动微分(Automatic Differentiation,下面简称 AD)是用来计算偏导的一种手段,在深度学习框架中广泛使用(如 Pytorh, Tensorflow)。最近想学习这些框架的实现,先从 AD 入手,框架的具体实现比较复杂,我们主要是理解 AD 的思想并做个简单的实现。
本篇只介绍算法的基础知识,实现部分请参考 实现篇 。
AD 能用来求偏导 值 的。
例如有一个 的函数(函数有 2
个输入, 1
个输出): ,对于 、 的偏导分别计为 和 。通常我们不关心偏导的解析式,只关心具体某个 , 取值下偏导 和 的值。
另外注意在神经网络在使用“梯度下降”学习时,我们关心的是“参数 ”的偏导。而不是“输入 ”的偏导。假设有 这样的神经网络,损失函数是 ,现在给了一个样本标签对 ,我们要计算的是 和 。在对号入座时要牢记这点。
为什么用 AD?
求偏导有很多做法,例如 symbolic differentiation 使用“符号计算” 得到准确的偏导解析式,但对于复杂的函数,偏导解析式会特别复杂,占用大量内存且计算慢,并且通常应用也不需要解析式;再比如 numerical differentiation 通过引入很小的位移 ,计算 得到偏导,这种方法编码容易,但受 float 误差影响大,且计算慢(有几个输入就要算几次 )。
AD 认为所有的计算最终都可以拆解成基础操作(如加减乘除, exp
, log
, sin
,cos
等基本函数)的组合。然后通过 链式法则 逐步计算偏导。这样使用方只需要正常组合基础操作,就能自动计算偏导,且不受 float 误差的影响,还可以复用一些中间结果来减少计算量(等价于动态规划)。
链式法则回顾
AD 的数学基础就是 链式法则(chain rule) :
对于函数 ,如果有子函数 ,满足 ,则求偏导有如下关系:
上述两种写法是一致的。另外如果涉及多个变量,例如 ,而 ,则有:
上面的式子叫 multivariable case :多变量的链式法则。也可以认为是 Total Derivative 全微分的链式法则。
AD 其实就是链式法则的具体实现。它有两种模式:前向模式(Forward accumulation)和反向模式(Reverse accumulation),我们只考虑反向模式。那么具体是怎么工作的呢?考虑下面的复杂函数 1
上述公式中,我们用了一些子函数来简化整个函数,画成图如下左图:
于是为了求偏导 与 的值,我们可以先定义中间值 ,根据链式法则,有
于是计算时需要先“前向”计算一次,得到 的值,之后再“后向”计算 的值(参考上右图),最终得到的 就是我们要计算的结果。而需要先“前向”计算一次,是因为后向计算时会用到前向的值,例如 就需要用到前向的 。
注意图里 的计算依赖了链式法则中多变量的情况,等于它所有后继节点偏导(即图中的 )的和。当计算图中存在 指向 的箭头时,我们记 为 从 方向对 的偏导,则公式可以扩充如下:
多输出情形
多输出的情况偏理论,跳过也影响不大。神经网络的输出,在训练时最终都会接入损失函数,得到 loss
值,一般都是一个标量,可以认为神经网络的学习总是单输出的。
在多输出的情况下,链式法则依然生效。
刚才都假设函数是 ,即 n
个输入, 1
个输出。考虑 m
个输出,即 的情况。假设输入是 ,而输出是 。此时我们要计算的偏导就不是 n
个值了,而是一个 m×n
的矩阵 2 ,每个元素 。这个矩阵一般称为 Jacobian Matrix :
其中 代表 对于所有输入的偏导(行向量)的转置。
考虑函数 , ,而函数 是二者的组合: ,则有
此时 中的每个元素:
可以看到和 的结果是一致的。不过这些性质其实都是链式法则的内容,这里也只是扩充视野。
小结
AD 把复杂的函数看成是许多小函数的组合,再利用链式法则来计算偏导。它有不同的模式,其中“后向模式”在计算偏导时先“前向”计算得到一些中间结果,之后再“反向”计算偏导。从工程的视角看,由于中间的偏导可以重复利用,能减少许多计算量。深度学习的反向传播算法(BP)是 AD 的一种特例。
所以回过头来,什么是 AD?AD 就是利用链式法则算偏导的一种实现。
参考
- A Review of automatic differentiation and its efficient implementation 一篇综述,对 AD “是什么”、“为什么”的描述比较清晰
- What is Automatic Differentiation? Youtube 视频,回过头来看它介绍了 AD 的各个方面,但第一次直接看还是比较懵的,视频也有对应的综述论文,也是比较好的补充材料
- Lecture 4 - Automatic Differentiation 一个 DL 的课程,前面的内容和其它材料差不多,最后通过扩展计算图来计算 AD 的方式对理解一些框架的具体实现很有帮助