目录
简介
Logistic回归,尽管它的名字是一个分类,但是属于回归的线性模型。Logistic回归在文献中也称为logit回归,最大熵分类(MaxEnt)或对数线性分类器。
计算过程
。。。
scikit-learn实现
linear_model.LogisticRegression,
Logistic回归分类器。实现可以适合二元,一对多或多元逻辑回归与可选的L2或L1正则化。
linear_model.LogisticRegressionCV
,Logistic回归CV分类器。使用内置交叉验证实现Logistic回归,以找出最佳C参数
这是一个简单的逻辑回归案例,帮助理解训练、预测模型:
import numpy as np
from sklearn.datasets import make_moons #制作两个交叉的半圈
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
# 定义决策边界制图函数
def plot_decision_boundary(pred_func):
# 设定最大最小值,附加一点点边缘填充
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
h = 0.01
#使用meshgrid生成坐标矩阵
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
# 用预测函数预测一下
Z = pred_func(np.c_[xx.ravel(), yy.ravel()]) #返回一个扁平数组
Z = Z.reshape(xx.shape)
# 然后画出图
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral) #轮廓图,即填充区域颜色
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Spectral) #散点图
#生成数据,并制图
np.random.seed(0)
X,y=make_moons(500,noise=0.2) #生成两个半园型数组
X.shape
plt.scatter(X[:,0],X[:,1],s=40,c=y,cmap=plt.cm.Spectral)
plt.title('The Data Distribution')
plt.show()
clf=LogisticRegression() #实例化逻辑回归
clf.fit(X,y) #训练模型
plot_decision_boundary(lambda x:clf.predict(x)) #进行预测,给定模型预测X的目标值(低级方法)
plt.title('Logistic Regression')
plt.show()
这是一个预测模型案例,帮助理解逻辑回归:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model #广义线性模型,利用最小角度回归和坐标下降计算的岭回归,贝叶斯回归
from sklearn import decomposition #矩阵分解,矩阵分解算法,包括PCA,NMF或ICA。该模块的大多数算法可以被视为降维技术
from sklearn import datasets #数据集,包括加载和获取常用参考数据集的方法。它还具有一些人工数据生成器
from sklearn.pipeline import Pipeline #管道,实现了用于构建复合估计器的实用程序,作为变换和估计器链
from sklearn.model_selection import GridSearchCV #超参数优化器
#model_selection和grid_search是迭代了??
logistic = linear_model.LogisticRegression() #实例化逻辑回归
pca = decomposition.PCA() #主成分分析
pipe = Pipeline(steps=[('pca', pca), ('logistic', logistic)]) #使用最终估算器进行变换的流水线。
digits = datasets.load_digits() #下载数据
X_digits = digits.data
y_digits = digits.target
pca.fit(X_digits) #训练模型
#制图
plt.figure(1, figsize=(4, 3))
plt.clf() #清楚当前数据
plt.axes([.2, .2, .7, .7]) #从图中移除Axes 斧头(默认为当前轴)
plt.plot(pca.explained_variance_, linewidth=2) #绘制y与x作为线和/或标记
plt.axis('tight') #获取或设置某些轴属性的便捷方法
plt.xlabel('n_components') #x轴标题
plt.ylabel('explained_variance_') #y轴标题
#预