栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

numpy-rnn 从公式推导到代码实现

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

numpy-rnn 从公式推导到代码实现

注:本文主要扣反向传播这一块的公式,rnn原理、历史发展等等其他文章都讲的很清楚了

前向传播公式

n e t t = U x t + W h t − 1 + b net^t= Ux^t +Wh^{t-1} +b nett=Uxt+Wht−1+b

h t = t a n h ( n e t t ) h^t = tanh(net^t) ht=tanh(nett)

o t = V h t + c o^t = Vh^t+c ot=Vht+c

代码
net = np.dot(self.Wx, xt) + np.dot(self.Wh, h_prev) + self.Bh
h = tanh(net)
o = np.dot(self.Wy, h) + self.By
y_hat = softmax(o)
反向传播

L代表损失

根据链式求导:

∂ L ∂ W = ∂ L ∂ O ∂ O ∂ h t ∂ h t ∂ W frac{partial L}{partial W} = frac{partial L}{partial O}frac{partial O}{partial {h^t}}frac{partial {h^t}}{partial{W}} ∂W∂L​=∂O∂L​∂ht∂O​∂W∂ht​

前面三项我们可以正常求导,主要是最后一项麻烦点:

∂ h t ∂ W frac{partial {h^t}}{partial{W}} ∂W∂ht​

这里假设t = 3

h0不再包含U这一项,所以求导最后只剩下h0

公式推导

∂ h 3 ∂ W = ∂ h 3 ∂ n e t 3 ( h 2 + W ∂ h 2 ∂ W ) frac{partial {h^3}}{partial{W}} = frac{partial {h^3}}{partial{net^3}}(h^2+Wfrac{partial {h^2}}{partial{W}} ) ∂W∂h3​=∂net3∂h3​(h2+W∂W∂h2​)

        = ∂ h 3 ∂ n e t 3 ( h 2 + W ∂ h 2 ∂ n e t 2 ( h 1 + W ∂ h 1 ∂ W ) ) = frac{partial {h^3}}{partial{net^3}}(h^2+Wfrac{partial {h^2}}{partial{net^{2}}} (h^1+Wfrac{partial {h^1}}{partial{W}}))        =∂net3∂h3​(h2+W∂net2∂h2​(h1+W∂W∂h1​))

        = ∂ h 3 ∂ n e t 3 ( h 2 + W ∂ h 2 ∂ n e t 2 ( h 1 + W ∂ h 1 ∂ n e t 1 ( W h 0 ) ) = frac{partial {h^3}}{partial{net^3}}(h^2+Wfrac{partial {h^2}}{partial{net^{2}}} (h^1+Wfrac{partial {h^1}}{partial{net^{1}}}(W h^0))        =∂net3∂h3​(h2+W∂net2∂h2​(h1+W∂net1∂h1​(Wh0))

        = ∂ h 3 ∂ n e t 3 h 2 W 0 + = frac{partial {h^3}}{partial{net^3}}h^2W^0+        =∂net3∂h3​h2W0+

          ∂ h 3 ∂ n e t 3 ∂ h 2 ∂ n e t 2 h 1 W 1 + frac{partial {h^3}}{partial{net^3}}frac{partial {h^2}}{partial{net^2}}h^1W^1+          ∂net3∂h3​∂net2∂h2​h1W1+

          ∂ h 3 ∂ n e t 3 ∂ h 2 ∂ n e t 2 ∂ h 1 ∂ n e t 1 h 0 W 2 frac{partial {h^3}}{partial{net^3}}frac{partial {h^2}}{partial{net^2}}frac{partial {h^1}}{partial{net^1}}h^0W^2          ∂net3∂h3​∂net2∂h2​∂net1∂h1​h0W2

从特例中找到规律

∂ L ∂ W = ∑ k = 1 t ∂ L ∂ O ∂ O ∂ h t ( ∏ j = k t ∂ h j ∂ n e t j ) W t − k h k − 1 frac{partial L}{partial W} =sum_{k=1}^{t} frac{partial L}{partial O}frac{partial O}{partial {h^t}}(prod_{j=k}^tfrac{partial {h^j}}{partial{net^j}})W^{t-k}h_{k-1} ∂W∂L​=∑k=1t​∂O∂L​∂ht∂O​(∏j=kt​∂netj∂hj​)Wt−khk−1​

同理:

∂ L ∂ U = ∑ k = 1 t ∂ L ∂ O ∂ O ∂ h t ( ∏ j = k t ∂ h j ∂ n e t j ) W t − k x k frac{partial L}{partial U} =sum_{k=1}^{t} frac{partial L}{partial O}frac{partial O}{partial {h^t}}(prod_{j=k}^tfrac{partial {h^j}}{partial{net^j}})W^{t-k}x_{k} ∂U∂L​=∑k=1t​∂O∂L​∂ht∂O​(∏j=kt​∂netj∂hj​)Wt−kxk​

∂ L ∂ b = ∑ k = 1 t ∂ L ∂ O ∂ O ∂ h t ( ∏ j = k t ∂ h j ∂ n e t j ) W t − k frac{partial L}{partial b} =sum_{k=1}^{t} frac{partial L}{partial O}frac{partial O}{partial {h^t}}(prod_{j=k}^tfrac{partial {h^j}}{partial{net^j}})W^{t-k} ∂b∂L​=∑k=1t​∂O∂L​∂ht∂O​(∏j=kt​∂netj∂hj​)Wt−k

V和c好求,因为不涉及时间维度:

∂ L ∂ V = ∑ k = 1 t ∂ L ∂ O h t . T frac{partial L}{partial V} =sum_{k=1}^{t} frac{partial L}{partial O}h_t.T ∂V∂L​=∑k=1t​∂O∂L​ht​.T

∂ L ∂ c = ∑ k = 1 t ∂ L ∂ O frac{partial L}{partial c} =sum_{k=1}^{t} frac{partial L}{partial O} ∂c∂L​=∑k=1t​∂O∂L​

代码
for t in range(T):

    #这里我是自己推的公式找的规律,delta项在W、U、b项均有出现,所以先求出来
    D_tanh = reduce(lambda x, y: x * (1 - np.square(h[t])), range(T - 1, t-1, -1), 1)
    delta = np.power(self.Wh,T-t-1).dot(D_tanh)*np.dot(self.Wy.T,y_hat[t] - y[t])

    # W
    W_gradient += delta*(h[t])
    # U
    U_gradient += delta.dot(xt[t].T)
    # b
    B_gradient += delta

    # V
    V_gradient += np.dot((y_hat[t] - y[t]), h[t].T)
    # c
    C_gradient += (y_hat[t] - y[t])

delta项就是:

∑ k = 1 t ∂ L ∂ O ∂ O ∂ h t ( ∏ j = k t ∂ h j ∂ n e t j ) W t − k sum_{k=1}^{t} frac{partial L}{partial O}frac{partial O}{partial {h^t}}(prod_{j=k}^tfrac{partial {h^j}}{partial{net^j}})W^{t-k} ∑k=1t​∂O∂L​∂ht∂O​(∏j=kt​∂netj∂hj​)Wt−k

写代码的时候多注意结构和变量有没有缺失以及shape的情况,有时候代码不对劲大部分还是因为小毛病没注意到~

附上我的rnn_demo:rnn

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/580626.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号