栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

Tensorflow,在RNN中保存状态的最佳方法?

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Tensorflow,在RNN中保存状态的最佳方法?

这是

state_is_tuple=True
通过定义状态变量来更新LSTM初始状态的代码。它还支持多层。

我们定义了两个函数-
一个用于获取具有初始零状态的状态变量,另一个用于返回操作的函数,可以传递给该函数以

session.run
用LSTM的最后一个隐藏状态更新状态变量。

def get_state_variables(batch_size, cell):    # For each layer, get the initial state and make a variable out of it    # to enable updating its value.    state_variables = []    for state_c, state_h in cell.zero_state(batch_size, tf.float32):        state_variables.append(tf.contrib.rnn.LSTMStateTuple( tf.Variable(state_c, trainable=False), tf.Variable(state_h, trainable=False)))    # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state    return tuple(state_variables)def get_state_update_op(state_variables, new_states):    # Add an operation to update the train states with the last state tensors    update_ops = []    for state_variable, new_state in zip(state_variables, new_states):        # Assign the new state to the state variables on this layer        update_ops.extend([state_variable[0].assign(new_state[0]),     state_variable[1].assign(new_state[1])])    # Return a tuple in order to combine all update_ops into a single operation.    # The tuple's actual value should not be used.    return tf.tuple(update_ops)

我们可以用它来更新每批LSTM的状态。请注意,我

tf.nn.dynamic_rnn
用于展开:

data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))cell_layer = tf.contrib.rnn.GRUCell(256)cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers)# For each layer, get the initial state. states will be a tuple of LSTMStateTuples.states = get_state_variables(batch_size, cell)# Unroll the LSTMoutputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)# Add an operation to update the train states with the last state tensors.update_op = get_state_update_op(states, new_states)sess = tf.Session()sess.run(tf.global_variables_initializer())sess.run([outputs, update_op], {data: ...})

该答案的主要区别在于,

state_is_tuple=True
使LSTM的状态成为包含两个变量(单元状态和隐藏状态)而不是单个变量的LSTMStateTuple。然后,使用多层可以使LSTM的状态成为LSTMStateTuples的元组-
每层一个。

重置为零

使用训练有素的模型进行预测/解码时,您可能需要将状态重置为零。然后,您可以使用此功能:

def get_state_reset_op(state_variables, cell, batch_size):    # Return an operation to set each variable in a list of LSTMStateTuples to zero    zero_states = cell.zero_state(batch_size, tf.float32)    return get_state_update_op(state_variables, zero_states)

例如上面的例子:

reset_state_op = get_state_reset_op(state, cell, max_batch_size)# Reset the state to zero before feeding inputsess.run([reset_state_op])sess.run([outputs, update_op], {data: ...})


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/637782.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号