注:本文主要扣反向传播这一块的公式,rnn原理、历史发展等等其他文章都讲的很清楚了
前向传播公式n e t t = U x t + W h t − 1 + b net^t= Ux^t +Wh^{t-1} +b nett=Uxt+Wht−1+b
h t = t a n h ( n e t t ) h^t = tanh(net^t) ht=tanh(nett)
o t = V h t + c o^t = Vh^t+c ot=Vht+c
代码net = np.dot(self.Wx, xt) + np.dot(self.Wh, h_prev) + self.Bh h = tanh(net) o = np.dot(self.Wy, h) + self.By y_hat = softmax(o)反向传播
L代表损失
根据链式求导:
∂ L ∂ W = ∂ L ∂ O ∂ O ∂ h t ∂ h t ∂ W frac{partial L}{partial W} = frac{partial L}{partial O}frac{partial O}{partial {h^t}}frac{partial {h^t}}{partial{W}} ∂W∂L=∂O∂L∂ht∂O∂W∂ht
前面三项我们可以正常求导,主要是最后一项麻烦点:
∂ h t ∂ W frac{partial {h^t}}{partial{W}} ∂W∂ht
这里假设t = 3
h0不再包含U这一项,所以求导最后只剩下h0
公式推导∂ h 3 ∂ W = ∂ h 3 ∂ n e t 3 ( h 2 + W ∂ h 2 ∂ W ) frac{partial {h^3}}{partial{W}} = frac{partial {h^3}}{partial{net^3}}(h^2+Wfrac{partial {h^2}}{partial{W}} ) ∂W∂h3=∂net3∂h3(h2+W∂W∂h2)
= ∂ h 3 ∂ n e t 3 ( h 2 + W ∂ h 2 ∂ n e t 2 ( h 1 + W ∂ h 1 ∂ W ) ) = frac{partial {h^3}}{partial{net^3}}(h^2+Wfrac{partial {h^2}}{partial{net^{2}}} (h^1+Wfrac{partial {h^1}}{partial{W}})) =∂net3∂h3(h2+W∂net2∂h2(h1+W∂W∂h1))
= ∂ h 3 ∂ n e t 3 ( h 2 + W ∂ h 2 ∂ n e t 2 ( h 1 + W ∂ h 1 ∂ n e t 1 ( W h 0 ) ) = frac{partial {h^3}}{partial{net^3}}(h^2+Wfrac{partial {h^2}}{partial{net^{2}}} (h^1+Wfrac{partial {h^1}}{partial{net^{1}}}(W h^0)) =∂net3∂h3(h2+W∂net2∂h2(h1+W∂net1∂h1(Wh0))
= ∂ h 3 ∂ n e t 3 h 2 W 0 + = frac{partial {h^3}}{partial{net^3}}h^2W^0+ =∂net3∂h3h2W0+
∂ h 3 ∂ n e t 3 ∂ h 2 ∂ n e t 2 h 1 W 1 + frac{partial {h^3}}{partial{net^3}}frac{partial {h^2}}{partial{net^2}}h^1W^1+ ∂net3∂h3∂net2∂h2h1W1+
∂ h 3 ∂ n e t 3 ∂ h 2 ∂ n e t 2 ∂ h 1 ∂ n e t 1 h 0 W 2 frac{partial {h^3}}{partial{net^3}}frac{partial {h^2}}{partial{net^2}}frac{partial {h^1}}{partial{net^1}}h^0W^2 ∂net3∂h3∂net2∂h2∂net1∂h1h0W2
从特例中找到规律
∂ L ∂ W = ∑ k = 1 t ∂ L ∂ O ∂ O ∂ h t ( ∏ j = k t ∂ h j ∂ n e t j ) W t − k h k − 1 frac{partial L}{partial W} =sum_{k=1}^{t} frac{partial L}{partial O}frac{partial O}{partial {h^t}}(prod_{j=k}^tfrac{partial {h^j}}{partial{net^j}})W^{t-k}h_{k-1} ∂W∂L=∑k=1t∂O∂L∂ht∂O(∏j=kt∂netj∂hj)Wt−khk−1
同理:
∂ L ∂ U = ∑ k = 1 t ∂ L ∂ O ∂ O ∂ h t ( ∏ j = k t ∂ h j ∂ n e t j ) W t − k x k frac{partial L}{partial U} =sum_{k=1}^{t} frac{partial L}{partial O}frac{partial O}{partial {h^t}}(prod_{j=k}^tfrac{partial {h^j}}{partial{net^j}})W^{t-k}x_{k} ∂U∂L=∑k=1t∂O∂L∂ht∂O(∏j=kt∂netj∂hj)Wt−kxk
∂ L ∂ b = ∑ k = 1 t ∂ L ∂ O ∂ O ∂ h t ( ∏ j = k t ∂ h j ∂ n e t j ) W t − k frac{partial L}{partial b} =sum_{k=1}^{t} frac{partial L}{partial O}frac{partial O}{partial {h^t}}(prod_{j=k}^tfrac{partial {h^j}}{partial{net^j}})W^{t-k} ∂b∂L=∑k=1t∂O∂L∂ht∂O(∏j=kt∂netj∂hj)Wt−k
V和c好求,因为不涉及时间维度:
∂ L ∂ V = ∑ k = 1 t ∂ L ∂ O h t . T frac{partial L}{partial V} =sum_{k=1}^{t} frac{partial L}{partial O}h_t.T ∂V∂L=∑k=1t∂O∂Lht.T
∂ L ∂ c = ∑ k = 1 t ∂ L ∂ O frac{partial L}{partial c} =sum_{k=1}^{t} frac{partial L}{partial O} ∂c∂L=∑k=1t∂O∂L
代码for t in range(T):
#这里我是自己推的公式找的规律,delta项在W、U、b项均有出现,所以先求出来
D_tanh = reduce(lambda x, y: x * (1 - np.square(h[t])), range(T - 1, t-1, -1), 1)
delta = np.power(self.Wh,T-t-1).dot(D_tanh)*np.dot(self.Wy.T,y_hat[t] - y[t])
# W
W_gradient += delta*(h[t])
# U
U_gradient += delta.dot(xt[t].T)
# b
B_gradient += delta
# V
V_gradient += np.dot((y_hat[t] - y[t]), h[t].T)
# c
C_gradient += (y_hat[t] - y[t])
delta项就是:
∑ k = 1 t ∂ L ∂ O ∂ O ∂ h t ( ∏ j = k t ∂ h j ∂ n e t j ) W t − k sum_{k=1}^{t} frac{partial L}{partial O}frac{partial O}{partial {h^t}}(prod_{j=k}^tfrac{partial {h^j}}{partial{net^j}})W^{t-k} ∑k=1t∂O∂L∂ht∂O(∏j=kt∂netj∂hj)Wt−k
写代码的时候多注意结构和变量有没有缺失以及shape的情况,有时候代码不对劲大部分还是因为小毛病没注意到~
附上我的rnn_demo:rnn



