LogisticRegression训练及决策边界可视化

import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
import numpy as np
from sklearn.model_selection import train_test_split

# 生成四组数据,也就是要进行分类的数据有四个类别
x1 = np.random.normal(loc=5, scale=1, size=(100, 1))
x2 = np.random.normal(loc=-2, scale=1, size=(100, 1))
x3 = np.random.normal(loc=4, scale=1, size=(100, 1))
x4 = np.random.normal(loc=0, scale=1, size=(100, 1))
y1 = np.random.normal(loc=10, scale=1, size=(100, 1))
y2 = np.random.normal(loc=1, scale=1, size=(100, 1))
y3 = np.random.normal(loc=0.5, scale=1, size=(100, 1))
y4 = np.random.normal(loc=-5, scale=1, size=(100, 1))

# 对数据进行处理,x和y是两个属性(或特征),也就是决定这个样例是是什么类别是由两个属性来判断的;
# 对应iris数据:xy属性就是feature_names, 下面的target对应的就是target_names,也就是标签
c0 = np.hstack((x1, y1))
c1 = np.hstack((x2, y2))
c2 = np.hstack((x3, y3))
c3 = np.hstack((x4, y4))
c = np.vstack((c0, c1, c2, c3))
# 四个标签,分别为0,1,2,3
target_0 = np.zeros((100, 1))
target_1 = np.ones((100, 1))
target_2 = target_1  * 2
target_3 = target_1 * 3
y = np.vstack((target_0, target_1, target_2, target_3))

# 生成训练数据和测试数据
X_train, X_test, y_train, y_test = train_test_split(c, y, test_size=0.25)
model = LogisticRegression(C=1e5, solver='lbfgs', multi_class='multinomial')
model.fit(X_train, y_train)
x_min, y_min = np.min(X_train, axis=0)
x_max, y_max = np.max(X_train, axis=0)
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 500), np.linspace(y_min, y_max, 500))
# 或者用 z=model.predict(np.c_[xx.ravel(), yy.ravel()])也是可以的
z = model.predict(np.hstack((xx.ravel().reshape(-1, 1), yy.ravel().reshape(-1, 1))))

# 绘图部分,绘图时使用plt.cm.Paired可以让点的颜色和网格块的颜色相匹配
fig = plt.figure()
plt.pcolormesh(xx, yy, z.reshape(xx.shape), shading='auto', cmap=plt.cm.Paired, alpha=0.5)
# 上一条语句也可以用plt.contourf(xx, yy, z.reshape(xx.shape), cmap=plt.cm.Paired),效果一样
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, edgecolor='k', cmap=plt.cm.Paired)
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.xticks(())
plt.yticks(())
plt.xlabel('x values')
plt.ylabel('y values')
plt.show()

# 下面段代码显示出原始数据点的分布
plt.scatter(x1, y1, label='class 1', alpha=0.5)
plt.scatter(x2, y2, label='class 2', alpha=0.5)
plt.scatter(x3, y3, label='class 3', alpha=0.5)
plt.scatter(x4, y4, label='class 4', alpha=0.5)
plt.legend()
plt.show()

结果如下图所示: 

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值