1. Introduction
区别 | NumPy ndarray数组 | Python原生Array数组 |
---|---|---|
长度是否固定 | 固定,若改变长度,则新建一个数组 | 长度动态 |
数组元素类型是否相同 | 必须相同 | 可以不同 |
高级数学操作 | 支持度高,效率高,代码量少 | 支持度第,效率低,代码复杂 |
为了实现高效的计算,很多科学计算库都使用NumPy作为计算的载体。即使是使用Python原生数组作为输入,内部也将其转化为numpy数组。
当涉及到 ndarray 时,逐个元素的操作由预编译的C代码快速执行。
- numpy的两个特性:矢量化与广播
2. Quick Start
2.1 Basic Attribute
import numpy as np
a = np.array([
[1,2,3],
[4,5,6]
])
print(a.ndim) # return 2, dimensions
print(a.shape) # return (2, 3)
print(a.size) # return 6, total number of elements
print(a.dtype) # return int32
print(a.itemsize) # return 4, the size of int32 is 4 Bytes
print(a.data) # return <memory at 0x000002E2DD390E48>, the memory location of a
2.2 Array Creation
Normal creation:
import numpy as np
a = np.array([2,3,4], dtype=float) # normal defination
b = np.zeros((3, 4)) # zeros array
c = np.ones((2,3,4), dtype=np.int16) # ones array
d = np.empty((2,3)) # array without initialization
Arange function:
import numpy as np
a = np.arange(10, 30, 5)
# array between 10 <= x < 30, step = 5
# output: a = [10 15 20 25]
Linspace function:
import numpy as np
a = np.linspace(0,2,9)
# array length = 9
# output: [0. 0.25 0.5 0.75 1. 1.25 1.5 1.75 2. ]
2.3 Print Arrays
import numpy as np
a = np.linspace(0,15,16)
a.reshape((4,4))
# output:
# [[ 0. 1. 2. 3.]
# [ 4. 5. 6. 7.]
# [ 8. 9. 10. 11.]
# [12. 13. 14. 15.]]
如果数组过大,会显示不全。通过如下指令强制显示完整。
import sys
np.set_printoptions(threshold=sys.maxsize)
2.4 Basic Operations
矢量化运算
a ** 2 # 乘方
a * b # 点积
a @ b # 向量积, 或者 a.dot(b)
*= , += 等原地运算符也存在。
还有一系列的 min, max, sum, cumsum, exp, sqrt, add等函数,就不再一一介绍了。
2.5 indexing, slicing, iterating
a[index]
b[start:end:step]
注意:若使用
b = a[1:3]
则获得的是原数组,而不是原数组的拷贝,这与原生python array不同。
2.6 reshape
以下过程均为原地变换。
a.ravel() # 转化为一维数组
a.reshape((4,4))
2.7 stacking together different arrays
a = np.floor(10*rg.random((2,2)))
b = np.floor(10*rg.random((2,2)))
b = np.vstack(a,b) # 纵向堆叠
d = np.hstack(a,b) # 水平堆叠
2.8 spliting one array into several smaller ones
np.hsplit(a, 3) # 将a水平分为三份
np.vsplit(a, 3) # 将a纵向分为三份
2.9 copy
shallow copy
import numpy as np
a = np.linspace(0,15,16)
c = a # 这个过程不是浅拷贝,它连新的np.ndarray对象都没有创建,两者的id值完全相同
c = a.view() # 浅拷贝,创建了新的np.ndarray对象,两者的id值不同,但是内部的数组引用完全相同
c = a[:] # 这个与上面的view()等效
deep copy
import numpy as np
a = np.linspace(0, 15, 16)
b = a.copy()
print(a.base is b.base) # return false, 该深拷贝对多维数组也有效
2.10 Indexing
Indexing with Arrays od Indices
import numpy as np
a = np.arange(0,15,1)
b = np.array([1,6,8])
print(a[b]) # return [1 6 8]
e.g. 获取矩阵中的纵向最大值
import numpy as np
time = np.linspace(20, 145, 5)
data = np.sin(np.arange(20)).reshape(5,4)
ind = data.argmax(axis=0)
time_max = time[ind]
print(time_max)
a[[1,3,4]] = 0 # 令这三个索引位的元素为0
Indexing with boolean values
e.g. 输出矩阵中大于4的元素,并修改为0
a = np.arrange(12).reshape(3,4)
b = a > 4 # b 是一个boolean 矩阵
print(a[b])
a[b] = 0
ix_() function
import numpy as np
a = np.arange(3)
b = np.arange(4)
c = np.arange(5)
ax, bx, cx = np.ix_(a, b, c)
# ax.size() = (3,1,1)
# bx.size() = (1,4,1)
# cx.size() = (1,1,5)
use string index
2.11 Linear Algebra
Simple Array Operations
- transpose转置
import numpy as np
a = np.array([
[1, 2],
[3, 4]
])
print(a.transpose())
'''
转置操作,return
[
[1 3]
[2 4]
]
'''
- 求逆
np.linalg.inv(a)
'''
对矩阵a求逆,return
[
[-2. 1. ]
[ 1.5 0.5]
]
'''
- 单位矩阵
u = np.eye(2)
'''
[
[1 0]
[0 1]
]
'''
- 内积
j @ j
- 迹
np.trace(u) # return 2.0
- 求解线性方程组
a = np.array([
[1, 2],
[3, 4]
])
y = np.array([[5.], [7.]])
np.linalg.solve(a, y)
- 特征向量和特征值
import numpy as np
j = np.array([[0.0, -1.0], [1.0, 0.0]])
print(np.linalg.eig(j))
'''
return (array([0.+1.j, 0.-1.j]), array([[0.70710678+0.j , 0.70710678-0.j ],
[0. -0.70710678j, 0. +0.70710678j]]))
'''
2.12 Tricks and Tips
Automatic Compute Reshaping Size
a = np.arange(30)
b = a.reshape((2, -1, 3))
print(b.shape) # return (2, 5, 3)
2.13 Histograms
import numpy as np
import matplotlib.pyplot as plt
rg = np.random.default_rng(1)
mu, sigma = 2, 0.5
v = rg.normal(mu,sigma,10000)
plt.hist(v, bins=50, density=1) # use plt api draw histograms
plt.show()
(n, bins) = np.histogram(v, bins=50, density=True)
plt.plot(.5 * (bins[1:] + bins[:-1]), n) # use np api draw histograms
plt.show()