栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

将numpy数组传递给张量流队列

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

将numpy数组传递给张量流队列

您可以创建另一个队列,将数据排队到该队列中,

num_epoch
关闭它,然后将其连接到
batch
。为了节省内存,您可以使此队列变小,并并行将项目放入队列中。各个时期之间会有一些混淆。为了完全避免混淆,您可以在下面的代码中加上
num_epochs=1
和调用它的
num_epochs
时间。

tf.reset_default_graph()data = np.array([1, 2, 3, 4])num_epochs = 5queue1_input = tf.placeholder(tf.int32)queue1 = tf.FIFOQueue(capacity=10, dtypes=[tf.int32], shapes=[()])def create_session():    config = tf.ConfigProto()    config.operation_timeout_in_ms=20000    return tf.InteractiveSession(config=config)enqueue_op = queue1.enqueue_many(queue1_input)close_op = queue1.close()dequeue_op = queue1.dequeue()batch = tf.train.shuffle_batch([dequeue_op], batch_size=4, capacity=5, min_after_dequeue=4)sess = create_session()def fill_queue():    for i in range(num_epochs):        sess.run(enqueue_op, feed_dict={queue1_input: data})    sess.run(close_op)fill_thread = threading.Thread(target=fill_queue, args=())fill_thread.start()# read the data from queue shuffledtf.train.start_queue_runners()try:    while True:        print batch.eval()except tf.errors.OutOfRangeError:    print "Done"

顺便说一句,

enqueue_many
当队列不足以将整个numpy数据集加载到队列中时,以上模式将挂起。您可以通过按以下方式分块加载数据来使自己有更大的灵活性来拥有较小的队列。

tf.reset_default_graph()data = np.array([1, 2, 3, 4])queue1_capacity = 2num_epochs = 2queue1_input = tf.placeholder(tf.int32)queue1 = tf.FIFOQueue(capacity=queue1_capacity, dtypes=[tf.int32], shapes=[()])enqueue_op = queue1.enqueue_many(queue1_input)close_op = queue1.close()dequeue_op = queue1.dequeue()def dequeue():    try:        while True: print sess.run(dequeue_op)    except:        returndef enqueue():    for i in range(num_epochs):        start_pos = 0        while start_pos < len(data): end_pos = start_pos+queue1_capacity data_chunk = data[start_pos: end_pos] sess.run(enqueue_op, feed_dict={queue1_input: data_chunk}) start_pos += queue1_capacity    sess.run(close_op)sess = create_session()enqueue_thread = threading.Thread(target=enqueue, args=())enqueue_thread.start()dequeue_thread = threading.Thread(target=dequeue, args=())dequeue_thread.start()


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/634123.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号