栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

tf.data.Dataset.interleave()原理详解(全参数)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

tf.data.Dataset.interleave()原理详解(全参数)

1. 基础运行流程

  最近学习tensorflow,对于这个函数tf.data.Dataset.interleave()始终有点晕乎,即使搞明白了,用不了多久又忘了,在网上查了查,发现很少有人能把这个函数讲清楚。趁着现在还明白,记录下来——备忘+助友。interleave()是Dataset的类方法,所以interleave是作用在一个Dataset上的。

语法:

interleave(
    map_func,
    cycle_length=AUTOTUNE,
    block_length=1,
    num_parallel_calls=None
)

解释:

  1. 假定我们现在有一个Dataset——A
  2. 从该A中取出cycle_length个element,然后对这些element apply map_func,得到cycle_length个新的Dataset对象。
  3. 然后从这些新生成的Dataset对象中取数据,取数逻辑为轮流从每个对象(注意,这里不是先取完一个对象再取另一个)里面取数据,每次取block_length个数据
  4. 当这些新生成的某个Dataset的对象取尽时,从原Dataset中再取cycle_length个element,然后apply map_func,以此类推。

举例:

a = tf.data.Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
# NOTE: New lines indicate "block" boundaries.
b=a.interleave(lambda x: tf.data.Dataset.from_tensors(x).repeat(6),
            cycle_length=2, block_length=4) 
for item in b:
    print(item.numpy(),end=', ')

输出结果:

1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 3, 3, 4, 4, 5, 5, 5, 5, 5, 5,

上面程序的图示,看示意图可能更清晰:

其中map_func在这里是重复6次-repeat(6)。

常见case:
dataset里面存储文件名, 将所有文件读取出来,产生一个大数据集。

  以上内容参考自「倚剑天客」的原创文章,做了一些改动。版权声明:本文为CSDN博主「倚剑天客」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。原文链接:https://blog.csdn.net/menghuanshen/article/details/104240189

2. 性能提升

  该方法有时会比较占用资源,为了提升性能,可能需要另外两个参数的参与
num_parallel_calls只会影响同时处理的线程数量并不会影响最终结果,但要求要小于等于cycle_length。

例程:

import tensorflow as tf

dataset = tf.data.Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
# NOTE: New lines indicate "block" boundaries.
dataset = dataset.interleave(
    lambda x: tf.data.Dataset.from_tensors(x).repeat(6),
    cycle_length=4, block_length=4, num_parallel_calls = 2, deterministic = False)
print(list(dataset.as_numpy_iterator()))

打印:
  num_parallel_calls不论设置为1、2、3、4,输出都是下面结果

[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 5, 5, 5, 5]

原理:

  1. 假设num_parallel_calls = 2,此时先从dataset(即[ 1, 2, 3, 4, 5 ])中取cycle_length=4个元素,即为[1, 2, 3, 4]
  2. 然后因为num_parallel_calls = 2,所以先对[1, 2]应用 lambda x: tf.data.Dataset.from_tensors(x).repeat(6)(也就是重复6次),得到
[1, 1, 1, 1
  2, 2, 2, 2]
  1. 然后再对[3, 4]进行同样的处理得到
[3, 3, 3, 3
  4, 4, 4, 4]
  1. 之后将
[1, 1, 1, 1
  2, 2, 2, 2
  3, 3, 3, 3
  4, 4, 4, 4]
  1. 再轮流取 block_length=4个数据,重复基础运行流程中的循环

  可以看到,上述num_parallel_calls只是对并行处理量进行了拆分,对运行结果并没有造成影响.
  deterministic参数会对结果造成影响,他影响数据处理时候的顺序(自己总结,不一定对)。即若deterministic=False,在并行处理时可能并不一定是按照[1、2、3、4]的顺序,有可能是[2、3、4、1]之类的。
例程:

import tensorflow as tf

dataset = tf.data.Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
# NOTE: New lines indicate "block" boundaries.
dataset = dataset.interleave(
    lambda x: tf.data.Dataset.from_tensors(x).repeat(6),
    cycle_length=4, block_length=4, num_parallel_calls = 2, deterministic = False)
print(list(dataset.as_numpy_iterator()))

打印:

[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 5, 5, 5, 5]

[2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 1, 1, 1, 1, 2, 2, 3, 3, 4, 4, 1, 1, 5, 5, 5, 5, 5, 5]

更详细内容可以参考官方文档

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/886518.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号