形状与axis
设embeddings是一个shape=[3,4,5]的矩阵,可理解为3块4行5列,如下:
embeddings = [[[-0.30166972 0.25741747 -0.07442257 0.24321035 -0.3538919 ]
[-0.22572032 0.1288028 -0.4686908 -0.07217035 0.05287632]
[ 0.15845934 0.07064888 0.00922218 0.2841002 -0.24992025]
[ 0.43347922 -0.43738696 -0.08176881 0.34185413 -0.2826353 ]]
[[-0.08590135 0.06792518 -0.07807922 -0.28746927 -0.10613027]
[ 0.07476929 0.132256 -0.0926154 0.39621904 0.2497718 ]
[-0.15389556 0.0867373 0.19403657 -0.11003655 0.317669 ]
[ 0.3949038 -0.17275128 0.34710506 -0.02576578 -0.17427891]]
[[-0.27703786 0.02631402 0.22129896 -0.07714707 0.41439041]
[-0.08512023 0.19059369 -0.13418713 -0.12881753 -0.26143318]
[-0.333749 0.27034065 0.45429572 -0.46164128 -0.3955955 ]
[ 0.24430516 -0.3841647 0.37126407 -0.463441 -0.1441828 ]]]
a = tf.math.argmax(embeddings, axis=-1) # tf.math.argmax=tf.argmax,用来返回最大数值对应的index
b = tf.math.argmax(embeddings, axis=1)
c = tf.math.argmax(embeddings, axis=0)
得到
[[1 1 3 0]
[1 3 4 0]
[4 1 2 2]] # axis=-1,沿着最后一个维度“列”, shape=[3,4]
[[3 0 2 3 1]
[3 1 3 1 2]
[3 2 2 0 0]] # axis=1, 沿着第二个维度“行”,shape=[3,5]
[[1 0 2 0 2]
[1 2 1 1 1]
[0 2 2 0 1]
[0 1 2 0 2]] # axis=0,沿着第一个维度“块”,块和快叠加在一起,对应位置(如左上角)的加起来, shape=[4,5]