预先导入数据
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
housing=fetch_california_housing()
scaler=StandardScaler()
x_data=scaler.fit_transform(housing.data)
x_train_full,x_test,y_train_full,y_test=train_test_split(x_data,housing.target)
x_train,x_valid,y_train,y_valid=train_test_split(x_train_full,y_train_full)
建立模型
- 这是一个多输出的模型,包含一个主要输出和辅助输出。主要输出维度1维,与标签求
mse
。辅助输出维度为8维(输入特征也是8维)与输入特征求’mse’,主要让hidden2
层的经过second_output
辅助输出层后与输入的特征值相差不大,是一个正则化过程。 - 同时这里还包括了一个合并层,将最开始的输入
input_
层与hidden2
层合并,再经过一个主要输出层main_output
后与标签的mse
较小。 - 这是一个随便设置的、规模较小的神经网络,如果不清楚的,可以自己画出网络图。
input_ = tf.keras.layers.Input(shape=[8],name='input_')
hidden1 = tf.keras.layers.Dense(15,activation='elu',kernel_initializer='he_normal',name='hidden1')(input_)
hidden2 = tf.keras.layers.Dense(15,activation='elu',kernel_initializer='he_normal',name='hidden2')(hidden1)
second_output = tf.keras.layers.Dense(8,name='second_output')(hidden2)
concat = tf.keras.layers.Concatenate()([hidden2,input_])
main_output = tf.keras.layers.Dense(1,name='main_output')(concat)
model = tf.keras.Model(inputs=[input_],outputs=[main_output,second_output])
model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_ (InputLayer) [(None, 8)] 0 []
hidden1 (Dense) (None, 15) 135 ['input_[0][0]']
hidden2 (Dense) (None, 15) 240 ['hidden1[0][0]']
concatenate (Concatenate) (None, 23) 0 ['hidden2[0][0]',
'input_[0][0]']
main_output (Dense) (None, 1) 24 ['concatenate[0][0]']
second_output (Dense) (None, 8) 128 ['hidden2[0][0]']
==================================================================================================
Total params: 527
Trainable params: 527
Non-trainable params: 0
__________________________________________________________________________________________________
模型编译
- 这里设置了两个一样的损失函数,分别对应两个输出(主要输出和辅助输出),并且设置了损失的权重,主要损失占0.8,次要损失占0.2。
- 为了加快训练速率,先设置一个较大的学习率,同时动量优化
momentum=0.9
,(不清楚的可以看一下同一分栏的笔记)。 - 为了让训练的方向更准确,加了
nesterov
加速梯度,同时设置了clipnorm=1
梯度裁剪。 - 为了让最终接近目标时减小震荡,设置了学习率幂调度
decay=1.0/100
。
model.compile(loss=['mse','mse']
,loss_weights=[0.8,0.2]
,optimizer=tf.keras.optimizers.SGD(learning_rate=0.05,momentum=0.9,nesterov=True,clipnorm=1,decay=1.0/100)
)
模型训练
- 这里设置了提前停止,防止过拟合。
- 注意:按照上面建立模型所言,这里将
(y_train,x_train)
放在标签的位置,validation_data
同样的道理。 - 最后进行模型评估。
earlystop = tf.keras.callbacks.EarlyStopping(patience=5,restore_best_weights=True)
history=model.fit(x_train,(y_train,x_train),epochs=100,validation_data=(x_valid,(y_valid,x_valid)),callbacks=[earlystop])
model.evaluate(x_test,(y_test,x_test))
Epoch 1/100
363/363 [==============================] - 4s 6ms/step - loss: 0.6339 - main_output_loss: 0.6846 - second_output_loss: 0.4311 - val_loss: 0.3725 - val_main_output_loss: 0.4425 - val_second_output_loss: 0.0926
Epoch 2/100
363/363 [==============================] - 2s 6ms/step - loss: 0.4054 - main_output_loss: 0.4276 - second_output_loss: 0.3164 - val_loss: 0.3457 - val_main_output_loss: 0.4130 - val_second_output_loss: 0.0767
Epoch 3/100
363/363 [==============================] - 3s 7ms/step - loss: 0.3458 - main_output_loss: 0.3705 - second_output_loss: 0.2469 - val_loss: 0.3213 - val_main_output_loss: 0.3840 - val_second_output_loss: 0.0706
Epoch 4/100
363/363 [==============================] - 3s 8ms/step - loss: 0.3185 - main_output_loss: 0.3549 - second_output_loss: 0.1729 - val_loss: 0.3210 - val_main_output_loss: 0.3850 - val_second_output_loss: 0.0653
Epoch 5/100
363/363 [==============================] - 3s 8ms/step - loss: 0.3202 - main_output_loss: 0.3662 - second_output_loss: 0.1364 - val_loss: 0.3174 - val_main_output_loss: 0.3814 - val_second_output_loss: 0.0614
Epoch 6/100
363/363 [==============================] - 3s 7ms/step - loss: 0.3141 - main_output_loss: 0.3692 - second_output_loss: 0.0937 - val_loss: 0.3212 - val_main_output_loss: 0.3863 - val_second_output_loss: 0.0607
Epoch 7/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2964 - main_output_loss: 0.3503 - second_output_loss: 0.0806 - val_loss: 0.3154 - val_main_output_loss: 0.3788 - val_second_output_loss: 0.0616
Epoch 8/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2862 - main_output_loss: 0.3393 - second_output_loss: 0.0739 - val_loss: 0.3089 - val_main_output_loss: 0.3711 - val_second_output_loss: 0.0599
Epoch 9/100
363/363 [==============================] - 3s 7ms/step - loss: 0.3075 - main_output_loss: 0.3677 - second_output_loss: 0.0665 - val_loss: 0.3092 - val_main_output_loss: 0.3719 - val_second_output_loss: 0.0587
Epoch 10/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2824 - main_output_loss: 0.3379 - second_output_loss: 0.0605 - val_loss: 0.3081 - val_main_output_loss: 0.3707 - val_second_output_loss: 0.0580
Epoch 11/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2864 - main_output_loss: 0.3432 - second_output_loss: 0.0592 - val_loss: 0.3083 - val_main_output_loss: 0.3712 - val_second_output_loss: 0.0564
Epoch 12/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2844 - main_output_loss: 0.3416 - second_output_loss: 0.0557 - val_loss: 0.3091 - val_main_output_loss: 0.3720 - val_second_output_loss: 0.0575
Epoch 13/100
363/363 [==============================] - 3s 7ms/step - loss: 0.2845 - main_output_loss: 0.3422 - second_output_loss: 0.0537 - val_loss: 0.3097 - val_main_output_loss: 0.3730 - val_second_output_loss: 0.0563
Epoch 14/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2869 - main_output_loss: 0.3454 - second_output_loss: 0.0527 - val_loss: 0.3100 - val_main_output_loss: 0.3737 - val_second_output_loss: 0.0554
Epoch 15/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2844 - main_output_loss: 0.3428 - second_output_loss: 0.0509 - val_loss: 0.3046 - val_main_output_loss: 0.3671 - val_second_output_loss: 0.0547
Epoch 16/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2811 - main_output_loss: 0.3388 - second_output_loss: 0.0500 - val_loss: 0.3061 - val_main_output_loss: 0.3689 - val_second_output_loss: 0.0548
Epoch 17/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2831 - main_output_loss: 0.3416 - second_output_loss: 0.0492 - val_loss: 0.3057 - val_main_output_loss: 0.3682 - val_second_output_loss: 0.0554
Epoch 18/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2799 - main_output_loss: 0.3377 - second_output_loss: 0.0489 - val_loss: 0.3044 - val_main_output_loss: 0.3669 - val_second_output_loss: 0.0543
Epoch 19/100
363/363 [==============================] - 3s 9ms/step - loss: 0.2778 - main_output_loss: 0.3352 - second_output_loss: 0.0479 - val_loss: 0.3100 - val_main_output_loss: 0.3739 - val_second_output_loss: 0.0545
Epoch 20/100
363/363 [==============================] - 3s 7ms/step - loss: 0.2873 - main_output_loss: 0.3473 - second_output_loss: 0.0471 - val_loss: 0.3055 - val_main_output_loss: 0.3684 - val_second_output_loss: 0.0538
Epoch 21/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2786 - main_output_loss: 0.3366 - second_output_loss: 0.0469 - val_loss: 0.3077 - val_main_output_loss: 0.3712 - val_second_output_loss: 0.0538
Epoch 22/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2827 - main_output_loss: 0.3418 - second_output_loss: 0.0463 - val_loss: 0.3042 - val_main_output_loss: 0.3670 - val_second_output_loss: 0.0529
Epoch 23/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2788 - main_output_loss: 0.3369 - second_output_loss: 0.0462 - val_loss: 0.3063 - val_main_output_loss: 0.3697 - val_second_output_loss: 0.0527
Epoch 24/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2784 - main_output_loss: 0.3365 - second_output_loss: 0.0460 - val_loss: 0.3031 - val_main_output_loss: 0.3658 - val_second_output_loss: 0.0523
Epoch 25/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2787 - main_output_loss: 0.3370 - second_output_loss: 0.0456 - val_loss: 0.3023 - val_main_output_loss: 0.3648 - val_second_output_loss: 0.0524
Epoch 26/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2796 - main_output_loss: 0.3382 - second_output_loss: 0.0450 - val_loss: 0.3034 - val_main_output_loss: 0.3662 - val_second_output_loss: 0.0522
Epoch 27/100
363/363 [==============================] - 3s 7ms/step - loss: 0.2747 - main_output_loss: 0.3321 - second_output_loss: 0.0454 - val_loss: 0.3035 - val_main_output_loss: 0.3665 - val_second_output_loss: 0.0517
Epoch 28/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2786 - main_output_loss: 0.3371 - second_output_loss: 0.0447 - val_loss: 0.3024 - val_main_output_loss: 0.3651 - val_second_output_loss: 0.0516
Epoch 29/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2759 - main_output_loss: 0.3336 - second_output_loss: 0.0450 - val_loss: 0.3012 - val_main_output_loss: 0.3636 - val_second_output_loss: 0.0517
Epoch 30/100
363/363 [==============================] - 3s 7ms/step - loss: 0.2745 - main_output_loss: 0.3318 - second_output_loss: 0.0451 - val_loss: 0.3043 - val_main_output_loss: 0.3675 - val_second_output_loss: 0.0513
Epoch 31/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2747 - main_output_loss: 0.3323 - second_output_loss: 0.0446 - val_loss: 0.3023 - val_main_output_loss: 0.3650 - val_second_output_loss: 0.0516
Epoch 32/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2741 - main_output_loss: 0.3315 - second_output_loss: 0.0446 - val_loss: 0.3015 - val_main_output_loss: 0.3640 - val_second_output_loss: 0.0511
Epoch 33/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2749 - main_output_loss: 0.3326 - second_output_loss: 0.0443 - val_loss: 0.3008 - val_main_output_loss: 0.3633 - val_second_output_loss: 0.0507
Epoch 34/100
363/363 [==============================] - 3s 7ms/step - loss: 0.2738 - main_output_loss: 0.3312 - second_output_loss: 0.0442 - val_loss: 0.3008 - val_main_output_loss: 0.3634 - val_second_output_loss: 0.0507
Epoch 35/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2738 - main_output_loss: 0.3313 - second_output_loss: 0.0440 - val_loss: 0.3013 - val_main_output_loss: 0.3639 - val_second_output_loss: 0.0509
Epoch 36/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2760 - main_output_loss: 0.3340 - second_output_loss: 0.0436 - val_loss: 0.3002 - val_main_output_loss: 0.3626 - val_second_output_loss: 0.0505
Epoch 37/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2724 - main_output_loss: 0.3296 - second_output_loss: 0.0438 - val_loss: 0.3009 - val_main_output_loss: 0.3635 - val_second_output_loss: 0.0504
Epoch 38/100
363/363 [==============================] - 3s 7ms/step - loss: 0.2725 - main_output_loss: 0.3297 - second_output_loss: 0.0437 - val_loss: 0.2998 - val_main_output_loss: 0.3622 - val_second_output_loss: 0.0504
Epoch 39/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2738 - main_output_loss: 0.3313 - second_output_loss: 0.0437 - val_loss: 0.3004 - val_main_output_loss: 0.3629 - val_second_output_loss: 0.0501
Epoch 40/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2743 - main_output_loss: 0.3320 - second_output_loss: 0.0433 - val_loss: 0.3016 - val_main_output_loss: 0.3644 - val_second_output_loss: 0.0502
Epoch 41/100
363/363 [==============================] - 3s 7ms/step - loss: 0.2738 - main_output_loss: 0.3315 - second_output_loss: 0.0432 - val_loss: 0.3003 - val_main_output_loss: 0.3629 - val_second_output_loss: 0.0501
Epoch 42/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2721 - main_output_loss: 0.3293 - second_output_loss: 0.0433 - val_loss: 0.2987 - val_main_output_loss: 0.3610 - val_second_output_loss: 0.0499
Epoch 43/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2728 - main_output_loss: 0.3302 - second_output_loss: 0.0431 - val_loss: 0.3011 - val_main_output_loss: 0.3639 - val_second_output_loss: 0.0499
Epoch 44/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2725 - main_output_loss: 0.3299 - second_output_loss: 0.0430 - val_loss: 0.3011 - val_main_output_loss: 0.3640 - val_second_output_loss: 0.0495
Epoch 45/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2707 - main_output_loss: 0.3277 - second_output_loss: 0.0431 - val_loss: 0.3016 - val_main_output_loss: 0.3645 - val_second_output_loss: 0.0498
Epoch 46/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2728 - main_output_loss: 0.3303 - second_output_loss: 0.0428 - val_loss: 0.3013 - val_main_output_loss: 0.3642 - val_second_output_loss: 0.0495
Epoch 47/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2719 - main_output_loss: 0.3292 - second_output_loss: 0.0427 - val_loss: 0.3006 - val_main_output_loss: 0.3634 - val_second_output_loss: 0.0494
162/162 [==============================] - 1s 3ms/step - loss: 0.3038 - main_output_loss: 0.3699 - second_output_loss: 0.0391
[0.3037593960762024, 0.36992380023002625, 0.039101772010326385]
绘制图像
import pandas as pd
pd.DataFrame(history.history).plot(figsize=(18,10))