伴随方法:线性方程的伴随方程(Adjoint Equation)
伴随方法是 Neural-ODE 中十分重要的一个方法,它让一个计算量复杂到基本无法求解的问题变得有可能。在神经网络中嵌套线性方程或者非线性方程也会遇到同样的问题,这篇文章从最简单的例子线性方程中的网络参数求解中,表达一下伴随方法的思想以及一些公式的推导。
假设现在有一个线性系统 A x = b \mathbf{A}\boldsymbol{x}=\boldsymbol{b}A x = b ,其中矩阵 A \mathbf{A}A 和 b \boldsymbol{b}b 都是参数 θ \thetaθ 的函数,那么线性系统可以表示为 A ( θ ) x = b ( θ ) \mathbf{A}(\theta)\boldsymbol{x}=\boldsymbol{b}(\theta)A ( θ ) x = b ( θ ) 。在机器学习领域,A ( θ ) \mathbf{A}(\theta)A ( θ ) 和 b ( θ ) \boldsymbol{b}(\theta)b ( θ ) 可以看做是神经网络,θ \thetaθ 是神经网络的参数,那么自然而然地,我们的目标就是想要求得损失函数关于网络参数 θ \thetaθ 的导数,然后利用梯度下降以及优化算法来训练网络。
对于一个线性方程,有许多的方法来求解得到 x \boldsymbol{x}x ,假设 x \boldsymbol{x}x 会作为模型最后的预测结果,那么最终它会输入到一个损失函数 J ( x ) J(\boldsymbol{x})J ( x ) 中,可能会有真实标签与其对应。因此,我们最终要求的就是损失函数关于参数的导数 d J / d θ {\text{d}J}/{\text{d}\theta}d J / d θ 。
因为 A ( θ ) \mathbf{A}(\theta)A ( θ ) 和 b ( θ ) \boldsymbol{b}(\theta)b ( θ ) 都是由 θ \thetaθ 决定的,因此 x \boldsymbol{x}x 实际上也是 θ \thetaθ 的隐式函数,所以可以写成 x ( θ ) \boldsymbol{x}(\theta)x ( θ ) 。我们假设参数 θ \thetaθ 的维度为 P PP ,即 θ ∈ R P \theta\in\mathbb{R}^{P}θ ∈ R P ,其他的矩阵以及向量的维度分别为 A ( θ ) ∈ R N × N \mathbf{A}(\theta)\in\mathbb{R}^{N\times N}A ( θ ) ∈ R N × N ,x ( θ ) ∈ R N \boldsymbol{x}(\theta)\in\mathbb{R}^Nx ( θ ) ∈ R N ,( θ ) ∈ R N \boldsymbol(\theta)\in\mathbb{R}^N( θ ) ∈ R N 。有得时候损失函数也会是 θ \thetaθ 的函数,因此具体地写出来损失函数就是 J ( x ( θ ) ; θ ) J(\boldsymbol{x}(\theta);\theta)J ( x ( θ ) ; θ ) .
注意:为了方便各种符号的简化,下面继续表示这些变量的时候,会省略后面的 θ \thetaθ ,但是读者应该记住这些变量依旧是 θ \thetaθ 的函数,在求导的时候要一直考虑这一项。
我们想要得到的是 d J / d θ \text{d}J/\text{d}\thetad J / d θ ,要注意的是这里表达的是全微分,因此有:
d J d θ ⏟ R 1 × P = ∂ J ∂ θ ⏟ R 1 × P + ∂ J ∂ x ⏟ R 1 × N × d x d θ ⏟ R N × P , (1) \underbrace{\frac{\text{d}J}{\text{d}\theta}}_{\mathbb{R}^{1\times P}} = \underbrace{\frac{\partial J}{\partial \theta}}_{\mathbb{R}^{1\times P}} + \underbrace{\frac{\partial J}{\partial \boldsymbol{x}}}_{\mathbb{R}^{1\times N}} \times \underbrace{\frac{\text{d}\boldsymbol{x}}{\text{d}\theta}}_{\mathbb{R}^{N\times P}}\tag{1},R 1 × P d θ d J = R 1 × P ∂ θ ∂ J + R 1 × N ∂ x ∂ J × R N × P d θ d x , ( 1 )
在每一个变量的下面都标上了各自的维度。因为 x \boldsymbol{x}x 和 θ \thetaθ 都是一个向量,因此 d x / d θ \text{d}\boldsymbol{x}/\text{d}\thetad x / d θ 是一个雅可比矩阵,在这式子当中,d x / d θ \text{d}\boldsymbol{x}/\text{d}\thetad x / d θ 是最难求的。
我们对于线性系统 A x = b \mathbf{A}\boldsymbol{x}=\boldsymbol{b}A x = b 的两端,都对 θ \thetaθ 进行求导,可以得到:
d d θ ( A x ) = d d θ ( b ) \frac{\text{d}}{\text{d}\theta}(\mathbf{A}\boldsymbol{x}) = \frac{\text{d}}{\text{d}\theta}(\boldsymbol{b})d θ d ( A x ) = d θ d ( b )
d A d θ x + A d x d θ ⏟ target = d b d θ \frac{\text{d} \mathbf{A}}{\text{d}\theta}\boldsymbol{x}+\mathbf{A} \underbrace{\frac{\text{d}\boldsymbol{x}}{\text{d}\theta}}_{\text{target}} = \frac{\text{d}\boldsymbol{b}}{\text{d}\theta}d θ d A x + A target d θ d x = d θ d b
我们的目标是求出 d x / d θ {\text{d}\boldsymbol{x}}/{\text{d}\theta}d x / d θ 这一项,对其进行简单的变换:
A d x d θ = d b d θ − d A d θ x , (移项) \mathbf{A}\frac{\text{d}\boldsymbol{x}}{\text{d}\theta} = \frac{\text{d}\boldsymbol{b}}{\text{d}\theta}-\frac{\text{d}\mathbf{A}}{\text{d}\theta}\boldsymbol{x},\quad\text{(移项)}A d θ d x = d θ d b − d θ d A x , (移项)
方程两边同时左乘 A \mathbf{A}A 的逆,得到:
d x d θ ⏟ R N × P = A − 1 ⏟ R N × N ( d b d θ ⏟ R N × P − d A d θ ⏟ R N × N × P x ⏟ R N ) , (2) \underbrace{\frac{\text{d}\boldsymbol{x}}{\text{d}\theta}}_{\mathbb{R}^{N\times P}} = \underbrace{\mathbf{A}^{-1}}_{\mathbb{R}^{N\times N}} \left( \underbrace{\frac{\text{d}\boldsymbol{b}}{\text{d}\theta}}_{\mathbb{R}^{N\times P}} - \underbrace{\frac{\text{d}\mathbf{A}}{\text{d}\theta}}_{\mathbb{R}^{N\times N\times P}} \underbrace{\boldsymbol{x}}_{\mathbb{R}^{N}} \right)\tag{2},R N × P d θ d x = R N × N A − 1 ⎝ ⎜ ⎛ R N × P d θ d b − R N × N × P d θ d A R N x ⎠ ⎟ ⎞ , ( 2 )
同样的,我们在变量下面标上对应的维度。要注意的是,这里 d A / d θ \text{d}\mathbf{A}/\text{d}\thetad A / d θ 和 x \boldsymbol{x}x 的维度是不匹配的,但是我们不拘泥于这里,我们关注的点在于如果要通过最直接的方式去求解 d x / d θ {\text{d}\boldsymbol{x}}/{\text{d}\theta}d x / d θ 所需要的时间是有多大。这里只需要记住,无论如何,括号里面最终得到的矩阵维度为 N × P N\times PN × P 的大小。同时也不用去过度的关注矩阵 A \mathbf{A}A 要如何求逆(因为这里是一个神经网络的输出,所以求逆会使得问题变得更为复杂),因为在后面会发现其实没有必要对 A \mathbf{A}A 求逆。
将式子 (2) 与线性方程 A x = b \mathbf{A}\boldsymbol{x}=\boldsymbol{b}A x = b 进行对比可以发现,其实这就是由 P PP 个线性方程组成的更大的线性方程。求解一个线性方程可以用 LU 分解 或者 QR 分解 ,它们的时间复杂度为 O ( N 3 ) \mathcal{O}(N^3)O ( N 3 ) ,时间花费太过于大,对于神经网络来说,参数一多基本无法求解。因此,我们要使用另外一种更为高效的方法 —— 伴随方法,来求解这个问题。
伴随方法(Adjoint Method)
我们观察 (1) 式子以及 (2) 式,会发现实际上 (1) 式的最后一项就是我们想要求的「目标」,那么我们可以将 (2) 代入到 (1) 式中,得到 (3) 式:
d J d θ ⏟ R 1 × P = ∂ J ∂ θ + ∂ J ∂ x ⏟ R 1 × N A − 1 ( d b d θ − d A d θ x ) ⏟ R N × P , (3) \underbrace{\frac{\text{d}J}{\text{d}\theta}}_{\mathbb{R}^{1\times P}} = \frac{\partial J}{\partial \theta} + \underbrace{\frac{\partial J}{\partial \boldsymbol{x}}}_{\mathbb{R}^{1\times N}} \underbrace{\mathbf{A}^{-1}\left( \frac{\text{d}\boldsymbol{b}}{\text{d}\theta} - \frac{\text{d}\mathbf{A}}{\text{d}\theta}\boldsymbol{x}\right)}_{\mathbb{R}^{N\times P}}\tag{3},R 1 × P d θ d J = ∂ θ ∂ J + R 1 × N ∂ x ∂ J R N × P A − 1 ( d θ d b − d θ d A x ) , ( 3 )
我们发现最后括号里面的那一整块维度是 N × P N\times PN × P 的,而我们最终需要的只是一个 1 × P 1\times P1 × P 的向量,这说明,实际上我们不需要额外求解 P PP 个线性方程,而只需要额外求解 1 个线性方程就能行了。
我们重新把 (3) 式分块来看:
d J d θ = ∂ J ∂ θ + ( ∂ J ∂ x A − 1 ) ⏟ λ ⊤ ( d b d θ − d A d θ x ) , (4) \frac{\text{d}J}{\text{d}\theta} = \frac{\partial J}{\partial \theta} + \underbrace{\left( \frac{\partial J}{\partial \boldsymbol{x}} \mathbf{A}^{-1}\right)}_{\lambda^\top} \left( \frac{\text{d}\boldsymbol{b}}{\text{d}\theta} - \frac{\text{d}\mathbf{A}}{\text{d}\theta}\boldsymbol{x} \right)\tag{4},d θ d J = ∂ θ ∂ J + λ ⊤ ( ∂ x ∂ J A − 1 ) ( d θ d b − d θ d A x ) , ( 4 )
我们令 λ ⊤ = ∂ J ∂ x A − 1 \lambda^\top = \frac{\partial J}{\partial \boldsymbol{x}} \mathbf{A}^{-1}λ ⊤ = ∂ x ∂ J A − 1 ,称 λ ∈ R N \lambda\in\mathbb{R}^Nλ ∈ R N 为伴随变量(adjoint variable),然后对这个方程进行如下变换:
λ ⊤ A = ∂ J ∂ x , (两边右乘 A ) \lambda^\top \mathbf{A} = \frac{\partial J}{\partial \boldsymbol{x}},\quad\text{(两边右乘 $\mathbf{A}$)} λ ⊤ A = ∂ x ∂ J , ( 两边右乘 A )
( λ ⊤ A ) ⊤ = ( ∂ J ∂ x ) ⊤ , (两边进行转置) \left( \lambda^\top \mathbf{A} \right)^\top = \left( \frac{\partial J}{\partial \boldsymbol{x}} \right)^\top,\quad\text{(两边进行转置)}( λ ⊤ A ) ⊤ = ( ∂ x ∂ J ) ⊤ , ( 两边进行转置)
最后我们得到 (5) 式:
A ⊤ ⏟ R N × N λ ⏟ R N = ( ∂ J ∂ x ) ⊤ ⏟ R N (5) \underbrace{\mathbf{A}^\top}_{\mathbb{R}^{N\times N}} \underbrace{\lambda}_{\mathbb{R}^{N}} = \underbrace{\left( \frac{\partial J}{\partial \boldsymbol{x}} \right)^\top}_{\mathbb{R}^{N}}\tag{5}R N × N A ⊤ R N λ = R N ( ∂ x ∂ J ) ⊤ ( 5 )
观察 (5) 式不难发现,这其实与 A x = b \mathbf{A}\boldsymbol{x}=\boldsymbol{b}A x = b 的形式是完全一样的,而且我们不用计算矩阵 A \mathbf{A}A 的逆,而是直接用它的转置,关于 ∂ J ∂ x \frac{\partial J}{\partial \boldsymbol{x}}∂ x ∂ J 这一项,利用自动微分可以很简单地计算出来。
这种求解方法就很好地规避了求逆,并且使得问题的维度大大地减小了。对于伴随方法,可以通过以下三步来计算:
第一步:前向求解 A x = b \mathbf{A}\boldsymbol{x}=\boldsymbol{b}A x = b ,得到 x \boldsymbol{x}x 的解;
第二步:后向求解伴随方程 A ⊤ λ = ( ∂ J ∂ x ) ⊤ \mathbf{A}^\top \lambda = \left( \frac{\partial J}{\partial \boldsymbol{x}} \right)^\topA ⊤ λ = ( ∂ x ∂ J ) ⊤ ,得到伴随变量 λ \lambdaλ ;
第三步:代回原式:
d J d θ = ∂ J ∂ θ ⏟ may be zero in many problems + λ ⊤ ( d b d θ − d A d θ x ) \frac{\text{d}J}{\text{d}\theta} = \underbrace{\frac{\partial J}{\partial \theta}}_{\text{may be zero in many problems}} + \lambda^\top \left( \frac{\text{d}\boldsymbol{b}}{\text{d}\theta} - \frac{\text{d}\mathbf{A}}{\text{d}\theta} \boldsymbol{x} \right)d θ d J = may be zero in many problems ∂ θ ∂ J + λ ⊤ ( d θ d b − d θ d A x )
利用这样的伴随方法,只需要求解两个线性系统就可以得到 d J d θ \frac{\text{d}J}{\text{d}\theta}d θ d J 。而对于 ∂ J ∂ x , ∂ J ∂ θ , d b d θ , d A d θ \frac{\partial J}{\partial \boldsymbol{x}}, \frac{\partial J}{\partial\theta}, \frac{\text{d}\boldsymbol{b}}{\text{d}\theta}, \frac{\text{d}\mathbf{A}}{\text{d}\theta}∂ x ∂ J , ∂ θ ∂ J , d θ d b , d θ d A ,这几个矩阵利用自动微分可以更为简单地求得。
参考:
[1] Machine Learning & Simulation. Adjoint Equation of a Linear System of Equations - by implicit derivative. YouTube