import tensorflow as tf
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
d = tf.constant([
[[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.], [13., 14., 15., 16.]],
[[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.], [13., 14., 15., 16.]]
])
after_dropout = tf.nn.dropout(d, 0.6)
result = sess.run(after_dropout)
print(result)
打印结果为:
[[[ 1.25 2.5 3.75 5. ]
[ 6.25 7.5 8.75 10. ]
[ 11.25 12.5 0. 15. ]
[ 0. 0. 18.75 20. ]]
[[ 1.25 2.5 3.75 5. ]
[ 6.25 0. 8.75 10. ]
[ 0. 0. 13.75 0. ]
[ 0. 17.5 18.75 0. ]]]
所以我们可以直观地看到,dropout为0.8的作用为:将所有元素除以0.8,然后随机挑选20%的数值,将这些数值设置为0。