前言
在完成作业的过程中,对multiply
函数、*
运算符号、dot
函数的功能经常混淆。在这里做一个简单的区分,并不一定严谨。每个函数对于数组和矩阵的操作内容也存在差异。
本博客只是针对常见的用法,例如矩阵传播机制并不进行考虑,如果需要透彻理解这些操作,可以转到别的博客学习,这里只做最简单的应用区分。
import numpy as np
m = np.array([[1, 2, 3], [4, 5, 6]])
n = np.array([[1, 2, 3], [4, 5, 6]])
print("矩阵m:\n", m,type(m))
print("矩阵n:\n", n,type(n))
矩阵m:
[[1 2 3]
[4 5 6]] <class 'numpy.ndarray'>
矩阵n:
[[1 2 3]
[4 5 6]] <class 'numpy.ndarray'>
np.multiply()
函数
只做点乘运算。数组和矩阵对应位置相乘,输出与相乘数组/矩阵的大小一致。
ans=np.multiply(n,m)
print("结果:\n",ans,type(ans))
ans=np.multiply(m,n)
print("结果:\n",ans,type(ans))
结果:
[[ 1 4 9]
[16 25 36]] <class 'numpy.ndarray'>
结果:
[[ 1 4 9]
[16 25 36]] <class 'numpy.ndarray'>
将其从数组转换成矩阵做相同的操作:
mat_n=np.mat(n)
mat_m=np.mat(m)
print("矩阵mat_m:\n", mat_m,type(mat_m))
print("矩阵mat_n:\n", mat_n,type(mat_n))
矩阵mat_m:
[[1 2 3]
[4 5 6]] <class 'numpy.matrix'>
矩阵mat_n:
[[1 2 3]
[4 5 6]] <class 'numpy.matrix'>
ans=np.multiply(mat_n,mat_m)
print("结果:\n",ans,type(ans))
ans=np.multiply(mat_m,mat_n)
print("结果:\n",ans,type(ans))
结果:
[[ 1 4 9]
[16 25 36]] <class 'numpy.matrix'>
结果:
[[ 1 4 9]
[16 25 36]] <class 'numpy.matrix'>
np.dot()
函数
针对数组形式,可以参考下述博客:
针对矩阵形式,可以理解为它是两个二维的数组,执行矩阵乘法运算。
星号(*
)运算符
对数组执行对应位置相乘,即点乘
对矩阵执行矩阵乘法运算
m = np.array([[1, 2, 3], [4, 5, 6]])
n = np.array([[1, 2, 3], [4, 5, 6]])
ans=m*n
print("结果:\n",ans,type(ans))
结果:
[[ 1 4 9]
[16 25 36]] <class 'numpy.ndarray'>
mat_n=np.mat(n).T
mat_m=np.mat(m)
print("矩阵mat_m:\n", mat_m,type(mat_m))
print("矩阵mat_n:\n", mat_n,type(mat_n))
矩阵mat_m:
[[1 2 3]
[4 5 6]] <class 'numpy.matrix'>
矩阵mat_n:
[[1 4]
[2 5]
[3 6]] <class 'numpy.matrix'>
ans=mat_m*mat_n #(3,2)*(2,3)=(2,2)
print("结果:\n",ans,type(ans))
结果:
[[14 32]
[32 77]] <class 'numpy.matrix'>
@
运算符
对数组和矩阵都是执行乘法操作,当运算符两边的数据维度无法满足矩阵运算时,就会报错。
m = np.array([[1, 2, 3], [4, 5, 6]])
n = np.array([[1, 2, 3], [4, 5, 6]])
ans=m@n
print("结果:\n",ans,type(ans))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-36-6e1f6ca7b792> in <module>
1 m = np.array([[1, 2, 3], [4, 5, 6]])
2 n = np.array([[1, 2, 3], [4, 5, 6]])
----> 3 ans=m@n
4 print("结果:\n",ans,type(ans))
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 3)
运行发现,它做的是矩阵乘法运算,而此时m
和n
都是
(
2
,
3
)
(2,3)
(2,3),无法运行。
将n
转置,再运算:
ans=m@n.T
print("结果:\n",ans,type(ans))
结果:
[[14 32]
[32 77]] <class 'numpy.ndarray'>
下面讨论,@
运算符对矩阵的运算操作。
mat_n=np.mat(n).T
mat_m=np.mat(m)
ans=mat_m@mat_n
print("结果:\n",ans,type(ans))
结果:
[[14 32]
[32 77]] <class 'numpy.matrix'>
运行发现,它做的是矩阵乘法运算。
总结
为了防止记混或者出错,并且结合目前我的学习需要:
- 只使用
@
来做矩阵乘法运算 - 只使用
np.multiply
来做点乘运算