使用sklearn中的神经网络模块MLPClassifier处理分类问题

MLPClassifier:参数详解--https://blog.csdn.net/weixin_38278334/article/details/83023958
生成网格点坐标矩阵--https://blog.csdn.net/lllxxq141592654/article/details/81532855
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neural_network import MLPClassifier   #MLPClassifier(多层感知器分类器)
from sklearn import datasets
import matplotlib
%matplotlib inline
# 生成所有测试样本点  
def make_meshgrid(x, y, h=.02):
    x_min, x_max = x.min() - 1, x.max() + 1
    y_min, y_max = y.min() - 1, y.max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),    
    np.arange(y_min, y_max, h))   
    return xx, yy 
xx
array([[ 3.3 ,  3.32,  3.34, ...,  8.84,  8.86,  8.88],
       [ 3.3 ,  3.32,  3.34, ...,  8.84,  8.86,  8.88],
       [ 3.3 ,  3.32,  3.34, ...,  8.84,  8.86,  8.88],
       ..., 
       [ 3.3 ,  3.32,  3.34, ...,  8.84,  8.86,  8.88],
       [ 3.3 ,  3.32,  3.34, ...,  8.84,  8.86,  8.88],
       [ 3.3 ,  3.32,  3.34, ...,  8.84,  8.86,  8.88]])
yy
array([[ 1.  ,  1.  ,  1.  , ...,  1.  ,  1.  ,  1.  ],
       [ 1.02,  1.02,  1.02, ...,  1.02,  1.02,  1.02],
       [ 1.04,  1.04,  1.04, ...,  1.04,  1.04,  1.04],
       ..., 
       [ 5.34,  5.34,  5.34, ...,  5.34,  5.34,  5.34],
       [ 5.36,  5.36,  5.36, ...,  5.36,  5.36,  5.36],
       [ 5.38,  5.38,  5.38, ...,  5.38,  5.38,  5.38]])
# 对测试样本进行预测,并显示
def plot_test_results(ax, clf, xx, yy, **params):
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    ax.contourf(xx, yy, Z, **params)
# 载入iris数据集
iris = datasets.load_iris()
# 只使用前面连个特征
X = iris.data[:,:2]
# 样本标签值
y = iris.target
X
array([[ 5.1,  3.5],
       [ 4.9,  3. ],
       [ 4.7,  3.2],
       [ 4.6,  3.1],
       [ 5. ,  3.6],
       [ 5.4,  3.9],
       [ 4.6,  3.4],
       [ 5. ,  3.4],
       [ 4.4,  2.9],
       [ 4.9,  3.1],
       [ 5.4,  3.7],
       [ 4.8,  3.4],
       [ 4.8,  3. ],
       [ 4.3,  3. ],
       [ 5.8,  4. ],
       [ 5.7,  4.4],
       [ 5.4,  3.9],
       [ 5.1,  3.5],
       [ 5.7,  3.8],
       [ 5.1,  3.8],
       [ 5.4,  3.4],
       [ 5.1,  3.7],
       [ 4.6,  3.6],
       [ 5.1,  3.3],
       [ 4.8,  3.4],
       [ 5. ,  3. ],
       [ 5. ,  3.4],
       [ 5.2,  3.5],
       [ 5.2,  3.4],
       [ 4.7,  3.2],
       [ 4.8,  3.1],
       [ 5.4,  3.4],
       [ 5.2,  4.1],
       [ 5.5,  4.2],
       [ 4.9,  3.1],
       [ 5. ,  3.2],
       [ 5.5,  3.5],
       [ 4.9,  3.1],
       [ 4.4,  3. ],
       [ 5.1,  3.4],
       [ 5. ,  3.5],
       [ 4.5,  2.3],
       [ 4.4,  3.2],
       [ 5. ,  3.5],
       [ 5.1,  3.8],
       [ 4.8,  3. ],
       [ 5.1,  3.8],
       [ 4.6,  3.2],
       [ 5.3,  3.7],
       [ 5. ,  3.3],
       [ 7. ,  3.2],
       [ 6.4,  3.2],
       [ 6.9,  3.1],
       [ 5.5,  2.3],
       [ 6.5,  2.8],
       [ 5.7,  2.8],
       [ 6.3,  3.3],
       [ 4.9,  2.4],
       [ 6.6,  2.9],
       [ 5.2,  2.7],
       [ 5. ,  2. ],
       [ 5.9,  3. ],
       [ 6. ,  2.2],
       [ 6.1,  2.9],
       [ 5.6,  2.9],
       [ 6.7,  3.1],
       [ 5.6,  3. ],
       [ 5.8,  2.7],
       [ 6.2,  2.2],
       [ 5.6,  2.5],
       [ 5.9,  3.2],
       [ 6.1,  2.8],
       [ 6.3,  2.5],
       [ 6.1,  2.8],
       [ 6.4,  2.9],
       [ 6.6,  3. ],
       [ 6.8,  2.8],
       [ 6.7,  3. ],
       [ 6. ,  2.9],
       [ 5.7,  2.6],
       [ 5.5,  2.4],
       [ 5.5,  2.4],
       [ 5.8,  2.7],
       [ 6. ,  2.7],
       [ 5.4,  3. ],
       [ 6. ,  3.4],
       [ 6.7,  3.1],
       [ 6.3,  2.3],
       [ 5.6,  3. ],
       [ 5.5,  2.5],
       [ 5.5,  2.6],
       [ 6.1,  3. ],
       [ 5.8,  2.6],
       [ 5. ,  2.3],
       [ 5.6,  2.7],
       [ 5.7,  3. ],
       [ 5.7,  2.9],
       [ 6.2,  2.9],
       [ 5.1,  2.5],
       [ 5.7,  2.8],
       [ 6.3,  3.3],
       [ 5.8,  2.7],
       [ 7.1,  3. ],
       [ 6.3,  2.9],
       [ 6.5,  3. ],
       [ 7.6,  3. ],
       [ 4.9,  2.5],
       [ 7.3,  2.9],
       [ 6.7,  2.5],
       [ 7.2,  3.6],
       [ 6.5,  3.2],
       [ 6.4,  2.7],
       [ 6.8,  3. ],
       [ 5.7,  2.5],
       [ 5.8,  2.8],
       [ 6.4,  3.2],
       [ 6.5,  3. ],
       [ 7.7,  3.8],
       [ 7.7,  2.6],
       [ 6. ,  2.2],
       [ 6.9,  3.2],
       [ 5.6,  2.8],
       [ 7.7,  2.8],
       [ 6.3,  2.7],
       [ 6.7,  3.3],
       [ 7.2,  3.2],
       [ 6.2,  2.8],
       [ 6.1,  3. ],
       [ 6.4,  2.8],
       [ 7.2,  3. ],
       [ 7.4,  2.8],
       [ 7.9,  3.8],
       [ 6.4,  2.8],
       [ 6.3,  2.8],
       [ 6.1,  2.6],
       [ 7.7,  3. ],
       [ 6.3,  3.4],
       [ 6.4,  3.1],
       [ 6. ,  3. ],
       [ 6.9,  3.1],
       [ 6.7,  3.1],
       [ 6.9,  3.1],
       [ 5.8,  2.7],
       [ 6.8,  3.2],
       [ 6.7,  3.3],
       [ 6.7,  3. ],
       [ 6.3,  2.5],
       [ 6.5,  3. ],
       [ 6.2,  3.4],
       [ 5.9,  3. ]])
y
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
#创建神经网络,并训练
clf = MLPClassifier(solver='lbfgs',alpha=1e-5,hidden_layer_sizes=(30,20,10),random_state=1)
clf.fit(X,y)
print(clf)
MLPClassifier(activation='relu', alpha=1e-05, batch_size='auto', beta_1=0.9,
       beta_2=0.999, early_stopping=False, epsilon=1e-08,
       hidden_layer_sizes=(30, 20, 10), learning_rate='constant',
       learning_rate_init=0.001, max_iter=200, momentum=0.9,
       nesterovs_momentum=True, power_t=0.5, random_state=1, shuffle=True,
       solver='lbfgs', tol=0.0001, validation_fraction=0.1, verbose=False,
       warm_start=False)
X0, X1 = X[:, 0], X[:, 1]
# 网格点矩阵,生成所有测试样本点
xx, yy = make_meshgrid(X0, X1)
xx   
array([[ 3.3 ,  3.32,  3.34, ...,  8.84,  8.86,  8.88],
       [ 3.3 ,  3.32,  3.34, ...,  8.84,  8.86,  8.88],
       [ 3.3 ,  3.32,  3.34, ...,  8.84,  8.86,  8.88],
       ..., 
       [ 3.3 ,  3.32,  3.34, ...,  8.84,  8.86,  8.88],
       [ 3.3 ,  3.32,  3.34, ...,  8.84,  8.86,  8.88],
       [ 3.3 ,  3.32,  3.34, ...,  8.84,  8.86,  8.88]])
yy
array([[ 1.  ,  1.  ,  1.  , ...,  1.  ,  1.  ,  1.  ],
       [ 1.02,  1.02,  1.02, ...,  1.02,  1.02,  1.02],
       [ 1.04,  1.04,  1.04, ...,  1.04,  1.04,  1.04],
       ..., 
       [ 5.34,  5.34,  5.34, ...,  5.34,  5.34,  5.34],
       [ 5.36,  5.36,  5.36, ...,  5.36,  5.36,  5.36],
       [ 5.38,  5.38,  5.38, ...,  5.38,  5.38,  5.38]])
title = ('MLPClassifier')
fig, ax = plt.subplots(figsize = (5, 5))
plt.subplots_adjust(wspace=0.4, hspace=0.4)
# 显示测试样本的分类结果
plot_test_results(ax, clf, xx, yy, cmap=plt.cm.coolwarm, alpha=0.8)
# 显示训练样本
ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors='k')
ax.set_xlim(xx.min(), xx.max())
ax.set_ylim(yy.min(), yy.max())
ax.set_xlabel('x1')
ax.set_ylabel('x2')
ax.set_xticks(())
ax.set_yticks(())
ax.set_title(title)
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-7jSPGTzF-1587480671986)(output_12_0.png)]

©️2020 CSDN 皮肤主题: 数字20 设计师:CSDN官方博客 返回首页