自己在尝试用逻辑回归处理鸢尾花数据时,遇到了很多坑,在这里分享一下代码和作图原理。
1.首先导入包:
import numpy as np
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn import preprocessing
import pandas as pd
from sklearn.preprocessing import StandardScaler,PolynomialFeatures
from sklearn.pipeline import Pipeline
部分包的用法待用到时再做解释
2.导入数据 这里定义了一个函数处理第四行的数据
def iris_type(s):
it = {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2}
#print(it[s.decode('utf-8')])
return it[s.decode('utf-8')]
if __name__ == "__main__":
path = u'8.iris.data' # 数据文件路径
data = np.loadtxt(path, dtype=float, delimiter=',', converters={4: iris_type})
#print (data)
# 将数据的0到3列组成x,第4列得到y
x, y = np.split(data, (4,), axis=1)
# 为了可视化,仅使用前两列特征
x = x[:, :2]
3.用pipline建立模型
lr = Pipeline([('sc', StandardScaler()), # 先做标准化
('poly',PolynomialFeatures(degree=1)),
('clf', LogisticRegression())]
)
lr.fit(x, y.ravel())
StandardScaler()作用:去均值和方差归一化。且是针对每一个特征维度来做的,而不是针对样本。 StandardScaler对每列分别标准化。
PolynomialFeatures(degree=1):进行特征的构造。它是使用多项式的方法来进行的,如果有a,b两个特征,那么它的2次多项式为(1,a,b,a^2,ab, b^2)。PolynomialFeatures有三个参数:
1.degree:控制多项式的度
2.interaction_only: 默认为False,如果指定为True,那么就不会有特征自己和自己结合的项,上面的二次项中没有a2和b2。
3.include_bias:默认为True。如果为True的话,那么就会有上面的 1那一项。
LogisticRegression()建立逻辑回归模型
4.画图准备
N, M = 500, 500 # 横纵各采样多少个值
x1_min, x1_max = x[:, 0].min(), x[:, 0].max() # 第0列的范围
x2_min, x2_max = x[:, 1].min(), x[:, 1].max() # 第1列的范围
t1 = np.linspace(x1_min, x1_max, N)
t2 = np.linspace(x2_min, x2_max, M)
x1, x2 = np.meshgrid(t1, t2) # 生成网格采样点
x_test = np.stack((x1.flat, x2.flat), axis=1) # 测试点
5.开始画图
cm_light = mpl.colors.ListedColormap(['#77E0A0', '#FF8080', '#A0A0FF'])
cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b'])
y_hat = lr.predict(x_test) # 预测值
y_hat = y_hat.reshape(x1.shape) # 使之与输入的形状相同
#print(y_hat)
plt.pcolormesh(x1, x2, y_hat, cmap=cm_light) # 预测值的显示 其实就是背景
plt.scatter(x[:, 0], x[:, 1], c=y.ravel(), edgecolors='k', s=50, cmap=cm_dark) # 样本的显示
plt.xlabel('petal length')
plt.ylabel('petal width')
plt.xlim(x1_min, x1_max)
plt.ylim(x2_min, x2_max)
plt.grid()
plt.savefig('2.png')
plt.show()
重点关注scatter的用法,不清楚的朋友可以去查查资料
6.训练集上的预测结果
y_hat = lr.predict(x)
y = y.reshape(-1)
result = y_hat == y
print(y_hat)
print(result)
acc = np.mean(result)
print('准确度: %.2f%%' % (100 * acc))
打印结果如下:
[0. … 1.]
[ True … False]
准确度: 80.00%