您可以创建另一个队列,将数据排队到该队列中,
num_epoch关闭它,然后将其连接到
batch。为了节省内存,您可以使此队列变小,并并行将项目放入队列中。各个时期之间会有一些混淆。为了完全避免混淆,您可以在下面的代码中加上
num_epochs=1和调用它的
num_epochs时间。
tf.reset_default_graph()data = np.array([1, 2, 3, 4])num_epochs = 5queue1_input = tf.placeholder(tf.int32)queue1 = tf.FIFOQueue(capacity=10, dtypes=[tf.int32], shapes=[()])def create_session(): config = tf.ConfigProto() config.operation_timeout_in_ms=20000 return tf.InteractiveSession(config=config)enqueue_op = queue1.enqueue_many(queue1_input)close_op = queue1.close()dequeue_op = queue1.dequeue()batch = tf.train.shuffle_batch([dequeue_op], batch_size=4, capacity=5, min_after_dequeue=4)sess = create_session()def fill_queue(): for i in range(num_epochs): sess.run(enqueue_op, feed_dict={queue1_input: data}) sess.run(close_op)fill_thread = threading.Thread(target=fill_queue, args=())fill_thread.start()# read the data from queue shuffledtf.train.start_queue_runners()try: while True: print batch.eval()except tf.errors.OutOfRangeError: print "Done"顺便说一句,
enqueue_many当队列不足以将整个numpy数据集加载到队列中时,以上模式将挂起。您可以通过按以下方式分块加载数据来使自己有更大的灵活性来拥有较小的队列。
tf.reset_default_graph()data = np.array([1, 2, 3, 4])queue1_capacity = 2num_epochs = 2queue1_input = tf.placeholder(tf.int32)queue1 = tf.FIFOQueue(capacity=queue1_capacity, dtypes=[tf.int32], shapes=[()])enqueue_op = queue1.enqueue_many(queue1_input)close_op = queue1.close()dequeue_op = queue1.dequeue()def dequeue(): try: while True: print sess.run(dequeue_op) except: returndef enqueue(): for i in range(num_epochs): start_pos = 0 while start_pos < len(data): end_pos = start_pos+queue1_capacity data_chunk = data[start_pos: end_pos] sess.run(enqueue_op, feed_dict={queue1_input: data_chunk}) start_pos += queue1_capacity sess.run(close_op)sess = create_session()enqueue_thread = threading.Thread(target=enqueue, args=())enqueue_thread.start()dequeue_thread = threading.Thread(target=dequeue, args=())dequeue_thread.start()


