import tensorflow as tf
samples = tf.multinomial([[0.4, 0.6], [0.5, 0.7],[0.2, 0.1],[0.7, 0.8]], 1)
with tf.Session() as sess:
print(sess.run(samples))
结果是:
[[1]
[0]
[1]
[0]]
注意结果并不是:
[[1]
[1]
[0]
[1]]
对于一个batch=4, output class = 2的矩阵, 取了4个值