栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

使用tf.contrib.data.parallel_interleave并行化tf.from_generator

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

使用tf.contrib.data.parallel_interleave并行化tf.from_generator

在我看来,发电机不必要地使您的生活变得复杂。这就是我实现您的输入管道的方式:

def parse_file_tf(filename):    return tf.py_func(parse_file, [filename], [tf.float32, tf.float32])# version with mapfiles = tf.data.Dataset.from_tensor_slices(files_to_process)dataset = files.map(parse_file_tf, num_parallel_calls=N)dataset = dataset.flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x))dataset = dataset.batch(batch_size).shuffle(shuffle_size).prefetch(2)it = dataset.make_one_shot_iterator()

为了测试它,我定义一个虚拟对象

parse_file
为:

i=0def parse_file(f):    global i    i += 1    return np.asarray([i]*i, dtype=np.float32), np.asarray([i]*i, dtype=np.float32) # mimicks variable-length examples_x, examples_y

我输入了一个基本循环,该循环显示了迭代器返回的内容:

sess = tf.Session()try:    while True:        x, y = it.get_next()        vx, vy = sess.run([x,y])        print(vx)        print(vy)except tf.errors.OutOfRangeError:    passsess.close()

运行上面的代码可以打印:

[2. 3. 2. 1. 3. 3.][2. 3. 2. 1. 3. 3.]

管道说明

本质上,我将并行化问题留给

map
,可以在其中传递应运行的线程数。无需生成器迭代范围和那些额外的复杂性。

我选择map
over

parallel_interleave
是因为map要求您为
Dataset
返回的每个项生成一个实例,在您的情况下,这实际上没有任何意义,因为在运行时已将所有值加载到内存中
parse_file

parallel_interleave
如果您缓慢地生成值(例如,通过应用
tf.data.TFRecordDataset
到文件名列表)会很有意义,但是如果您的数据集适合内存,请使用
map


关于

tf.py_func
限制,它们不会影响您训练有素的网络,只会影响输入管道。理想情况下,您将为培训和网络的最终使用使用不同的管道。您只需要注意后者的局限性,而对于培训(除非您使用分布式培训和/或在机器之间移动培训进行非常具体的操作),则可以相当安全地进行。


带发电机的版本

如果您的JSON文件很大,并且其内容无法容纳在内存中,则可以使用生成器,但与您最初使用的方法略有不同。这个想法是,生成器遍历JSON文件并

yield
一次记录一个记录。然后,生成器必须是您的
parse_file
功能。例如,假设您具有以下
parse_file
生成器:

i = 3def parse_file(filename):    global i    i += 1    ctr = 0    while ctr < i:        yield ctr, ctr

在这种情况下,管道如下所示:

def wrap_generator(filename):    return tf.data.Dataset.from_generator(parse_file(filename), [tf.int32, tf.int32])files = tf.data.Dataset.from_tensor_slices(files_to_process)dataset = files.apply(tf.contrib.data.parallel_interleave(wrap_generator, cycle_length=N))dataset = dataset.flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x))dataset = dataset.shuffle(shuffle_size).batch(batch_size).prefetch(2)it = dataset.make_one_shot_iterator()

请注意,此处需要使用,

parallel_interleave
因为我们将生成器转换
Dataset
为从中提取值的实例。其余的保持不变。

将其馈送到与上述相同的示例循环中:

[6. 5. 4. 4. 6. 5. 6. 6. 5. 4. 6. 4. 5. 5. 6.][6. 5. 4. 4. 6. 5. 6. 6. 5. 4. 6. 4. 5. 5. 6.]


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/394981.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号