经常用户数据的过滤,
filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回由符合条件元素组成的新列表。该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。
以下是 filter() 方法的语法:
filter(function, iterable)
比如我们有如下数据:
34.62365962451697,78.0246928153624,0
30.28671076822607,43.89499752400101,0
35.84740876993872,72.90219802708364,0
60.18259938620976,86.30855209546826,1
79.0327360507101,75.3443764369103,1
45.08327747668339,56.3163717815305,0
61.10666453684766,96.51142588489624,1
75.02474556738889,46.55401354116538,1
76.09878670226257,87.42056971926803,1
84.43281996120035,43.53339331072109,1
95.86155507093572,38.22527805795094,0
每一行数据分三列,分别是x值,y值,以及z值
如果我们想分别用pyplot画出z值为1和0的点,那么首先就要筛选出z值不同的x,y数据,可以这么做()假设数据保存在data.txt中
import matplotlib.pyplot as plt
def get_data():
with open('data/data.txt','r') as fs:
data_list=fs.readlines()
data_list=[i.split('\n')[0] for i in data_list]
data_list=[i.split(',') for i in data_list]
data=[[float(i[0]),float(i[1]),float(i[2])] for i in data_list]
return data
data=get_data()
x0=list(filter(lambda x:x[-1]==0.0,data))
x1=list(filter(lambda x:x[-1]==1.0,data))
plot_x0_x=[i[0] for i in x0]
plot_x0_y=[i[1] for i in x0]
plot_x1_x=[i[0] for i in x1]
plot_x1_y=[i[1] for i in x1]
plt.plot(plot_x0_x,plot_x0_y,'ro')
plt.plot(plot_x1_x,plot_x1_y,'bo')
plt.show()
核心代码:
x0=list(filter(lambda x:x[-1]==0.0,data))
x1=list(filter(lambda x:x[-1]==1.0,data))
关于lambda,参见这里