栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

伴随方法:线性方程的伴随方程(Adjoint Equation)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

伴随方法:线性方程的伴随方程(Adjoint Equation)

伴随方法:线性方程的伴随方程(Adjoint Equation)

伴随方法是 Neural-ODE 中十分重要的一个方法,它让一个计算量复杂到基本无法求解的问题变得有可能。在神经网络中嵌套线性方程或者非线性方程也会遇到同样的问题,这篇文章从最简单的例子线性方程中的网络参数求解中,表达一下伴随方法的思想以及一些公式的推导。

假设现在有一个线性系统 A x = b mathbf{A}boldsymbol{x}=boldsymbol{b} Ax=b,其中矩阵 A mathbf{A} A 和 b boldsymbol{b} b 都是参数 θ theta θ 的函数,那么线性系统可以表示为 A ( θ ) x = b ( θ ) mathbf{A}(theta)boldsymbol{x}=boldsymbol{b}(theta) A(θ)x=b(θ)。在机器学习领域, A ( θ ) mathbf{A}(theta) A(θ) 和 b ( θ ) boldsymbol{b}(theta) b(θ) 可以看做是神经网络, θ theta θ 是神经网络的参数,那么自然而然地,我们的目标就是想要求得损失函数关于网络参数 θ theta θ 的导数,然后利用梯度下降以及优化算法来训练网络。

对于一个线性方程,有许多的方法来求解得到 x boldsymbol{x} x,假设 x boldsymbol{x} x 会作为模型最后的预测结果,那么最终它会输入到一个损失函数 J ( x ) J(boldsymbol{x}) J(x) 中,可能会有真实标签与其对应。因此,我们最终要求的就是损失函数关于参数的导数 d J / d θ {text{d}J}/{text{d}theta} dJ/dθ。

因为 A ( θ ) mathbf{A}(theta) A(θ) 和 b ( θ ) boldsymbol{b}(theta) b(θ) 都是由 θ theta θ 决定的,因此 x boldsymbol{x} x 实际上也是 θ theta θ 的隐式函数,所以可以写成 x ( θ ) boldsymbol{x}(theta) x(θ)。我们假设参数 θ theta θ 的维度为 P P P,即 θ ∈ R P thetainmathbb{R}^{P} θ∈RP,其他的矩阵以及向量的维度分别为 A ( θ ) ∈ R N × N mathbf{A}(theta)inmathbb{R}^{Ntimes N} A(θ)∈RN×N, x ( θ ) ∈ R N boldsymbol{x}(theta)inmathbb{R}^N x(θ)∈RN, ( θ ) ∈ R N boldsymbol(theta)inmathbb{R}^N (θ)∈RN。有得时候损失函数也会是 θ theta θ 的函数,因此具体地写出来损失函数就是 J ( x ( θ ) ; θ ) J(boldsymbol{x}(theta);theta) J(x(θ);θ).

注意:为了方便各种符号的简化,下面继续表示这些变量的时候,会省略后面的 θ theta θ,但是读者应该记住这些变量依旧是 θ theta θ 的函数,在求导的时候要一直考虑这一项。

我们想要得到的是 d J / d θ text{d}J/text{d}theta dJ/dθ,要注意的是这里表达的是全微分,因此有:
d J d θ ⏟ R 1 × P = ∂ J ∂ θ ⏟ R 1 × P + ∂ J ∂ x ⏟ R 1 × N × d x d θ ⏟ R N × P , (1) underbrace{frac{text{d}J}{text{d}theta}}_{mathbb{R}^{1times P}} = underbrace{frac{partial J} {partial theta}}_{mathbb{R}^{1times P}} + underbrace{frac{partial J}{partial boldsymbol{x}}}_{mathbb{R}^{1times N}} times underbrace{frac{text{d}boldsymbol{x}}{text{d}theta}}_{mathbb{R}^{Ntimes P}}tag{1}, R1×P dθdJ​​​=R1×P ∂θ∂J​​​+R1×N ∂x∂J​​​×RN×P dθdx​​​,(1)
在每一个变量的下面都标上了各自的维度。因为 x boldsymbol{x} x 和 θ theta θ 都是一个向量,因此 d x / d θ text{d}boldsymbol{x}/text{d}theta dx/dθ 是一个雅可比矩阵,在这式子当中, d x / d θ text{d}boldsymbol{x}/text{d}theta dx/dθ 是最难求的。

我们对于线性系统 A x = b mathbf{A}boldsymbol{x}=boldsymbol{b} Ax=b 的两端,都对 θ theta θ 进行求导,可以得到:
d d θ ( A x ) = d d θ ( b ) frac{text{d}}{text{d}theta}(mathbf{A}boldsymbol{x}) = frac{text{d}}{text{d}theta}(boldsymbol{b}) dθd​(Ax)=dθd​(b)

d A d θ x + A d x d θ ⏟ target = d b d θ frac{text{d} mathbf{A}}{text{d}theta}boldsymbol{x}+mathbf{A} underbrace{frac{text{d}boldsymbol{x}}{text{d}theta}}_{text{target}} = frac{text{d}boldsymbol{b}}{text{d}theta} dθdA​x+Atarget dθdx​​​=dθdb​

我们的目标是求出 d x / d θ {text{d}boldsymbol{x}}/{text{d}theta} dx/dθ 这一项,对其进行简单的变换:
A d x d θ = d b d θ − d A d θ x , (移项) mathbf{A}frac{text{d}boldsymbol{x}}{text{d}theta} = frac{text{d}boldsymbol{b}}{text{d}theta}-frac{text{d}mathbf{A}}{text{d}theta}boldsymbol{x},quadtext{(移项)} Adθdx​=dθdb​−dθdA​x,(移项)
方程两边同时左乘 A mathbf{A} A 的逆,得到:
d x d θ ⏟ R N × P = A − 1 ⏟ R N × N ( d b d θ ⏟ R N × P − d A d θ ⏟ R N × N × P x ⏟ R N ) , (2) underbrace{frac{text{d}boldsymbol{x}}{text{d}theta}}_{mathbb{R}^{Ntimes P}} = underbrace{mathbf{A}^{-1}}_{mathbb{R}^{Ntimes N}} left( underbrace{frac{text{d}boldsymbol{b}}{text{d}theta}}_{mathbb{R}^{Ntimes P}} - underbrace{frac{text{d}mathbf{A}}{text{d}theta}}_{mathbb{R}^{Ntimes Ntimes P}} underbrace{boldsymbol{x}}_{mathbb{R}^{N}} right)tag{2}, RN×P dθdx​​​=RN×N A−1​​⎝⎜⎛​RN×P dθdb​​​−RN×N×P dθdA​​​RN x​​⎠⎟⎞​,(2)
同样的,我们在变量下面标上对应的维度。要注意的是,这里 d A / d θ text{d}mathbf{A}/text{d}theta dA/dθ 和 x boldsymbol{x} x 的维度是不匹配的,但是我们不拘泥于这里,我们关注的点在于如果要通过最直接的方式去求解 d x / d θ {text{d}boldsymbol{x}}/{text{d}theta} dx/dθ 所需要的时间是有多大。这里只需要记住,无论如何,括号里面最终得到的矩阵维度为 N × P Ntimes P N×P 的大小。同时也不用去过度的关注矩阵 A mathbf{A} A 要如何求逆(因为这里是一个神经网络的输出,所以求逆会使得问题变得更为复杂),因为在后面会发现其实没有必要对 A mathbf{A} A 求逆。

将式子 (2) 与线性方程 A x = b mathbf{A}boldsymbol{x}=boldsymbol{b} Ax=b 进行对比可以发现,其实这就是由 P P P 个线性方程组成的更大的线性方程。求解一个线性方程可以用 LU 分解或者 QR 分解,它们的时间复杂度为 O ( N 3 ) mathcal{O}(N^3) O(N3),时间花费太过于大,对于神经网络来说,参数一多基本无法求解。因此,我们要使用另外一种更为高效的方法 —— 伴随方法,来求解这个问题。

伴随方法(Adjoint Method)

我们观察 (1) 式子以及 (2) 式,会发现实际上 (1) 式的最后一项就是我们想要求的「目标」,那么我们可以将 (2) 代入到 (1) 式中,得到 (3) 式:
d J d θ ⏟ R 1 × P = ∂ J ∂ θ + ∂ J ∂ x ⏟ R 1 × N A − 1 ( d b d θ − d A d θ x ) ⏟ R N × P , (3) underbrace{frac{text{d}J}{text{d}theta}}_{mathbb{R}^{1times P}} = frac{partial J}{partial theta} + underbrace{frac{partial J}{partial boldsymbol{x}}}_{mathbb{R}^{1times N}} underbrace{mathbf{A}^{-1}left( frac{text{d}boldsymbol{b}}{text{d}theta} - frac{text{d}mathbf{A}}{text{d}theta}boldsymbol{x}right)}_{mathbb{R}^{Ntimes P}}tag{3}, R1×P dθdJ​​​=∂θ∂J​+R1×N ∂x∂J​​​RN×P A−1(dθdb​−dθdA​x)​​,(3)
我们发现最后括号里面的那一整块维度是 N × P Ntimes P N×P 的,而我们最终需要的只是一个 1 × P 1times P 1×P 的向量,这说明,实际上我们不需要额外求解 P P P 个线性方程,而只需要额外求解 1 个线性方程就能行了。

我们重新把 (3) 式分块来看:
d J d θ = ∂ J ∂ θ + ( ∂ J ∂ x A − 1 ) ⏟ λ ⊤ ( d b d θ − d A d θ x ) , (4) frac{text{d}J}{text{d}theta} = frac{partial J}{partial theta} + underbrace{left( frac{partial J}{partial boldsymbol{x}} mathbf{A}^{-1}right)}_{lambda^top} left( frac{text{d}boldsymbol{b}}{text{d}theta} - frac{text{d}mathbf{A}}{text{d}theta}boldsymbol{x} right)tag{4}, dθdJ​=∂θ∂J​+λ⊤ (∂x∂J​A−1)​​(dθdb​−dθdA​x),(4)
我们令 λ ⊤ = ∂ J ∂ x A − 1 lambda^top = frac{partial J}{partial boldsymbol{x}} mathbf{A}^{-1} λ⊤=∂x∂J​A−1,称 λ ∈ R N lambdainmathbb{R}^N λ∈RN 为伴随变量(adjoint variable),然后对这个方程进行如下变换:
λ ⊤ A = ∂ J ∂ x , (两边右乘  A ) lambda^top mathbf{A} = frac{partial J}{partial boldsymbol{x}},quadtext{(两边右乘 $mathbf{A}$)} λ⊤A=∂x∂J​,(两边右乘 A)

( λ ⊤ A ) ⊤ = ( ∂ J ∂ x ) ⊤ , (两边进行转置) left( lambda^top mathbf{A} right)^top = left( frac{partial J}{partial boldsymbol{x}} right)^top,quadtext{(两边进行转置)} (λ⊤A)⊤=(∂x∂J​)⊤,(两边进行转置)
最后我们得到 (5) 式:
A ⊤ ⏟ R N × N λ ⏟ R N = ( ∂ J ∂ x ) ⊤ ⏟ R N (5) underbrace{mathbf{A}^top}_{mathbb{R}^{Ntimes N}} underbrace{lambda}_{mathbb{R}^{N}} = underbrace{left( frac{partial J}{partial boldsymbol{x}} right)^top}_{mathbb{R}^{N}}tag{5} RN×N A⊤​​RN λ​​=RN (∂x∂J​)⊤​​(5)
观察 (5) 式不难发现,这其实与 A x = b mathbf{A}boldsymbol{x}=boldsymbol{b} Ax=b 的形式是完全一样的,而且我们不用计算矩阵 A mathbf{A} A 的逆,而是直接用它的转置,关于 ∂ J ∂ x frac{partial J}{partial boldsymbol{x}} ∂x∂J​ 这一项,利用自动微分可以很简单地计算出来。

这种求解方法就很好地规避了求逆,并且使得问题的维度大大地减小了。对于伴随方法,可以通过以下三步来计算:

第一步:前向求解 A x = b mathbf{A}boldsymbol{x}=boldsymbol{b} Ax=b,得到 x boldsymbol{x} x 的解;

第二步:后向求解伴随方程 A ⊤ λ = ( ∂ J ∂ x ) ⊤ mathbf{A}^top lambda = left( frac{partial J}{partial boldsymbol{x}} right)^top A⊤λ=(∂x∂J​)⊤,得到伴随变量 λ lambda λ;

第三步:代回原式:
d J d θ = ∂ J ∂ θ ⏟ may be zero in many problems + λ ⊤ ( d b d θ − d A d θ x ) frac{text{d}J}{text{d}theta} = underbrace{frac{partial J}{partial theta}}_{text{may be zero in many problems}} + lambda^top left( frac{text{d}boldsymbol{b}}{text{d}theta} - frac{text{d}mathbf{A}}{text{d}theta} boldsymbol{x} right) dθdJ​=may be zero in many problems ∂θ∂J​​​+λ⊤(dθdb​−dθdA​x)
利用这样的伴随方法,只需要求解两个线性系统就可以得到 d J d θ frac{text{d}J}{text{d}theta} dθdJ​。而对于 ∂ J ∂ x , ∂ J ∂ θ , d b d θ , d A d θ frac{partial J}{partial boldsymbol{x}}, frac{partial J}{partialtheta}, frac{text{d}boldsymbol{b}}{text{d}theta}, frac{text{d}mathbf{A}}{text{d}theta} ∂x∂J​,∂θ∂J​,dθdb​,dθdA​,这几个矩阵利用自动微分可以更为简单地求得。


参考:

[1] Machine Learning & Simulation. Adjoint Equation of a Linear System of Equations - by implicit derivative. YouTube

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/886663.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号