栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

谁能举一个小小的例子来解释tf.random.categorical的参数?

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

谁能举一个小小的例子来解释tf.random.categorical的参数?

如您所述,

tf.random.categorical
有两个参数:

  • logits
    ,具有形状的2D浮动张量
    [batch_size, num_classes]
  • num_samples
    ,一个整数标量。

输出是带有shape的2D整数张量

[batch_size, num_samples]

logits
张量(
logits[0, :]
logits[1,:]
…)的每个“行”代表不同分类分布的事件概率。但是,该函数并不期望实际的概率值,而是期望的非标准化对数概率;因此实际的概率为
softmax(logits[0,:])
softmax(logits[1,:])
等等。这样做的好处是,您基本上可以给出任何实际值作为输入(例如,神经网络的输出),并且它们将是有效的。同样,使用对数使用特定的概率值或比例也很简单。例如,两个
[log(0.1),log(0.3), log(0.6)]
[log(1), log(3),log(6)]
表示相同的概率,其中第二类的概率是第一个类的三倍,但仅是第三类的一半。

对于每一行(非标准化对数)概率,您可以

num_samples
从分布中获取样本。每个样本都是介于
0
和之间的整数
num_classes -1
,根据给定的概率得出。因此,结果是二维张量的形状
[batch_size, num_samples]
与每个分布的采样整数。

编辑:函数的一个小例子。

import tensorflow as tfwith tf.Graph().as_default(), tf.Session() as sess:    tf.random.set_random_seed(123)    logits = tf.log([[1., 1., 1., 1.],          [0., 1., 2., 3.]])    num_samples = 30    cat = tf.random.categorical(logits, num_samples)    print(sess.run(cat))    # [[3 3 1 1 0 3 3 0 2 3 1 3 3 3 1 1 0 2 2 0 3 1 3 0 1 1 0 1 3 3]    #  [2 2 3 3 2 3 3 3 2 2 3 3 2 2 2 1 3 3 3 2 3 2 2 1 3 3 3 3 3 2]]

在这种情况下,结果是一个包含两行30列的数组。第一行中的值是从分类分布中抽样的,其中每个类别(

[0, 1, 2,3]
)具有相同的概率。在第二行中,该类别
3
是最可能的类别,并且该类别
0
几乎没有被采样的可能性。



转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/634279.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号