学习材料
近期讲课学习numpy,实践发现若干资料讲述均不准确,如:
1.某教材
‘’‘
注意:广播机制实现了对两个或两个以上数组的运算,即使这些数组的shape不是完全相同的,只需要满足如下任意一个条件即可。.
(1)数组的某一维度等长。
(2)其中一个数组的某一维度为1。
广播机制需要扩展维度小的数组,使得它与维度最大的数组的shape值相同,以便使用元素级函数或者运算符进行运算。
‘’‘
2.某博客
‘’‘
前言:数组广播是学习numpy和tensorflow等数学运算的基础,但是很多文章解释得不清楚,本文做一个全面的总结。
Numpy的广播规则:广播的前提——两个数组必须可以转化成维度大小一模一样的才可以运算:
规则1:如果两个数组的维度不相同,那么小维度数组的形状将会在最左边补1.
规则2:如果两个数组的形状在任何一个维度上都不匹配,那么数组的形状会沿着维度为1扩展以匹配另外一个数组的形状。
规则3:如果两个数组的形状在任何一个维度上都不匹配并且没有任何一个维度为1,那么会引起异常。
‘’’
通过实验,上述说法或存在错误,或不易理解。
尝试总结
本人尝试总结如下,欢迎拍砖沟通,-
(1)维度相同的必须同构(对应维数相同)或一个维度为1,可以广播;
(2)低维与高维,必须低维同构(对应维数相同)或一个维度为1,可以广播。任意单行(1维)与任意单列(2维,低维为1)(对应维数相同),可以广播。
算法:从最右侧维数看起,每一个维数要么相同,要么其中1个为1或0,可以广播。否则,不可以广播。
如a=np.ones((4,1) , a.shape为(4,1)
b=np.ones(3),a.shape为(3,)
取a的左右侧维数为1,b的左右侧维数为3,符合,a的下一个维数为4,b的为0,,所以a和b可以广播
再如如a=np.ones((4,1) , a.shape为(4,2)
b=np.ones(3),a.shape为(3,)
取a的左右侧维数为2,b的左右侧维数为3,不符合,所以a和b不可以广播
测试如下:
(1)1维,单行和单列,可以广播;单行与单行,单列与单列,必须同构(等长),可以广播;
import numpy as np
a = np.ones(8)
b = np.ones((4,1))
print(a+b)
**OK通过 #1维的单行和单列
import numpy as np
a = np.ones(8)
b = np.ones(4)
print(a+b)
不通过 #单行不同构
import numpy as np
a = np.ones((8,1))
b = np.ones((4,1))
print(a+b)
不通过 #单列不同构
请大家自行测试,单行与单行,单列与单列,必须同构(等长)
(2)多维相同,必须同构,可以广播;
import numpy as np
a = np.ones((3,2))
b = np.ones((3,4))
print(a+b)
不通过。#第1维相同
import numpy as np
a = np.ones((6,4))
b = np.ones((3,4))
print(a+b)
不通过,#第2维相同
import numpy as np
a = np.ones((4,3))
b = np.ones((3,4))
print(a+b)
不通过,#互为转置
(3)多维不同,低维与高维,必须右侧同构(低维同构?),可以广播。
import numpy as np
a = np.ones((4,3))
b = np.ones((2,4,3))
print(a+b)
通过,# 右侧同构
import numpy as np
a = np.ones((4,3))
b = np.ones((2,5,4,3))
print(a+b)
通过,# 右侧同构
import numpy as np
a = np.ones((4,3))
b = np.ones(4)
print(a+b)
不通过,#右侧不同构
import numpy as np
a = np.ones((4,3))
b = np.ones((4,3,2))
print(a+b)
不通过,# 右侧(低维?)不同构(说法不确认)
import numpy as np
a = np.ones((4,3))
b = np.ones((5,4,3,2))
print(a+b)
不通过,# 右侧(低维?)不同构(说法不确认)
欢迎沟通。
另自己初时不理解的维度表示,贴出来共勉
a=np.ones(8)
b=np.ones((8,1))
b
array([[1., 1., 1., 1., 1., 1., 1., 1.]])
a
array([1., 1., 1., 1., 1., 1., 1., 1.])
a.shape
(8,) #1维
b.shape
(8, 1) #2维
a.ndim
1
b.ndim
2
#理解:a是1维数组,8个元素;b为2维数组,1行8个元素