def multinomial(logits, num_samples, seed=None, name=None, output_dtype=None)
logits是一个二维张量,num_samples指的是采样的个数。
先上代码:
a = tf.constant([1.,2.,3.,4.,5.,6.,7.,8.,9.])
b = tf.reshape(a,[1,9])
dede = tf.math.log(x=b,name='xx')
samples = tf.multinomial(dede, 100)
with tf.Session() as sess:
print('b: ', sess.run(b))
print('dede: ', sess.run(dede))
wqw = sess.run(samples)
print('wqw: ', wqw)
d = {}
for i in wqw[0]:
if 'class'+str(i) in list(d.keys()):
d['class' + str(i)] += 1
else:
d['class' + str(i)] = 1
print(d)
for i in range(len(d)):
print('class%d'%i,d['class%d'%i])
输出结果:
b: [[1. 2. 3. 4. 5. 6. 7. 8. 9.]]
dede: [[0. 0.6931472 1.0986123 1.3862944 1.609438 1.7917595 1.9459102
2.0794415 2.1972246]]
wqw: [[6 3 7 6 3 8 5 7 7 8 5 7 0 7 6 6 4 7 8 7 8 0 2 7 7 8 6 6 6 3 4 5 6 4 5 7
6 0 8 4 7 4 7 5 1 8 7 7 7 4 0 4 4 8 5 4 4 8 7 3 7 7 8 6 6 2 5 4 7 5 4 6
7 8 6 8 8 8 6 6 6 8 8 8 4 5 5 6 5 5 8 4 8 5 0 0 6 2 7 5]]
{'class8': 19, 'class4': 14, 'class7': 21, 'class1': 1, 'class0': 6, 'class3': 4, 'class2': 3, 'class6': 18, 'class5': 14}
class0 6
class1 1
class2 3
class3 4
class4 14
class5 14
class6 18
class7 21
class8 19
multinomial对概率分布dede进行采样,采样100个,wqw 的值:每一次采样,采的是dede的值对应的位置。比如:wqw的第一个值6 指的就是dede位置6的值,也就是1.9459102。
{‘class8’: 19, ‘class4’: 14, ‘class7’: 21, ‘class1’: 1, ‘class0’: 6, ‘class3’: 4, ‘class2’: 3, ‘class6’: 18, ‘class5’: 14}从这可以看出,dede位置8的值被采了19次,位置4的值被采了14次。。。。。。。。