栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

深入浅出TensorFlow2函数——tf.data.Dataset.shuffle

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

深入浅出TensorFlow2函数——tf.data.Dataset.shuffle

分类目录:《深入浅出TensorFlow2函数》总目录


函数:

shuffle(buffer_size, seed=None, reshuffle_each_iteration=None, name=None)

该函数可以随机洗牌此数据集的元素。此数据集使用buffer_size的元素填充缓冲区,然后从该缓冲区中随机采样元素,用新元素替换所选元素。为了实现完美的洗牌,需要缓冲区大小大于或等于数据集的完整大小。

例如,如果您的数据集包含10000个元素,但buffer_size设置为1000,则shuffle最初将仅从缓冲区中的前1000个元素中选择一个随机元素。一旦选择一个元素,其在缓冲区中的空间将被下一个(即1001个)元素替换,从而保持1000个元素的缓冲区。而reshuffle_each_iteration控制每次迭代的洗牌顺序是否应该不同。

在TensorFlow2.X中,tf.data.Dataset对象是Python的iterables,所以我们也可以用Python的循环遍历:

dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
list(dataset.as_numpy_iterator())
# [1, 0, 2]
list(dataset.as_numpy_iterator())
# [1, 2, 0]

参数:

参数意义
buffer_size[tf.int64 /tf.Tensor]表示新数据集将从此数据集中采样的元素数。
seed[可选,tf.int64 /tf.Tensor]表示将用于创建分布的随机种子。
reshuffle_each_iteration[可选,tf.bool]如果为True,则表示每次迭代数据集时都应伪随机地重新洗牌,默认为True。
name[可选]tf.data操作的名称

返回值:

返回值意义
Dataset一个tf.data.Dataset的数据集。

函数实现:

  def shuffle(self,
              buffer_size,
              seed=None,
              reshuffle_each_iteration=None,
              name=None):
    """Randomly shuffles the elements of this dataset.
    This dataset fills a buffer with `buffer_size` elements, then randomly
    samples elements from this buffer, replacing the selected elements with new
    elements. For perfect shuffling, a buffer size greater than or equal to the
    full size of the dataset is required.
    For instance, if your dataset contains 10,000 elements but `buffer_size` is
    set to 1,000, then `shuffle` will initially select a random element from
    only the first 1,000 elements in the buffer. once an element is selected,
    its space in the buffer is replaced by the next (i.e. 1,001-st) element,
    maintaining the 1,000 element buffer.
    `reshuffle_each_iteration` controls whether the shuffle order should be
    different for each epoch. In TF 1.X, the idiomatic way to create epochs
    was through the `repeat` transformation:
    ```python
    dataset = tf.data.Dataset.range(3)
    dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
    dataset = dataset.repeat(2)
    # [1, 0, 2, 1, 2, 0]
    dataset = tf.data.Dataset.range(3)
    dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
    dataset = dataset.repeat(2)
    # [1, 0, 2, 1, 0, 2]
    ```
    In TF 2.0, `tf.data.Dataset` objects are Python iterables which makes it
    possible to also create epochs through Python iteration:
    ```python
    dataset = tf.data.Dataset.range(3)
    dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
    list(dataset.as_numpy_iterator())
    # [1, 0, 2]
    list(dataset.as_numpy_iterator())
    # [1, 2, 0]
    ```
    ```python
    dataset = tf.data.Dataset.range(3)
    dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
    list(dataset.as_numpy_iterator())
    # [1, 0, 2]
    list(dataset.as_numpy_iterator())
    # [1, 0, 2]
    ```
    Args:
      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
        elements from this dataset from which the new dataset will sample.
      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
        seed that will be used to create the distribution. See
        `tf.random.set_seed` for behavior.
      reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
        that the dataset should be pseudorandomly reshuffled each time it is
        iterated over. (Defaults to `True`.)
      name: (Optional.) A name for the tf.data operation.
    Returns:
      Dataset: A `Dataset`.
    """
    return ShuffleDataset(
        self, buffer_size, seed, reshuffle_each_iteration, name=name)
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/656627.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号