Numpy,Numerical + Python,主攻高维数组的处理,结合了Python代码简洁和C性能优良的优点,是Python科学计算最最最最基础的包。
在 “可乐学人” 上一篇文章Matplotlib优雅作图笔记中,优雅的我从作图的“高效性”、“美观性”和“交互性”三个角度,讲了如何优雅地用Matplotlib绘制出一幅优雅的图像,并优雅地插入到LaTex论文文档中。今天我将带大家优雅地撸一遍Numpy。
众所周知,Numpy的核心是ndarray,这是一种“提供快速数值运算的节省内存的容器”(Memory-efficient container that provides fast numerical operations),“节省内存”和“快速运算”就是Numpy相比于Python更加高效的根本原因。除此之外,Numpy中区分了 拷贝(copy)和视图(view) 的概念,对这两个概念的理解和灵活运用也将使我们使用Numpy时更加自如。基于这三点,本文也相应分为如下三个部分:
- 数组基础
- 核心:数组在连续的内存块中存储,所有的高维数组在内存中都是一维的
- 内容:数组的创建、数组的合并与分裂
2. 数组计算
-
- 核心:矢量化将循环操作转为数组表达式,广播为不同形状数组之间的计算提供了矢量化的方法
- 内容:数组的逐元素计算、数组的广播机制计算、数组的整合计算
3. 数组操作
-
- 核心:拷贝复制了原数组,速度慢,视图直接取原数组的索引,速度快
- 内容:数组的索引和切片、数组的重塑和打平
本文首发于公众号“可乐学人”,预计阅读时间差不多也就十来分钟吧,希望大家都能有所收获。
数组基础
Numpy的核心是ndarray
对象,即同构多维数组(homogeneous n-dimensional array)。
- 同构是指数组中所有元素都是同一数据类型(dtype)。Numpy关注数值计算,因此dtype基本都是浮点数,而标准的双精度浮点值需要占用8字节,也就是64位,所以dtype默认是"float64"。
- 多维是指数组可以有多个维度,维度也叫轴(axis)。一个拥有三个轴的数组如下图所示,左边是三维数组的直观理解(Intuitive View),右边是Numpy中打印出的形式(Python view):
然而在内存中,高维数组并不长这样,而是在连续内存块(contiguous block of memory)中存储。换句话说,高维数组在内存中是以“一维数组”的方式呈现的。
在内存中有两种存储高维数组的方式:
- 行主序(row-major order):每行的元素彼此相邻;C语言默认行主序
- 列主序(column-major-order):每列的元素彼此相邻;Fortran语言默认列主序
Numpy中默认行主序,所以三维数组在内存中默认存储方式(Memory View)和Numpy打印出的样子是不一样的(Python View):
这就引出了跨度(strides)的概念。跨度,是在某个轴下从当前元素到下个元素需要跨过的字节数。假设上面三维数组的数据类型都是float64(8个字节),那么:
- axis 0:沿着轴0获取下一个元素需要跨过32个字节(4个元素);
- axis 1:沿着轴1获取下一个元素需要跨过16个字节(2个元素);
- axis 2:沿着轴2获取下一个元素需要跨过8个字节(1个元素);
所以,上面这个三维数组的跨度就是(32, 16, 8)。
因此,ndarray并不是简单的索引+数据,而是内存块+数据类型描述符+索引方案,一个数组实质上由数据(data)、数据类型(dtype)、形状(shape)、跨度(strides)四大块内容组成的:
这样的存储方式的好处是什么呢?
- 连续存储,节省内存空间,而且访问数组的时候也会更快,相比之下,Python中的list存储在不连续的区域,所以list的索引就很低效;
- 操作迅速,高维数组的操作变为一维,通过预先编译好的C或Fortran执行,比原生Python快了很多数量级。
一、创建数组的三种基本方法
Numpy提供了三类创建数组的方法:
# 1. 常规创建:np.array(array_like)
arr = np.array([1, 2, 3])
# 2. 固定间隔或个数:np.arange(start, stop, step)和np.linspace(start, stop, num)
arr1 = np.arange(0, 10, 2) # 0, 2, 4, 6, 8
arr2 = np.linspace(0, 10, num=5) # 0到10之间5个等间距数字
# 3. 占位符:np.ones(shape), np.zeros(shape), np.random.random(shape)等
arr1 = np.ones(2) # 全是1
arr2 = np.zeros((2,3)) # 全是0
arr3 = np.random.random((2,3)) # 0到1之间随机值
arr4 = np.empty_like(x) # 形状和x相同的空数组
arr5 = np.random.randn(4) # 满足正态分布的随机值
Python中万物皆对象,数组也是一个对象。我们来创建一个shape为(2,2,2)、数据类型为float64的三维数组,看看数组的跨度和之前的分析是否一致。
# 创建三维数组
arr = np.arange(1, 9, dtype="float64").reshape(2,2,2)
print(arr)
# [[[1. 2.]
# [3. 4.]]
# [[5. 6.]
# [7. 8.]]]
# dtype:数据类型
print( 'The dtype is {}'.format(arr.dtype))
# output: The dtype is float64
# shape:每个维度的元素数量
print( 'The shape is {}'.format(arr.shape))
# output: The shape is (2, 2, 2)
# ndim:维度数量,也叫轴的数量
print( 'The dimension is {}'.format(arr.ndim))
# output: The dimension is 3
# size:所有元素的数量。size=shape的乘积
print( 'The number of elements is {}'.format(arr.size))
# output: The number of elements is 8
# strides:跨度
print( 'The strides is {}'.format(arr.strides))
# output: The strides is (32, 16, 8)
可以看到,strides=(32,16,8)是数组的跨度,这和我们之前的分析是一样的。除此之外,ndim=3是轴的数量,size=8是所有元素的数量。这些数组的性质都是有内在联系的,记
二、数组的合并与分裂
除了这三种创建数组的基础办法,我们还可以通过已有数组的合并和分裂,生成新的数组。
- 合并是多个合成一个,
np.concatenate((arr1, arr2), axis)
- 分裂是一个分成多个,
np.split(arr, indices ,axis)
这里的axis参数意味着需要指定特定的轴,比如np.split(arr, indices ,axis)
中,axis=0时,就意味着按行分裂,axis=1时,就意味着把列分开。
Numpy关于合并和分裂还提供了很多很多很多操作,比如合并时可以用np.vstack(arr1, arr2)
来代替np.concatenate((arr1, arr2), axis=0)
,但是经过如下的测试发现,np.concatenate()
速度更快:
arr1 = np.ones((100, 100))
arr2 = np.zeros((100, 100))
%timeit np.concatenate((arr1, arr2), axis=0)
# 7.52 µs ± 219 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit np.vstack((arr1, arr2))
# 11.6 µs ± 369 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
作为一个优雅但记忆力不好的人,一个操作只要记住一种还不错的方法并且能熟练使用就行了,记得多了就混淆了。所以,合并用concatenate,分裂用split,就足够了。
数组计算
Numpy快速计算的核心概念是矢量化(Vectorization),即,用数组表达式代替显示循环(replace explicit loops with array expressions)。在Python中对list等数据结构进行循环时,会涉及大量的开销。矢量化将循环操作交由性能更好的C和Fortran,从而让Python代码既简洁又高效。
比如计算均方误差时,要用到公式:
如果直接在Python中做,需要进行大量循环操作,写出的代码不容易读,而且执行起来还贼慢。但在Numpy中,可以简洁地将需要完成的计算以数组的形式表现出来,error = (1/n) * np.sum(np.square(predictions - labels))
,表面上是逐元素计算(element-wise),实际上背后的循环操作已经交给效率更高的C和Fortran执行了。
Numpy中,数组的计算可以归纳为以下三种情况:
- 单个数组:逐元素计算 + 整合计算
- 两个形状相同的数组:逐元素计算
- 两个形状不同但相容的数组:触发广播机制后逐元素计算
接下来就详细介绍一下逐元素计算、广播机制和整合计算。
一、逐元素计算
逐元素计算,可以细分为三种情况:
- 一元函数:以e为底的幂运算
np.exp()
, 以n为底的幂运算np.power(arr, n)
, 平方根np.sqrt()
, 自然对数np.log()
等; - 二元函数:加法
np.add()
, 减法np.substract()
, 乘法np.multiply()
, 除法np.divide()
, 除余np.remain()
等; - 比较运算:相等
np.equal()
,大于np.greater()
,小于等于np.less_equal()
等。
# 创建两个二维数组
arr1 = np.arange(1,5).reshape((2,2))
# [[1 2]
# [3 4]]
arr2 = np.arange(5,9).reshape((2,2))
# [[5 6]
# [7 8]]
# 一元函数:平方根
print("Square root of Array 1n {}".format(np.sqrt(arr1)))
# Square root of Array 1
# [[1. 1.41421356]
# [1.73205081 2. ]]
# 二元函数:加法
print("Array 1 + Array 2n {}".format(np.add(arr1, arr2)))
# Array 1 + Array 2
# [[ 6 8]
# [10 12]]
# 比较运算:大于
print("Array 1 > Array 2n {}".format(np.greater(arr1, arr2)))
# Array 1 > Array 2
# [[ False False]
# [ False False ]]
需要补充的有四点:
- 以上的运算大都可以用对应的数学符号表示,比如加法可以写成
arr1+arr2
; - 以上函数可以统称为通用函数(universal function),这些函数还可以接受一个out可选参数,这样就能在数组原地进行操作,从而使得运算更加高效;
arr = np.random.rand(5000,5000)
%timeit np.add(arr, 1)
# 104 ms ± 1.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit np.add(arr, 1, out=arr)
# 17.9 ms ± 551 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
- 数组和数组的乘法是逐元素的,而不是线性代数中的矩阵相乘。
np.dot(arr1, arr2)
可以实现矩阵相乘; - Numpy提供了子模块
numpy.linalg
来进行基础的线性代数运算,然而,更好的解决方案是用scipy提供的scipy.linalg
。
二、广播机制
广播(Broadcasting)是保证不同形状的数组之间矢量化操作的机制。 当两个数组形状相同时,直接逐元素计算就可以了;当数组形状不同时,如果两个数组相容(compatible),就会触发广播机制,从而也可以进行逐元素计算。
两个数组是否相容,比较的是各个维度是否相容。如果满足如下两个条件之一,就说两个维度是相容的:
- 相等(equal)
- 有一个是1(one of them is 1).
具体而言,从两个数组形状元组(shape)的最后一个元素开始,逐个元素地检查维度是否相容,如果所有维度都相容,那么就可以触发广播机制了。
比如,第一个数组的形状为(3, 2),第二个数组的形状为(3, 1),因为轴1的两个数组有一个是1,轴0的两个数字相等,所以这两个数组就是相容的:
arr1 = np.ones((3, 2))
arr2 = np.ones((3, 1))
arr1 + arr2
# array([[2., 2.],
# [2., 2.],
# [2., 2.]])
广播机制非常非常有用,而且非常非常常见。说自己会Numpy却不知道broadcasting就好比说自己是杰伦的粉丝却不知道《七里香》一样。比如最简单的,要对某数组的前三个元素统一赋值为20,就可以用“切片+广播”一行代码搞定:
arr = np.ones(10)
arr[:3] = 20
arr
# array([20., 20., 20., 1., 1., 1., 1., 1., 1., 1.])
再比如要手算两个向量的外积时也可以使用广播机制。正常情况下,Numpy是不区分行向量和列向量的,所有的一维数组的shape都是(n,)的形式,就算进行转置,shape也是(n,)。所以两个一维数组直接相乘,不会触发广播机制,而是直接进行逐元素计算。通过np.newaxis
就可以人为增加一个轴,从而让两个向量的形状变成(1,3)和(3,1),这样就会触发广播机制,成功计算外积:
arr = np.arange(3)
row_vector = arr[np.newaxis, :] # shape: (1, 3)
col_vector = arr[:, np.newaxis] # shape: (3, 1)
row_vector * col_vector
# array([[0, 0, 0],
# [0, 1, 2],
# [0, 2, 4]])
三、整合计算
整合(aggregation),也叫reduction,其实就是求统计信息,比如最大值(max)、最小值(min)、平均值(mean)、标准差(std)等。
以求最大值为例,Numpy不仅可以让我们求整个数组的最大值(arr.max()
),而且可以求每一个维度上的最大值(arr.max(axis=0)
, arr.max(axis=1)
)。
看似简单,但实际上使用起来很容易混淆行和列。看个例子:
arr = np.arange(10).reshape(2, 5)
print('每行的平均值:{}'.format(arr.mean(axis=1)))
# 每行的平均值:[2. 7.]
print('每列的标准差:{}'.format(arr.std(axis=0)))
# 每列的标准差:[2.5 2.5 2.5 2.5 2.5]
小朋友,你一定有两个问号:
- 求每行的平均值,为什么axis=1呢,行不是对应轴0吗?
- 求每列的标准差,为什么不指定axis=1呢?列不是对应轴1吗?
Jake VanderPlas的这句话完美地回答了你的问号:
The axis keyword specifies the dimension of the array that will be collapsed, rather than the dimension that will be returned.
指定哪个轴,哪个轴就会collapse。比如,求每行的平均值,实际上就是把所有列压缩了,所以axis要设为1;求每列的标准差,实际上在把每行的元素collapse,所以axis要设为0.
除此之外,当我们理解了数组在内存中的存储方式,就很容易推出,在行主序下,如果求每行的整合计算,速度应该会更快,因为内存中每行的元素是彼此相邻的。验证一下:
arr = np.random.random((300, 300))
# 求每列的和
%timeit arr.sum(axis=0)
# 62.5 µs ± 3.4 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
# 求每行的和
%timeit arr.sum(axis=1)
# 50.6 µs ± 211 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
数组操作
拷贝和视图是使用Numpy进行数组操作时非常重要的概念:
- 拷贝(copy):也叫深拷贝,复制了一个新的数组,速度慢,但改变新数组不会改变原数组。
- 视图(view):也叫浅拷贝,直接取原数组的索引部分,速度快,但改变新数组会改变原数组。
copy和view各有优缺点,使用时要权衡利弊,谨慎使用。
view的优点在于快。一般来说,Numpy中很多操作都是尽可能用view的,因为可以节省内存,速度更快,比如切片。但是,修改view后的数组也会修改原来的数组,稍不注意就会让结果偏离预期。
copy也有很多用处。比如当中间结果数组a是个很大的数组,而最终结果b只是a的一小部分时,就可以先用copy从a中用切片把b复制出来,再将a的内存释放掉。
a = np.arange(int(1e8)) # 中间结果
b = a[:100].copy() # 最终结果
del a # 释放内存
接下来就撸一遍数组的操作,看看哪些操作是copy,哪些操作是view。把这些铭记在心,才可以在使用的时候游刃有余,避免出错。
一、数组的索引和切片
索引和切片的基本用法和list中一样:
但数组的索引和切片要更加灵活,而且背后的逻辑也不一样。一般而言,数组的获取方式有以下几种:
- 索引:获取特定元素,
arr[index]
,是copy - 布尔索引:用布尔 (boolean) 类型值数组来进行索引,是copy
- 花式索引:用索引数组进行索引,是copy
- 切片:截取一段元素,
arr[start:stop:step]
,是view
比如:
arr = np.arange(10)
# 索引 indexing
print("Array value with index 0: {}".format(arr[0]))
# Array value with index 0: 0
# 布尔索引 boolean indexing
mask = (arr > 6) | (arr < 3)
print("Array values with mask as true: {}".format(arr[mask]))
# Array values with mask as true: [0 1 2 7 8 9]
# 花式索引 Fancy indexing
ind = [2,4,6]
print("Array values with index in list: {}".format(arr[ind]))
# Array values with index in list: [2 4 6]
# 切片 slicing
print("Array values from index 0 to 4: {}".format(arr[:5]))
# Array values from index 0 to 4: [0 1 2 3 4]
下图是各种数组获取方式的混杂,这些骚操作都是要掌握的:
需要特别注意的就是,切片是view,而索引是copy。这意味着,修改索引后数组的值,不会修改原数组;修改切片后数组的值,原数组也会发生改变:
# 创建一维数组
arr1 = np.arange(6)
arr2 = np.arange(6)
# 索引:获取特定元素
arr_copy = arr1[[0, 1, 2]]
arr_copy[:] = 10
print(arr1)
# [0 1 2 3 4 5]
# 切片:截取一段元素
arr_view = arr2[:3]
arr_view[:] = 10
print(arr2)
# [10 10 10 3 4 5]
二、数组的重塑(reshape)、打平(ravel)和转置(transpose)
数组的重塑、打平和转置实际上就是改变数组的形状,或者说改变数组的跨度:
- 重塑是从低维到高维,
arr.reshape(shape)
一般是view - 打平是从高维到低维,
arr.ravel()
一般是view,arr.flatten()
是copy - 转置是特殊的重塑,
arr.T
或arr.transpose()
是view
为了追求速度,数组的变形都是直接修改跨度信息,返回原数组的view,避免进行copy。比如,转置二维数组时直接交换两个轴的跨度就可以了,所以转置是view:
但是,对一个转置后的数组进行打平或再重塑,就不能简单地通过修改跨度来实现了,这时候就只能在新的内存块中copy一份出来。所以在重塑或打平时既可能返回view也可能返回copy,需要特别注意。
# 创建两个一模一样的二维数组
arr1 = np.ones((3, 3))
arr2 = np.ones((3, 3))
# 对arr1直接进行reshape,此时,直接修改跨度信息,所以是view
arr_view = arr1.reshape((1, -1))
arr_view[:] = 10
print(arr1)
# [[10. 10. 10.]
# [10. 10. 10.]
# [10. 10. 10.]]
# 对arr2先转置再进行reshape,此时,不能直接修改跨度信息,所以是copy
arr_copy = arr2.T.reshape((1, -1))
arr_copy[:] = 10
print(arr2)
# [[1. 1. 1.]
# [1. 1. 1.]
# [1. 1. 1.]]
总结
啰里啰唆说了一大堆,总结一下就是,在使用Numpy时要注意以下三点:
- 理解数组的连续存储方式,理解形状、跨度和轴;
- 计算时尽量使用矢量化方法,将循环和条件逻辑转换为数组运算和布尔逻辑运算;
- 操作时尽量使用数组视图,避免复制数据;
Numpy还有很多深奥的逻辑和进阶的用法,这篇文章只是我个人学习Numpy的笔记,如有错误欢迎批评指正,如对你有帮助,可千万别忘了素质三连呀。
本文的参考资料包括但不限于:
- https://numpy.org/
- https://scipy-lectures.org/intro/numpy/index.html
- https://realpython.com/numpy-array-programming/#reader-comments
- https://ipython-books.github.io/45-understanding-the-internals-of-numpy-to-avoid-unnecessary-array-copying/
- https://mp.weixin.qq.com/s/nWu_PE5U7EASwJLYlyZcNA
- https://zhuanlan.zhihu.com/p/28626431