import tensorflow as tf
if __name__ == '__main__':
logits = tf.constant([[1.0,2.0,3.0,4.0,5.0,6.0]])
logp_all = tf.nn.log_softmax(logits)
pi = tf.squeeze(tf.multinomial(logits, 1), axis=1)
logp_pi = tf.reduce_sum(tf.one_hot(pi, depth=6) * logp_all, axis=1)
with tf.Session() as sess:
x = sess.run([logp_all, pi, logp_pi])[0]
print(x.shape)
print(x)
(32,) -> (32, 4)
y = tf.one_hot(x, 4)
(32, 4) gather from (32, 1), output (32, 1)
y = tf.gather(q_values, actions, axis=1, batch_dims=1)