栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

了解Tensorflow LSTM输入形状

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

了解Tensorflow LSTM输入形状

该的文件

tf.nn.dynamic_rnn
中指出:

inputs
:RNN输入。如果是
time_major == False
(默认值),则它必须是shape:的张量
[batch_size,max_time, ...]
,或此类元素的嵌套元组。

在您的情况下,这意味着输入的形状应为

[batch_size, 10,2]
。无需一次训练所有4000个序列,而是
batch_size
在每次训练迭代中仅使用其中的许多序列。类似于以下内容的东西应该起作用(为清楚起见,添加了重新塑形):

batch_size = 32# batch_size sequences of length 10 with 2 values for each timestepinput = get_batch(X, batch_size).reshape([batch_size, 10, 2])# Create LSTM cell with state size 256. Could also use GRUCell, ...# Note: state_is_tuple=False is deprecated;# the option might be completely removed in the futurecell = tf.nn.rnn_cell.LSTMCell(256, state_is_tuple=True)outputs, state = tf.nn.dynamic_rnn(cell,  input,  sequence_length=[10]*batch_size,  dtype=tf.float32)

从文档开始,

outputs
将具有形状
[batch_size,10,256]
,即每个时间步长一个256输出。
state
将是一个形状的元组
[batch_size,256]
。您可以据此预测最终值(每个序列一个):

predictions = tf.contrib.layers.fully_connected(state.h,    num_outputs=1,    activation_fn=None)loss = get_loss(get_batch(Y).reshape([batch_size, 1]), predictions)

形状为

outputs
和的数字256
state
cell.output_size
和决定。
cell.state_size
。在创建
LSTMCell
上述内容时,它们是相同的。另请参阅LSTMCell文档。



转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/647273.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号