项目中经常遇到不同维度矩阵的加减操作,规则如下:
(1)进行加减操作的两个矩阵最后一个维度要相同或其中一个矩阵的最后一个维度元素数为1
a=np.array([[1,2],[3,4]])
print(a.shape)
b=np.array([[[7,8],[8,9]],
[[7,8],[8,9]]])
print(b.shape)
print(b-a,(b-a).shape)
print('==============11111===============')
a=np.array([[1,2],[3,4]])
print(a.shape)
b=np.array([[[[7,8],[8,9]]]])
print(b.shape)
print(b-a,(b-a).shape)
print('==============22222===============')
#最后一个维度的元素数分别为2和3,报错
a=np.array([[1,2,3],[3,4,5]])
print(a.shape)
b=np.array([[[7,8],[8,9]],
[[7,8],[8,9]]])
print(b.shape)
try:
print(b-a,(b-a).shape)
except Exception as e:
print(e)
print('=============33333================')
#其中一个矩阵的最后一个维度元素数为1,不报错
a=np.array([[1],[3]])
print(a.shape)
b=np.array([[[7,8,9,10],[8,9,10,11]]])
print(b.shape)
try:
print(b-a,(b-a).shape)
except Exception as e:
print(e)
print('============44444=================')
#其中一个矩阵的最后一个维度元素数为1,不报错
a=np.array([[1,2,3],[3,4,5]])
print(a.shape)
b=np.array([[[7],[8]]])
print(b.shape)
try:
print(b-a,(b-a).shape)
except Exception as e:
print(e)
print('============55555=================')
#最后一个维度的元素数分别为4和2,报错
a=np.array([[1,2],[3,4]])
print(a.shape)
b=np.array([[[7,8,9,10],[8,9,10,11]]])
print(b.shape)
try:
print(b-a,(b-a).shape)
except Exception as e:
print(e)
print('============66666=================')
输出:
(2)结果矩阵维度
当两个矩阵最后一个维度元素数相同时,加减完后矩阵的形状和高维矩阵的相同;当不同时(其中一个矩阵的最后一维元素数为1),矩阵维度和高维度矩阵相同,但最后一维的元素数和原来两个矩阵中最后一维元素数较多的矩阵相同。
(3)两个矩阵卷积操作的规则与加减类似
a=np.array([[1,2,3],[4,5,6]])
b=np.array([[1,2],[2,3]])
try:
print(a*b,(b-a).shape)
except Exception as e:
print(e)
print('============11111============')
a=np.array([[1,2,3],[4,5,6]])
b=np.array([1,2])
try:
print(a*b,(b-a).shape)
except Exception as e:
print(e)
print('============22222============')
a=np.array([[1,2,3,4],[4,5,6,7]])
b=np.array([1,2])
try:
print(a*b,(b-a).shape)
except Exception as e:
print(e)
print('============33333============')
a=np.array([[1,2,3,4],[4,5,6,7]])
b=np.array([2])
try:
print(a*b,(b-a).shape)
except Exception as e:
print(e)
print('============44444============')
a=np.array([[1,1,1,1],[2,2,2,2]])
b=np.array([[2,4,6,8],[1,3,5,7]])
try:
print(a*b,(b-a).shape)
except Exception as e:
print(e)
print('============55555============')
输出: