import numpy as np
a = np.array([[[1, 2, 4], [1, 2, 4]], [[3, 2, 1], [1, 2, 4]], [[3, 2, 1], [1, 2, 4]]])
print(a.shape)
b = np.max(a, axis=0)
print(b.shape)
print(b)
c = np.max(a, axis=1)
print(c.shape)
print(c)
d = np.max(a, axis=2)
print(d.shape)
print(d)
a的形状是(3, 2, 3)
b的形状是
312244(b)
我们发现max和axis是一种降维的方法,变为(2, 3),我们可以理解为axis=0是在一个batch进行的。
c的形状是
133222444(c)
d的形状是
433444(d)
也就是说,axis是用来选定在哪一个维度进行计算的,axis从小到大,范围也越来越小,是一种降维的方法。