###引出
分段函数根据自变量的取值范围决定不同的计算方式,Numpy中提供了多种计算分段函数的方法。
方便起见,在这里使用一个分段函数的例子:计算三角波形(例子取自张若愚的《Python科学计算》)
三角波形具有周期性,因此我们只需要考虑0~1之间的这个范围就可以了,在这个范围里,又分了三个不同的表达式。
###最直观的自定义函数
分段函数就是分类讨论嘛,只要针对不同的x值,选择范围,计算就可以了,因此:
def triangle_wave(x, c, c0, hc):
x = x - int(x) #三角波周期为1 因此只取小数部分进行计算
if x < c0:
return x / c0 * hc
elif x >= c:
return 0.0
else:
return (c-x)/(c-c0)*hc
其中c,c0,还有hc都是图中给出的固定值,这个函数是最直观的方式,使用多个if-else,简单易懂,使用这个函数进行计算:
import numpy as np
import matplotlib.pyplot as plt
def triangle_wave(x, c, c0, hc):
x = x - int(x) #三角波周期为1 因此只取小数部分进行计算
if x < c0:
return x / c0 * hc
elif x >= c:
return 0.0
else:
return (c-x)/(c-c0)*hc
x = np.linspace(0, 2, 1000)
y = np.array([triangle_wave(t, 0.6, 0.4, 1.0) for t in x])
plt.figure()
plt.plot(x, y)
plt.ylim(-0.2, 1.2) #限制y的范围
plt.show()
这里在给出x的值之后,使用了列表解析的方式得到y的值,也就是说,每取出一个x值,计算一个对应的y值,对于这个特点,我们可以想到Numpy中的 ufunc 的方式(也就是universal function 它是一种能对数组中每个元素进行操作的的函数(大部分numpy中内置的ufunc函数是再C语言级别实现的,因此运行的速度非常快),对上面计算y的部分进行优化:
x = np.linspace(0, 2, 1000)
#使用frompyfunc()将计算单个值的函数转化为计算数组中每个元素的函数
#frompyfunc(func, nin, nout)
#func 是要进行转化的函数 nin是输入参数的个数 nout是返回值的个数
triangle_ufunc1 = np.frompyfunc(triangle_wave, 4, 1)
y = triangle_ufunc1(x, 0.6, 0.4, 1.0)
#此时返回数组的元素类型是object
y = y.astype(np.float)
其实就是使用内置的 frompyfunc 函数将一个普通的自定义函数,转化为 ufunc 方式,注意上面y的生成那行代码,还是非常简单的是吧,而且这种方式运行也会更快一些。
这是最直接的方式。
###使用np.where()
Numpy中提供了判断表达式 np.where(),用于简化这种多判断的情况:
使用的格式:
x = y if condition else z
如果condition成立,x的值就取自y数组,否则就是z数组,如果y, z是单个的值,那么就使用广播的方式扩展为长度一样的数组。例如,改变上面例子中的函数:
import numpy as np
import matplotlib.pyplot as plt
def triangle_wave(x, c, c0, hc):
x = x - int(x)
return np.where(x>=c, 0, np.where(x<c0, x/c0*hc, (c-x)/(c-c0)*hc))
# 使用多个where嵌套的方式
x = np.linspace(0, 2, 1000)
triangle_ufunc = np.frompyfunc(triangle_wave, 4, 1)
y = triangle_ufunc(x, 0.6, 0.4, 1.0)
y = y.astype(np.float)
plt.figure()
plt.plot(x, y)
plt.ylim(-0.2, 1.2)
plt.show()
return np.where(x>=c, 0, np.where (x <c0, x/c0*hc, (c-x)/(c-c0)*hc))
核心就是这一句代码,首先,最外层的where ,condition是x>=c,如果为真,也就是返回值0,否则,就转入第二个判断(where),它的condition是 x< c0 ,同样的方式判断为真或者假分别是后边的两个值,这里是单个值的情况,因此会相当于使用广播的方式将一个值拓展为相等长度的列表(这样就符合判断表达式的条件了)。
这是使用where的方式。
###使用np.select()
通过例子,可以看出来,其实where的嵌套,就是将多个if - else写在一行而已,所以,在有很多判断的时候也会很麻烦,因此,就出现了较为简单的select()方式:
调用格式:
select(condlist, choicelist, default=0)
其中 condlist 是一个长度为N的bool列表 choicelist则是长度为N的候选值的列表
所有候选的列表长度都是一样的(M) 如果choicelist非列表而是单个数值 那么也就相当于一个
长度为M的列表 只是所有的元素都是这个数值而已(其实相当于广播的方式扩展)
其实这么说还挺抽象的,用例子的话:
def triangle_wave(x, c, c0, hc):
x -= int(x)
return np.select([x>=c, x<c0, True], [0, x/c0*hc, (c-x)/(c-c0)*hc])
只有函数部分做了改变,仍然是只有一行核心:
return np.select([x>=c, x<c0, True], [0, x/c0*hc, (c-x)/(c-c0)*hc])
我们来看一下第一个数组:
[x>=c, x<c0, True]
可以发现,前两个就是不同取值的条件,其实True是第三个条件,只不过这里True的意思相当于除了前两种情况之外的意思,也可以把它写开,只是比较麻烦而已,
对应的第二个数组:
[0, x/c0*hc, (c-x)/(c-c0)*hc]
三个参数分别对应第一个数组的条件,因此,使用就很清楚了,无非是使用数组的方式,将不同的条件对应上不同的计算方式而已,可以看出,这样就不用写很多嵌套关系,写法也会清楚一些了。
在Numpy中最基础的分段函数计算方式就是这样了,另外附上使用matplotlib画出的图形,跟最开始的图像是一样的:
以上~