栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

【手撕LSTM】LSTM的numpy实现

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

【手撕LSTM】LSTM的numpy实现

文章目录
    • LSTM原理图
      • 便于程序实现的公式(简化版公式)
      • 关于“门”
        • 遗忘门
        • 输入门
        • 更新memory
        • 输出门
      • LSTM单元

详细理论部分参考我博文(2020李宏毅)机器学习-Recurrent Neural Network

LSTM原理图

公式
F t = σ ( X t W x f + H t − 1 W h f + b f ) mathbf{F}_{t}=sigmaleft(mathbf{X}_{t} mathbf{W}_{x f}+mathbf{H}_{t-1} mathbf{W}_{h f}+mathbf{b}_{f}right) Ft​=σ(Xt​Wxf​+Ht−1​Whf​+bf​)
I t = σ ( X t W x i + H t − 1 W h i + b i ) mathbf{I}_{t}=sigmaleft(mathbf{X}_{t} mathbf{W}_{x i}+mathbf{H}_{t-1} mathbf{W}_{h i}+mathbf{b}_{i}right) It​=σ(Xt​Wxi​+Ht−1​Whi​+bi​)
C ~ t = tanh ⁡ ( X t W x c + H t − 1 W h c + b c ) tilde{mathbf{C}}_{t}=tanh left(mathbf{X}_{t} mathbf{W}_{x c}+mathbf{H}_{t-1} mathbf{W}_{h c}+mathbf{b}_{c}right) C~t​=tanh(Xt​Wxc​+Ht−1​Whc​+bc​)
C t = F t ⊙ C t − 1 + I t ⊙ C ~ t mathbf{C}_{t}=mathbf{F}_{t} odot mathbf{C}_{t-1}+mathbf{I}_{t} odot tilde{mathbf{C}}_{t} Ct​=Ft​⊙Ct−1​+It​⊙C~t​
O t = σ ( X t W x o + H t − 1 W h o + b o ) mathbf{O}_{t}=sigmaleft(mathbf{X}_{t} mathbf{W}_{x o}+mathbf{H}_{t-1} mathbf{W}_{h o}+mathbf{b}_{o}right) Ot​=σ(Xt​Wxo​+Ht−1​Who​+bo​)
H t = O t ⊙ tanh ⁡ ( C t ) mathbf{H}_{t}=mathbf{O}_{t} odot tanh left(mathbf{C}_{t}right) Ht​=Ot​⊙tanh(Ct​)

便于程序实现的公式(简化版公式)

F t = σ ( W f [ H t − 1 , X t ] + b f ) mathbf{F}_{t}=sigmaleft(mathbf{W}_{f}[mathbf{H}_{t-1},mathbf{X}_{t}] + mathbf{b}_{f}right) Ft​=σ(Wf​[Ht−1​,Xt​]+bf​)
I t = σ ( W i [ H t − 1 , X t ] + b i ) mathbf{I}_{t}=sigmaleft(mathbf{W}_{i}[mathbf{H}_{t-1},mathbf{X}_{t}]+mathbf{b}_{i}right) It​=σ(Wi​[Ht−1​,Xt​]+bi​)
C ~ t = tanh ⁡ ( W c [ H t − 1 , X t ] + b c ) tilde{mathbf{C}}_{t}=tanh left(mathbf{W}_{c}[mathbf{H}_{t-1},mathbf{X}_{t}] +mathbf{b}_{c}right) C~t​=tanh(Wc​[Ht−1​,Xt​]+bc​)
C t = F t ⊙ C t − 1 + I t ⊙ C ~ t mathbf{C}_{t}=mathbf{F}_{t} odot mathbf{C}_{t-1}+mathbf{I}_{t} odot tilde{mathbf{C}}_{t} Ct​=Ft​⊙Ct−1​+It​⊙C~t​
O t = σ ( W o [ H t − 1 , X t ] + b o ) mathbf{O}_{t}=sigmaleft(mathbf{W}_{o}[mathbf{H}_{t-1},mathbf{X}_{t}]+mathbf{b}_{o}right) Ot​=σ(Wo​[Ht−1​,Xt​]+bo​)
H t = O t ⊙ tanh ⁡ ( C t ) mathbf{H}_{t}=mathbf{O}_{t} odot tanh left(mathbf{C}_{t}right) Ht​=Ot​⊙tanh(Ct​)

关于“门” 遗忘门

在LSTM中,遗忘门可以实现操作:
F t = σ ( W f [ H t − 1 , X t ] + b f ) mathbf{F}_{t}=sigmaleft(mathbf{W}_{f}[mathbf{H}_{t-1},mathbf{X}_{t}] + mathbf{b}_{f}right) Ft​=σ(Wf​[Ht−1​,Xt​]+bf​)
在这里, W f W_f Wf​是控制遗忘门行为的权重。将 [ H t − 1 , X t ] [mathbf{H}_{t-1},mathbf{X}_{t}] [Ht−1​,Xt​]连接起来,然后乘以 W f W_f Wf​。上面的等式使得向量 F t mathbf{F}_{t} Ft​的值介于0到1之间。该遗忘门向量将逐元素乘以先前的单元状态 C t − 1 mathbf{C}_{t-1} Ct−1​。因此,如果 F t mathbf{F}_{t} Ft​的其中一个值为0(或接近于0),则表示LSTM应该移除 C t − 1 mathbf{C}_{t-1} Ct−1​中的一部分信息,如果其中一个值为1,则它将保留信息。

输入门

输入门的公式:
I t = σ ( W i [ H t − 1 , X t ] + b i ) mathbf{I}_{t}=sigmaleft(mathbf{W}_{i}[mathbf{H}_{t-1},mathbf{X}_{t}]+mathbf{b}_{i}right) It​=σ(Wi​[Ht−1​,Xt​]+bi​)
类似于遗忘门,在这里 I t mathbf{I}_{t} It​也是值为0到1之间的向量。这将与 C ~ t tilde{mathbf{C}}_{t} C~t​逐元素相乘以计算 C t mathbf{C}_{t} Ct​。

更新memory

新的输入向量:
C ~ t = tanh ⁡ ( W c [ H t − 1 , X t ] + b c ) tilde{mathbf{C}}_{t}=tanh left(mathbf{W}_{c}[mathbf{H}_{t-1},mathbf{X}_{t}] +mathbf{b}_{c}right) C~t​=tanh(Wc​[Ht−1​,Xt​]+bc​)
最后,新的memory状态为:
C t = F t ⊙ C t − 1 + I t ⊙ C ~ t mathbf{C}_{t}=mathbf{F}_{t} odot mathbf{C}_{t-1}+mathbf{I}_{t} odot tilde{mathbf{C}}_{t} Ct​=Ft​⊙Ct−1​+It​⊙C~t​

输出门

为了确定接下来将使用哪些输出,使用以下两个公式:

O t = σ ( W o [ H t − 1 , X t ] + b o ) mathbf{O}_{t}=sigmaleft(mathbf{W}_{o}[mathbf{H}_{t-1},mathbf{X}_{t}]+mathbf{b}_{o}right) Ot​=σ(Wo​[Ht−1​,Xt​]+bo​)
H t = O t ⊙ tanh ⁡ ( C t ) mathbf{H}_{t}=mathbf{O}_{t} odot tanh left(mathbf{C}_{t}right) Ht​=Ot​⊙tanh(Ct​)

LSTM单元

实现上图中描述的LSTM单元。

说明:

  1. 将 H t − 1 mathbf{H}_{t-1} Ht−1​和 X t mathbf{X}_{t} Xt​连接在一个矩阵中: c o n c a t = [ H t − 1 X t ] concat = begin{bmatrix} mathbf{H}_{t-1} \ mathbf{X}_{t}end{bmatrix} concat=[Ht−1​Xt​​]
  2. 计算以上公式,使用sigmoid()和np.tanh()。
  3. 计算预测 y ⟨ t ⟩ y^{langle t rangle} y⟨t⟩,使用softmax()。
  4. 预测 y ^ hat y y^​公式为 y ^ = s o f t m a x ( W y H t + b y ) hat y=softmax(W_yH_t+b_y) y^​=softmax(Wy​Ht​+by​)
import numpy as np
def sigmoid(x):
    return 1/(1+np.exp(-x))

def softmax(x):
    e_x = np.exp(x-np.max(x))# 防溢出
    return e_x/e_x.sum(axis=0)
def LSTM_CELL_Forward(xt,h_prev,C_prev,parameters):
    """
    Arguments:
    xt:时间步“t”处输入的数据 shape(n_x,m)
    h_prev:时间步“t-1”的隐藏状态 shape(n_h,m)
    C_prev:时间步“t-1”的memory状态 shape(n_h,m)
    parameters
        Wf 遗忘门的权重矩阵 shape(n_h,n_h+n_x)
        bf 遗忘门的偏置 shape(n_h,1)
        Wi 输入门的权重矩阵 shape(n_h,n_h+n_x)
        bi 输入门的偏置 shape(n_h,1)
        Wc 第一个“tanh”的权重矩阵 shape(n_h,n_h+n_x)
        bc 第一个“tanh”的偏差 shape(n_h,1)
        Wo 输出门的权重矩阵 shape(n_h,n_h+n_x)
        bo 输出门的偏置 shape(n_h,1)
        Wy 将隐藏状态与输出关联的权重矩阵 shape(n_y,n_h)
        by 隐藏状态与输出相关的偏置 shape(n_y,1)
    Returns:
    h_next -- 下一个隐藏状态 shape(n_h,m)
    c_next -- 下一个memory状态 shape(n_h,m)
    yt_pred -- 时间步长“t”的预测 shape(n_y,m)
    """
    # 获取参数字典中各个参数
    Wf = parameters["Wf"]
    bf = parameters["bf"]
    Wi = parameters["Wi"]
    bi = parameters["bi"]
    Wc = parameters["Wc"]
    bc = parameters["bc"]
    Wo = parameters["Wo"]
    bo = parameters["bo"]
    Wy = parameters["Wy"]
    by = parameters["by"]
    
    # 获取 xt 和 Wy 的维度参数
    n_x, m = xt.shape
    n_y, n_h = Wy.shape
    
    #拼接 h_prev 和 xt
    concat = np.zeros((n_x+n_h,m))
    concat[: n_h, :] = h_prev
    concat[n_h :, :] = xt
    
    # 计算遗忘门、输入门、记忆细胞候选值、下一时间步的记忆细胞、输出门和下一时间步的隐状态值
    ft = sigmoid(np.dot(Wf,concat)+bf)
    it = sigmoid(np.dot(Wi,concat)+bi)
    cct = np.tanh(np.dot(Wc,concat)+bc)
    c_next = ft*c_prev + it*cct
    ot = sigmoid(np.dot(Wo,concat)+bo)
    h_next = ot*np.tanh(c_next)
    
    # LSTM单元的计算预测
    yt_pred = softmax(np.dot(Wy, h_next) + by)
    
    return h_next,c_next,yt_pred
np.random.seed(1)
xt = np.random.randn(3,10)
h_prev = np.random.randn(5,10)
c_prev = np.random.randn(5,10)
Wf = np.random.randn(5, 5+3)
bf = np.random.randn(5,1)
Wi = np.random.randn(5, 5+3)
bi = np.random.randn(5,1)
Wo = np.random.randn(5, 5+3)
bo = np.random.randn(5,1)
Wc = np.random.randn(5, 5+3)
bc = np.random.randn(5,1)
Wy = np.random.randn(2,5)
by = np.random.randn(2,1)

parameters = {"Wf": Wf, "Wi": Wi, "Wo": Wo, "Wc": Wc, "Wy": Wy, "bf": bf, "bi": bi, "bo": bo, "bc": bc, "by": by}

h_next, c_next, yt = LSTM_CELL_Forward(xt, h_prev, c_prev, parameters)
print("a_next[4] = ", h_next[4])
print("a_next.shape = ", c_next.shape)
print("c_next[2] = ", c_next[2])
print("c_next.shape = ", c_next.shape)
print("yt[1] =", yt[1])
print("yt.shape = ", yt.shape)
a_next[4] =  [-0.66408471  0.0036921   0.02088357  0.22834167 -0.85575339  0.00138482
  0.76566531  0.34631421 -0.00215674  0.43827275]
a_next.shape =  (5, 10)
c_next[2] =  [ 0.63267805  1.00570849  0.35504474  0.20690913 -1.64566718  0.11832942
  0.76449811 -0.0981561  -0.74348425 -0.26810932]
c_next.shape =  (5, 10)
yt[1] = [0.79913913 0.15986619 0.22412122 0.15606108 0.97057211 0.31146381
 0.00943007 0.12666353 0.39380172 0.07828381]
yt.shape =  (2, 10)

预期输出:
a_next[4] = [-0.66408471 0.0036921 0.02088357 0.22834167 -0.85575339 0.00138482
0.76566531 0.34631421 -0.00215674 0.43827275]

a_next.shape = (5, 10)

c_next[2] = [ 0.63267805 1.00570849 0.35504474 0.20690913 -1.64566718 0.11832942
0.76449811 -0.0981561 -0.74348425 -0.26810932]

c_next.shape = (5, 10)

yt[1] = [0.79913913 0.15986619 0.22412122 0.15606108 0.97057211 0.31146381
0.00943007 0.12666353 0.39380172 0.07828381]

yt.shape = (2, 10)

参考
https://zh-v2.d2l.ai/chapter_recurrent-modern/lstm.html
https://www.heywhale.com/mw/project/6174b96ef7e7c300175739cc

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/347990.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号