#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time : 2021/10/21 21:51
# @Author : XXX
# @File : test_multinomial.py
"""
记:
logits: 为网络最后一层的输出,无激活函数
p = softmax(logits): 概率
log_p = log_softmax(logits) = log(p): log概率
结论:
"""
import tensorflow as tf
import numpy as np
tf.random.set_random_seed(1111)
logits = tf.constant([[5., 3, 2]])
p = tf.nn.softmax(logits, axis=1)
log_p = tf.nn.log_softmax(logits, axis=1)
# 对tf.multinomial,输出ndarray,(batch_size, num_samples)
# sample = tf.multinomial(logits, 10000) # [8423. 1157. 420.]
# sample = tf.multinomial(log_p, 10000) # [8423. 1157. 420.]
# 对tf.distributions.Categorical,输出ndarray, (num_samples, batch_size)
# 参数1: logits,得赋值为开头定义的logits或者log_p
# sample = tf.distributions.Categorical(logits=logits).sample((1, 10000)) # [8457. 1155. 388.]
# sample = tf.distributions.Categorical(logits=log_p).sample((1, 10000)) # [8457. 1155. 388.]
# 参数2:probs
sample = tf.distributions.Categorical(probs=p).sample(sample_shape=(10000, )) # [8457. 1155. 388.]
sess = tf.Session()
s = sess.run(sample)
cnt = np.zeros((3, ))
for i in s:
cnt[i] += 1
print(cnt)
print(sess.run(p)) # [[0.8437947 0.11419519 0.04201007]]