I.准备
1.import...
In [112]:
import pandas as pd
import warnings
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
plt.rcParams['font.sans-serif'] = ['SimHei'] # 绘图时可以显示中文
plt.rcParams['axes.unicode_minus']=False # 绘图时显示负号
warnings.filterwarnings("ignore") # 不要显示警告
2.read data
In [113]:
cancer = pd.read_excel('C:\\Users\\91333\\Documents\\semester6\\data science\\3.NB&DT\\Week3_CancerDataset.xlsx')
3. 划分训练集测试集
In [114]:
X_train, X_test, y_train, y_test = train_test_split(cancer.iloc[:,0:-1], cancer.iloc[:,-1], test_size=0.4, random_state=99)
4. training the model
In [115]:
units = [1, 10, 100]
args = ['lbfgs', 'sgd', 'adam']
for arg in args:
for unit in units:
ann_model = MLPClassifier(hidden_layer_sizes=[unit], activation='logistic', solver=arg, momentum=0.2,learning_rate_init=0.001, random_state=0)
ann_model.fit(X_train, y_train)
print('神经元个数={:<5},优化算法是{:<6},损失函数值为{:<20},准确率:{:.3f}'.format(unit,arg, ann_model.loss_, ann_model.score(X_test, y_test)))
神经元个数=1 ,优化算法是lbfgs ,损失函数值为0.6725333173221658 ,准确率:0.667
神经元个数=10 ,优化算法是lbfgs ,损失函数值为0.11928464006120038 ,准确率:0.956
神经元个数=100 ,优化算法是lbfgs ,损失函数值为0.16189482436631242 ,准确率:0.860
神经元个数=1 ,优化算法是sgd ,损失函数值为0.6728320345604173 ,准确率:0.667
神经元个数=10 ,优化算法是sgd ,损失函数值为0.6851164704109012 ,准确率:0.667
神经元个数=100 ,优化算法是sgd ,损失函数值为0.46615206207580406 ,准确率:0.895
神经元个数=1 ,优化算法是adam ,损失函数值为0.6727370689888383 ,准确率:0.667
神经元个数=10 ,优化算法是adam ,损失函数值为0.4302821710780715 ,准确率:0.917
神经元个数=100 ,优化算法是adam ,损失函数值为0.12622508759611045 ,准确率:0.925
神经元个数为10,优化算法为lbfg时,测试集准确率最高,为0.956,损失函数值为0.119,模型共三层,输入层、一个隐藏层、一个输出层。 事实上,lbfgs, sgd, adam中,adam适用于较大的数据集,lbfgs适用于较小的数据集。 激活函数有relu, logistic, tanh, identity,这里是二分类问题,只能使用logistic,即sigmoid函数。
5.看一下训练出的参数
In [116]:
ann_model = MLPClassifier(hidden_layer_sizes=[10], activation='logistic', solver='lbfgs', momentum=0.2,learning_rate_init=0.001, random_state=0)
ann_model.fit(X_train, y_train)
W = ann_model.coefs_
b = ann_model.intercepts_
print(len(W))
2
1) 第二层(隐藏层)
In [117]:
W[0]
Out[117]:
array([[ 2.17986487e-02, 9.63614528e-02, 6.89647432e-03,
2.00429978e-02, -3.40926516e-02, 6.51511372e-02,
-2.85450149e-02, 1.75485015e-01, 4.09236456e-01,
-5.20502619e-02],
[ 1.30273654e-01, 1.31254603e-02, -6.65523892e-02,
1.90054093e-01, -1.91557794e-01, -1.84394875e-01,
-2.15355881e-01, 1.49614721e-01, 3.36722737e-02,
1.65232328e-01],
[ 2.13734746e-01, 1.35223811e-01, -2.63719148e-01,
1.25272882e-01, -1.70463066e-01, 6.24949011e-02,
-1.63549231e-01, 2.01932936e-01, 1.31287833e-01,
-3.81084956e-02],
[-1.05128815e-01, 1.41100252e-01, -9.35730369e-01,
3.05598016e-02, -2.14888841e-01, 5.37372601e-02,
3.15498927e-02, 6.35012773e-02, -6.52857414e-03,
8.11935271e-02],
[-6.27380310e-02, -2.81184057e-02, 8.77604885e-02,
-1.96385361e-01, 7.44712108e-02, 7.61996763e-02,
-1.29339190e-01, -1.65697847e-01, -8.96392519e-02,
-6.08612089e-02],
[ 3.13470197e-02, -2.74197088e-02, 2.17608915e-01,
-1.77710550e-01, -1.30003762e-01, -1.51245513e-01,
6.83660346e-02, -1.10163971e-01, -4.64520865e-02,
-1.14129103e-01],
[-1.52290268e-01, -1.73993277e-01, 6.95356867e-02,
-1.61572731e-01, -1.35493937e-01, -5.86219348e-02,
1.43340515e-01, -1.79917357e-01, 1.06875978e-01,
-1.80365971e-01],
[ 2.12767357e-01, -1.39995567e-02, 2.12852705e-01,
4.68197312e-02, 1.06845352e-01, -2.05779836e-01,
-9.69904599e-02, -1.69604580e-01, -1.09102208e-01,
-1.70260644e-01],
[-8.12812710e-02, -3.82858664e-02, -1.95616528e-01,
8.59501953e-02, 2.97414920e-02, -1.04767927e-01,
1.03675971e-02, -1.81315126e-01, 2.20665474e-02,
1.91706166e-01],
[-8.10196973e-02, 7.47591298e-02, -1.64806035e-01,
9.66029028e-02, -9.40426463e-02, -1.41474005e-01,
3.86276157e-02, -2.14294551e-01, 1.44968641e-01,
-2.21182791e-01],
[ 7.94056268e-02, -1.02677925e-01, 1.03932033e-01,
2.06394546e-01, -1.12196595e-01, 3.40094463e-02,
4.10826905e-02, 3.22814714e-02, -1.18232715e-01,
2.02179236e-01],
[-2.36115365e-02, 1.54694262e-01, 8.13866651e-02,
-9.04563929e-02, 1.40129303e-01, -4.62195817e-02,
1.70064845e-01, 3.64048295e-02, 2.32282629e-01,
8.59767525e-02],
[ 1.00589485e-01, 7.69591210e-04, 1.96245965e-01,
6.43001478e-02, -3.40032285e-02, 4.75153172e-02,
-2.14847245e-01, -8.84933860e-02, 8.52082453e-03,
-9.37427754e-02],
[ 5.27016326e-02, -2.70767046e-02, -2.18714086e-01,
-9.00788833e-02, 3.12434744e-02, 4.07861397e-02,
3.21221264e-02, 6.91620567e-02, -3.68420939e-02,
-3.06257285e-02],
[ 1.77081531e-01, -5.91414997e-02, -2.87015814e-02,
1.75016979e-01, 1.36733742e-01, 9.10482861e-02,
-1.78523164e-01, 1.87324619e-01, 9.51748079e-02,
2.22764719e-01],
[-1.56542083e-01, 1.64389648e-01, -1.50899824e-01,
5.16041864e-02, -1.67986647e-01, 1.55406186e-01,
1.37234601e-01, 3.08587502e-02, -4.67524146e-02,
-1.92392441e-01],
[ 8.81636358e-02, -2.07464757e-02, 9.89250594e-02,
1.63611398e-01, 2.12348502e-01, 1.58887155e-01,
-2.18049970e-01, -6.25273696e-02, 9.32542991e-02,
-1.46636788e-01],
[ 9.39409018e-03, -1.99014452e-01, -1.34006892e-01,
-2.15008521e-01, 1.31153411e-01, -1.23283977e-01,
-6.90601447e-02, 1.91163759e-01, 8.82996029e-02,
-2.09061632e-01],
[-1.49733907e-01, 5.42472152e-02, 3.43671489e-02,
-1.17046372e-01, 1.93902255e-01, 5.08924779e-02,
1.59099825e-02, 4.01520784e-02, 1.01104992e-01,
-8.39776921e-02],
[-4.54503205e-02, -1.29571988e-01, -1.40164579e-01,
1.98438578e-01, 1.06973611e-01, -4.26072311e-03,
-1.21725847e-01, -1.09694037e-01, -1.97909465e-01,
-2.92868591e-02],
[-8.40436793e-02, 8.81798079e-02, -9.82333240e-02,
-1.43075925e-01, -2.12259086e-01, -1.93236027e-01,
7.93779503e-02, -2.00730339e-02, 2.37778766e-01,
1.77137215e-01],
[ 2.18966982e-01, -1.26097182e-01, -5.64045637e-02,
-1.05690571e-01, -2.14057705e-01, 1.15356124e-01,
-8.18017107e-02, -5.05067796e-02, -1.18384203e-01,
1.47832733e-01],
[ 5.76018390e-02, 1.69582594e-01, -3.78126633e-01,
1.33095552e-01, -1.40382161e-01, 2.02289479e-01,
7.90347121e-02, -1.23209437e-01, -1.20663615e-01,
1.03090784e-01],
[-1.09865325e-01, -7.09408849e-02, -1.14796862e+00,
-2.11819686e-01, -1.30631929e-01, -3.00677651e-02,
-7.78124436e-02, -1.96728413e-03, -2.94984499e-03,
3.87543483e-02],
[ 1.62483080e-01, -1.70794073e-01, 6.98862148e-03,
-1.64303371e-01, 9.68406854e-02, -4.64158512e-02,
2.92039379e-02, -1.41422430e-01, -1.71338760e-01,
-5.33357774e-03],
[-6.44774487e-02, 1.96674887e-01, 1.17343912e-01,
1.11043027e-01, 1.80284763e-01, -1.86026859e-01,
2.32967244e-02, 3.77339099e-02, 1.05139018e-01,
-9.28184338e-02],
[-1.15735288e-01, -1.78497442e-01, -2.16996590e-01,
1.91810267e-01, 7.58777968e-02, 1.27337626e-01,
-9.74761035e-02, 3.85901420e-02, -3.30720723e-01,
-6.41812924e-03],
[ 2.13229847e-01, 1.68132412e-01, -7.24550899e-02,
2.06118398e-01, -1.19811107e-01, 2.00647601e-01,
1.97098893e-01, 1.33612222e-01, 1.83204938e-02,
1.67141734e-01],
[-9.24286741e-02, 1.55824660e-01, 5.11902804e-02,
-2.17368559e-01, -6.82192778e-02, -1.57126548e-01,
2.15144456e-01, -9.63705373e-03, -2.89222433e-02,
6.22827351e-02],
[-5.86847458e-02, -1.62145276e-01, 1.43329217e-01,
-1.38501268e-01, 5.05459577e-03, -1.23109001e-01,
-1.79592866e-01, 1.61747092e-01, 2.00764927e-01,
2.05789955e-01]])
In [118]:
W[0].shape
Out[118]:
(30, 10)
In [119]:
b[0]
Out[119]:
array([ 0.18181722, 0.12257258, -0.07959383, -0.18733715, -0.04148301,
-0.11975069, -0.16443246, -0.1996385 , 0.13287139, -0.21849628])
In [120]:
b[0].shape
Out[120]:
(10,)
第一层权重矩阵为3010,截距向量为101,输入一个样本的30个feature,输出10个隐藏层unit值。
2)第三层(输出层)
In [121]:
W[1]
Out[121]:
array([[ 0.23041471],
[-1.09548145],
[-0.34417969],
[-0.34947558],
[ 0.14650817],
[-1.00328244],
[-0.0665444 ],
[-0.73831002],
[14.80089679],
[-0.59309476]])
In [122]:
b[1]
Out[122]:
array([-0.98222453])
第二层权重矩阵为101, 截距项为11,输入一个样本的10个第一层激活值,得到第二层的单个数值。