Python - Numpy中对分段函数的一点总结

###引出
分段函数根据自变量的取值范围决定不同的计算方式,Numpy中提供了多种计算分段函数的方法。
方便起见,在这里使用一个分段函数的例子:计算三角波形(例子取自张若愚的《Python科学计算》)
这里写图片描述
三角波形具有周期性,因此我们只需要考虑0~1之间的这个范围就可以了,在这个范围里,又分了三个不同的表达式。

###最直观的自定义函数

分段函数就是分类讨论嘛,只要针对不同的x值,选择范围,计算就可以了,因此:

def triangle_wave(x, c, c0, hc):
    x = x - int(x)  #三角波周期为1 因此只取小数部分进行计算
    if x < c0:
        return x / c0 * hc
    elif x >= c:
        return 0.0
    else:
        return (c-x)/(c-c0)*hc

其中c,c0,还有hc都是图中给出的固定值,这个函数是最直观的方式,使用多个if-else,简单易懂,使用这个函数进行计算:

import numpy as np
import matplotlib.pyplot as plt

def triangle_wave(x, c, c0, hc):
    x = x - int(x)  #三角波周期为1 因此只取小数部分进行计算
    if x < c0:
        return x / c0 * hc
    elif x >= c:
        return 0.0
    else:
        return (c-x)/(c-c0)*hc


x = np.linspace(0, 2, 1000)
y = np.array([triangle_wave(t, 0.6, 0.4, 1.0) for t in x])

plt.figure()
plt.plot(x, y)
plt.ylim(-0.2, 1.2)   #限制y的范围
plt.show()

这里在给出x的值之后,使用了列表解析的方式得到y的值,也就是说,每取出一个x值,计算一个对应的y值,对于这个特点,我们可以想到Numpy中的 ufunc 的方式(也就是universal function 它是一种能对数组中每个元素进行操作的的函数(大部分numpy中内置的ufunc函数是再C语言级别实现的,因此运行的速度非常快),对上面计算y的部分进行优化:

x = np.linspace(0, 2, 1000)
#使用frompyfunc()将计算单个值的函数转化为计算数组中每个元素的函数
#frompyfunc(func, nin, nout)
#func  是要进行转化的函数  nin是输入参数的个数  nout是返回值的个数
triangle_ufunc1 = np.frompyfunc(triangle_wave, 4, 1)
y = triangle_ufunc1(x, 0.6, 0.4, 1.0)
#此时返回数组的元素类型是object
y = y.astype(np.float)

其实就是使用内置的 frompyfunc 函数将一个普通的自定义函数,转化为 ufunc 方式,注意上面y的生成那行代码,还是非常简单的是吧,而且这种方式运行也会更快一些。
这是最直接的方式。
###使用np.where()
Numpy中提供了判断表达式 np.where(),用于简化这种多判断的情况:
使用的格式:

x = y if condition else z

如果condition成立,x的值就取自y数组,否则就是z数组,如果y, z是单个的值,那么就使用广播的方式扩展为长度一样的数组。例如,改变上面例子中的函数:

import numpy as np
import matplotlib.pyplot as plt


def triangle_wave(x, c, c0, hc):
	x = x - int(x)
	return np.where(x>=c, 0, np.where(x<c0, x/c0*hc, (c-x)/(c-c0)*hc))
# 使用多个where嵌套的方式 

x = np.linspace(0, 2, 1000)
triangle_ufunc = np.frompyfunc(triangle_wave, 4, 1)
y = triangle_ufunc(x, 0.6, 0.4, 1.0)
y = y.astype(np.float)

plt.figure()
plt.plot(x, y)
plt.ylim(-0.2, 1.2)
plt.show()
return np.where(x>=c, 0, np.where (x <c0, x/c0*hc, (c-x)/(c-c0)*hc))

核心就是这一句代码,首先,最外层的where ,condition是x>=c,如果为真,也就是返回值0,否则,就转入第二个判断(where),它的condition是 x< c0 ,同样的方式判断为真或者假分别是后边的两个值,这里是单个值的情况,因此会相当于使用广播的方式将一个值拓展为相等长度的列表(这样就符合判断表达式的条件了)。
这是使用where的方式。

###使用np.select()
通过例子,可以看出来,其实where的嵌套,就是将多个if - else写在一行而已,所以,在有很多判断的时候也会很麻烦,因此,就出现了较为简单的select()方式:

调用格式:
select(condlist, choicelist, default=0)
其中 condlist 是一个长度为N的bool列表 choicelist则是长度为N的候选值的列表
所有候选的列表长度都是一样的(M) 如果choicelist非列表而是单个数值 那么也就相当于一个
长度为M的列表 只是所有的元素都是这个数值而已(其实相当于广播的方式扩展)

其实这么说还挺抽象的,用例子的话:

def triangle_wave(x, c, c0, hc):
	x -= int(x)
	return np.select([x>=c, x<c0, True], [0, x/c0*hc, (c-x)/(c-c0)*hc])

只有函数部分做了改变,仍然是只有一行核心:

return np.select([x>=c, x<c0, True], [0, x/c0*hc, (c-x)/(c-c0)*hc])

我们来看一下第一个数组:

[x>=c, x<c0, True]

可以发现,前两个就是不同取值的条件,其实True是第三个条件,只不过这里True的意思相当于除了前两种情况之外的意思,也可以把它写开,只是比较麻烦而已,
对应的第二个数组:

[0, x/c0*hc, (c-x)/(c-c0)*hc]

三个参数分别对应第一个数组的条件,因此,使用就很清楚了,无非是使用数组的方式,将不同的条件对应上不同的计算方式而已,可以看出,这样就不用写很多嵌套关系,写法也会清楚一些了。

在Numpy中最基础的分段函数计算方式就是这样了,另外附上使用matplotlib画出的图形,跟最开始的图像是一样的:
这里写图片描述

以上~

发布了190 篇原创文章 · 获赞 192 · 访问量 39万+
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 大白 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览