对于3维度数组 总结axis=0 , 1 ,2
axis=0 沿着 axis=0方向 (可以认为是时间的方向)取每个单元对应元素进行计算softmax() //通俗理解就是今天8点钟的对应行对应列的元素a[2][0][0] 与昨天8点钟的对应行对应列的元素a[1][0][0] 及 前天8点钟的对应行对应列的元素a[0][0][0] 都取出来,进行计算softmax()
axis=1 沿着 axis=1方向 (行的方向)取每个单元对应元素进行计算softmax()
axis=2 沿着 axis=2方向 (列的方向)取每个单元对应元素进行计算softmax()
import tensorflow as tf
import numpy as np
print("tf.__version__=",tf.__version__)
a = np.array([[[1, 2, 3], [1, 2, 3]],[[4, 5, 6], [4, 5, 6]]])
a = tf.cast(a, tf.float32)
#>>> a
#<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
#array([[[1., 2., 3.],
# [1., 2., 3.]],
#
# [[4., 5., 6.],
# [4., 5., 6.]]], dtype=float32)>
#
s1 = tf.nn.softmax(a,axis=0)
print(s1)
#tf.Tensor(
#[[[0.04742587 0.04742587 0.04742587]
# [0.04742587 0.04742587 0.04742587]]
#
# [[0.95257413 0.95257413 0.95257413]
# [0.95257413 0.95257413 0.95257413]]], shape=(2, 2, 3), dtype=float32)
s2 = tf.nn.softmax(a,axis=1)
print(s2)
#tf.Tensor(
#[[[0.5 0.5 0.5]
# [0.5 0.5 0.5]]
#
# [[0.5 0.5 0.5]
# [0.5 0.5 0.5]]], shape=(2, 2, 3), dtype=float32)
s3 = tf.nn.softmax(a,axis=2)
print(s3)
#tf.Tensor(
#[[[0.09003057 0.24472848 0.66524094]
# [0.09003057 0.24472848 0.66524094]]
#
# [[0.09003057 0.24472848 0.66524094]
# [0.09003057 0.24472848 0.66524094]]], shape=(2, 2, 3), dtype=float32)
a1 = np.array([[[1, 2, 3], [2, 4, 6]],[[2, 4, 6], [4, 5, 6]]])
a1 = tf.cast(a1, tf.float32)
#>>> a
#<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
#array([[[1., 2., 3.],
# [1., 2., 3.]],
#
# [[4., 5., 6.],
# [4., 5., 6.]]], dtype=float32)>
#
s11 = tf.nn.softmax(a1,axis=0)
print(s11)
输出
tf.__version__= 2.0.0
tf.Tensor(
[[[0.04742587 0.04742587 0.04742587]
[0.04742587 0.04742587 0.04742587]]
[[0.95257413 0.95257413 0.95257413]
[0.95257413 0.95257413 0.95257413]]], shape=(2, 2, 3), dtype=float32)
tf.Tensor(
[[[0.5 0.5 0.5]
[0.5 0.5 0.5]]
[[0.5 0.5 0.5]
[0.5 0.5 0.5]]], shape=(2, 2, 3), dtype=float32)
tf.Tensor(
[[[0.09003057 0.24472848 0.66524094]
[0.09003057 0.24472848 0.66524094]]
[[0.09003057 0.24472848 0.66524094]
[0.09003057 0.24472848 0.66524094]]], shape=(2, 2, 3), dtype=float32)
tf.Tensor(
[[[0.26894143 0.11920291 0.04742587]
[0.11920291 0.26894143 0.5 ]]
[[0.7310586 0.880797 0.95257413]
[0.880797 0.7310586 0.5 ]]], shape=(2, 2, 3), dtype=float32)