记录_broadcast
code
seed = 0
n_approx = 5
num_cells = 4
dimension = 3
X = np.random.rand(n_approx, dimension)
print(X, X.shape)
kernel_points = np.random.rand(num_cells, dimension)
print(kernel_points, kernel_points.shape)
>>>
[[0.26214758 0.05170461 0.93400984]
[0.62591246 0.65007225 0.8046075 ]
[0.86740066 0.01002227 0.63433347]
[0.6582109 0.76768031 0.22212046]
[0.12551867 0.37109102 0.19793781]] (5, 3)
>>>
[[0.51271662 0.17935163 0.50461549]
[0.52004753 0.02080544 0.33881502]
[0.23372128 0.87948412 0.67752571]
[0.95115178 0.11724297 0.1471212 ]] (4, 3)
Y = np.expand_dims(X, axis=1)
print(Y,Y.shape)
>>>
[[[0.26214758 0.05170461 0.93400984]]
[[0.62591246 0.65007225 0.8046075 ]]
[[0.86740066 0.01002227 0.63433347]]
[[0.6582109 0.76768031 0.22212046]]
[[0.12551867 0.37109102 0.19793781]]] (5, 1, 3)
differences = Y - kernel_points
print(differences, differences.shape)
>>>
[[[-0.25056904 -0.12764702 0.42939435]
[-0.25789995 0.03089917 0.59519481]
[ 0.02842629 -0.82777951 0.25648413]
[-0.6890042 -0.06553836 0.78688864]]
[[ 0.11319584 0.47072062 0.29999201]
[ 0.10586493 0.62926681 0.46579248]
[ 0.39219118 -0.22941187 0.12708179]
[-0.32523932 0.53282928 0.65748631]]
[[ 0.35468404 -0.16932936 0.12971798]
[ 0.34735313 -0.01078317 0.29551844]
[ 0.63367938 -0.86946185 -0.04319224]
[-0.08375112 -0.1072207 0.48721227]]
[[ 0.14549428 0.58832868 -0.28249502]
[ 0.13816337 0.74687487 -0.11669456]
[ 0.42448961 -0.11180381 -0.45540525]
[-0.29294088 0.65043734 0.07499927]]
[[-0.38719794 0.19173939 -0.30667768]
[-0.39452886 0.35028558 -0.14087721]
[-0.10820261 -0.50839309 -0.4795879 ]
[-0.8256331 0.25384805 0.05081661]]] (5, 4, 3)
conclusion
- shape(5,1,3) - shape(4,3) --> shape(5,4,3) - shape(4,3)
- 直接通过broadcasting:原本 shape(5,3) 的
- 第一行依次 减去shape(4,3)每一行
- 第二行依次 减去shape(4,3)每一行
- …
- 第五行依次 减去shape(4,3)每一行
- 达到各点遍历减去中心点的效果