tensorflow2笔记:双损失函数(同时两个损失函数)

预先导入数据

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
housing=fetch_california_housing()
scaler=StandardScaler()
x_data=scaler.fit_transform(housing.data)
x_train_full,x_test,y_train_full,y_test=train_test_split(x_data,housing.target)
x_train,x_valid,y_train,y_valid=train_test_split(x_train_full,y_train_full)

建立模型

  • 这是一个多输出的模型,包含一个主要输出和辅助输出。主要输出维度1维,与标签求mse。辅助输出维度为8维(输入特征也是8维)与输入特征求’mse’,主要让hidden2层的经过second_output辅助输出层后与输入的特征值相差不大,是一个正则化过程。
  • 同时这里还包括了一个合并层,将最开始的输入input_层与hidden2层合并,再经过一个主要输出层main_output后与标签的mse较小。
  • 这是一个随便设置的、规模较小的神经网络,如果不清楚的,可以自己画出网络图。
input_ = tf.keras.layers.Input(shape=[8],name='input_')
hidden1 = tf.keras.layers.Dense(15,activation='elu',kernel_initializer='he_normal',name='hidden1')(input_)
hidden2 = tf.keras.layers.Dense(15,activation='elu',kernel_initializer='he_normal',name='hidden2')(hidden1)
second_output = tf.keras.layers.Dense(8,name='second_output')(hidden2)
concat = tf.keras.layers.Concatenate()([hidden2,input_])
main_output = tf.keras.layers.Dense(1,name='main_output')(concat)
model = tf.keras.Model(inputs=[input_],outputs=[main_output,second_output])
model.summary()
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_ (InputLayer)            [(None, 8)]          0           []                               
                                                                                                  
 hidden1 (Dense)                (None, 15)           135         ['input_[0][0]']                 
                                                                                                  
 hidden2 (Dense)                (None, 15)           240         ['hidden1[0][0]']                
                                                                                                  
 concatenate (Concatenate)      (None, 23)           0           ['hidden2[0][0]',                
                                                                  'input_[0][0]']                 
                                                                                                  
 main_output (Dense)            (None, 1)            24          ['concatenate[0][0]']            
                                                                                                  
 second_output (Dense)          (None, 8)            128         ['hidden2[0][0]']                
                                                                                                  
==================================================================================================
Total params: 527
Trainable params: 527
Non-trainable params: 0
__________________________________________________________________________________________________

模型编译

  • 这里设置了两个一样的损失函数,分别对应两个输出(主要输出和辅助输出),并且设置了损失的权重,主要损失占0.8,次要损失占0.2。
  • 为了加快训练速率,先设置一个较大的学习率,同时动量优化momentum=0.9,(不清楚的可以看一下同一分栏的笔记)。
  • 为了让训练的方向更准确,加了nesterov加速梯度,同时设置了clipnorm=1梯度裁剪。
  • 为了让最终接近目标时减小震荡,设置了学习率幂调度decay=1.0/100
model.compile(loss=['mse','mse']
             ,loss_weights=[0.8,0.2]
             ,optimizer=tf.keras.optimizers.SGD(learning_rate=0.05,momentum=0.9,nesterov=True,clipnorm=1,decay=1.0/100)
             )

模型训练

  • 这里设置了提前停止,防止过拟合。
  • 注意:按照上面建立模型所言,这里将(y_train,x_train)放在标签的位置,validation_data同样的道理。
  • 最后进行模型评估。
earlystop = tf.keras.callbacks.EarlyStopping(patience=5,restore_best_weights=True)
history=model.fit(x_train,(y_train,x_train),epochs=100,validation_data=(x_valid,(y_valid,x_valid)),callbacks=[earlystop])
model.evaluate(x_test,(y_test,x_test))
Epoch 1/100
363/363 [==============================] - 4s 6ms/step - loss: 0.6339 - main_output_loss: 0.6846 - second_output_loss: 0.4311 - val_loss: 0.3725 - val_main_output_loss: 0.4425 - val_second_output_loss: 0.0926
Epoch 2/100
363/363 [==============================] - 2s 6ms/step - loss: 0.4054 - main_output_loss: 0.4276 - second_output_loss: 0.3164 - val_loss: 0.3457 - val_main_output_loss: 0.4130 - val_second_output_loss: 0.0767
Epoch 3/100
363/363 [==============================] - 3s 7ms/step - loss: 0.3458 - main_output_loss: 0.3705 - second_output_loss: 0.2469 - val_loss: 0.3213 - val_main_output_loss: 0.3840 - val_second_output_loss: 0.0706
Epoch 4/100
363/363 [==============================] - 3s 8ms/step - loss: 0.3185 - main_output_loss: 0.3549 - second_output_loss: 0.1729 - val_loss: 0.3210 - val_main_output_loss: 0.3850 - val_second_output_loss: 0.0653
Epoch 5/100
363/363 [==============================] - 3s 8ms/step - loss: 0.3202 - main_output_loss: 0.3662 - second_output_loss: 0.1364 - val_loss: 0.3174 - val_main_output_loss: 0.3814 - val_second_output_loss: 0.0614
Epoch 6/100
363/363 [==============================] - 3s 7ms/step - loss: 0.3141 - main_output_loss: 0.3692 - second_output_loss: 0.0937 - val_loss: 0.3212 - val_main_output_loss: 0.3863 - val_second_output_loss: 0.0607
Epoch 7/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2964 - main_output_loss: 0.3503 - second_output_loss: 0.0806 - val_loss: 0.3154 - val_main_output_loss: 0.3788 - val_second_output_loss: 0.0616
Epoch 8/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2862 - main_output_loss: 0.3393 - second_output_loss: 0.0739 - val_loss: 0.3089 - val_main_output_loss: 0.3711 - val_second_output_loss: 0.0599
Epoch 9/100
363/363 [==============================] - 3s 7ms/step - loss: 0.3075 - main_output_loss: 0.3677 - second_output_loss: 0.0665 - val_loss: 0.3092 - val_main_output_loss: 0.3719 - val_second_output_loss: 0.0587
Epoch 10/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2824 - main_output_loss: 0.3379 - second_output_loss: 0.0605 - val_loss: 0.3081 - val_main_output_loss: 0.3707 - val_second_output_loss: 0.0580
Epoch 11/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2864 - main_output_loss: 0.3432 - second_output_loss: 0.0592 - val_loss: 0.3083 - val_main_output_loss: 0.3712 - val_second_output_loss: 0.0564
Epoch 12/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2844 - main_output_loss: 0.3416 - second_output_loss: 0.0557 - val_loss: 0.3091 - val_main_output_loss: 0.3720 - val_second_output_loss: 0.0575
Epoch 13/100
363/363 [==============================] - 3s 7ms/step - loss: 0.2845 - main_output_loss: 0.3422 - second_output_loss: 0.0537 - val_loss: 0.3097 - val_main_output_loss: 0.3730 - val_second_output_loss: 0.0563
Epoch 14/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2869 - main_output_loss: 0.3454 - second_output_loss: 0.0527 - val_loss: 0.3100 - val_main_output_loss: 0.3737 - val_second_output_loss: 0.0554
Epoch 15/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2844 - main_output_loss: 0.3428 - second_output_loss: 0.0509 - val_loss: 0.3046 - val_main_output_loss: 0.3671 - val_second_output_loss: 0.0547
Epoch 16/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2811 - main_output_loss: 0.3388 - second_output_loss: 0.0500 - val_loss: 0.3061 - val_main_output_loss: 0.3689 - val_second_output_loss: 0.0548
Epoch 17/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2831 - main_output_loss: 0.3416 - second_output_loss: 0.0492 - val_loss: 0.3057 - val_main_output_loss: 0.3682 - val_second_output_loss: 0.0554
Epoch 18/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2799 - main_output_loss: 0.3377 - second_output_loss: 0.0489 - val_loss: 0.3044 - val_main_output_loss: 0.3669 - val_second_output_loss: 0.0543
Epoch 19/100
363/363 [==============================] - 3s 9ms/step - loss: 0.2778 - main_output_loss: 0.3352 - second_output_loss: 0.0479 - val_loss: 0.3100 - val_main_output_loss: 0.3739 - val_second_output_loss: 0.0545
Epoch 20/100
363/363 [==============================] - 3s 7ms/step - loss: 0.2873 - main_output_loss: 0.3473 - second_output_loss: 0.0471 - val_loss: 0.3055 - val_main_output_loss: 0.3684 - val_second_output_loss: 0.0538
Epoch 21/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2786 - main_output_loss: 0.3366 - second_output_loss: 0.0469 - val_loss: 0.3077 - val_main_output_loss: 0.3712 - val_second_output_loss: 0.0538
Epoch 22/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2827 - main_output_loss: 0.3418 - second_output_loss: 0.0463 - val_loss: 0.3042 - val_main_output_loss: 0.3670 - val_second_output_loss: 0.0529
Epoch 23/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2788 - main_output_loss: 0.3369 - second_output_loss: 0.0462 - val_loss: 0.3063 - val_main_output_loss: 0.3697 - val_second_output_loss: 0.0527
Epoch 24/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2784 - main_output_loss: 0.3365 - second_output_loss: 0.0460 - val_loss: 0.3031 - val_main_output_loss: 0.3658 - val_second_output_loss: 0.0523
Epoch 25/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2787 - main_output_loss: 0.3370 - second_output_loss: 0.0456 - val_loss: 0.3023 - val_main_output_loss: 0.3648 - val_second_output_loss: 0.0524
Epoch 26/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2796 - main_output_loss: 0.3382 - second_output_loss: 0.0450 - val_loss: 0.3034 - val_main_output_loss: 0.3662 - val_second_output_loss: 0.0522
Epoch 27/100
363/363 [==============================] - 3s 7ms/step - loss: 0.2747 - main_output_loss: 0.3321 - second_output_loss: 0.0454 - val_loss: 0.3035 - val_main_output_loss: 0.3665 - val_second_output_loss: 0.0517
Epoch 28/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2786 - main_output_loss: 0.3371 - second_output_loss: 0.0447 - val_loss: 0.3024 - val_main_output_loss: 0.3651 - val_second_output_loss: 0.0516
Epoch 29/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2759 - main_output_loss: 0.3336 - second_output_loss: 0.0450 - val_loss: 0.3012 - val_main_output_loss: 0.3636 - val_second_output_loss: 0.0517
Epoch 30/100
363/363 [==============================] - 3s 7ms/step - loss: 0.2745 - main_output_loss: 0.3318 - second_output_loss: 0.0451 - val_loss: 0.3043 - val_main_output_loss: 0.3675 - val_second_output_loss: 0.0513
Epoch 31/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2747 - main_output_loss: 0.3323 - second_output_loss: 0.0446 - val_loss: 0.3023 - val_main_output_loss: 0.3650 - val_second_output_loss: 0.0516
Epoch 32/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2741 - main_output_loss: 0.3315 - second_output_loss: 0.0446 - val_loss: 0.3015 - val_main_output_loss: 0.3640 - val_second_output_loss: 0.0511
Epoch 33/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2749 - main_output_loss: 0.3326 - second_output_loss: 0.0443 - val_loss: 0.3008 - val_main_output_loss: 0.3633 - val_second_output_loss: 0.0507
Epoch 34/100
363/363 [==============================] - 3s 7ms/step - loss: 0.2738 - main_output_loss: 0.3312 - second_output_loss: 0.0442 - val_loss: 0.3008 - val_main_output_loss: 0.3634 - val_second_output_loss: 0.0507
Epoch 35/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2738 - main_output_loss: 0.3313 - second_output_loss: 0.0440 - val_loss: 0.3013 - val_main_output_loss: 0.3639 - val_second_output_loss: 0.0509
Epoch 36/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2760 - main_output_loss: 0.3340 - second_output_loss: 0.0436 - val_loss: 0.3002 - val_main_output_loss: 0.3626 - val_second_output_loss: 0.0505
Epoch 37/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2724 - main_output_loss: 0.3296 - second_output_loss: 0.0438 - val_loss: 0.3009 - val_main_output_loss: 0.3635 - val_second_output_loss: 0.0504
Epoch 38/100
363/363 [==============================] - 3s 7ms/step - loss: 0.2725 - main_output_loss: 0.3297 - second_output_loss: 0.0437 - val_loss: 0.2998 - val_main_output_loss: 0.3622 - val_second_output_loss: 0.0504
Epoch 39/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2738 - main_output_loss: 0.3313 - second_output_loss: 0.0437 - val_loss: 0.3004 - val_main_output_loss: 0.3629 - val_second_output_loss: 0.0501
Epoch 40/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2743 - main_output_loss: 0.3320 - second_output_loss: 0.0433 - val_loss: 0.3016 - val_main_output_loss: 0.3644 - val_second_output_loss: 0.0502
Epoch 41/100
363/363 [==============================] - 3s 7ms/step - loss: 0.2738 - main_output_loss: 0.3315 - second_output_loss: 0.0432 - val_loss: 0.3003 - val_main_output_loss: 0.3629 - val_second_output_loss: 0.0501
Epoch 42/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2721 - main_output_loss: 0.3293 - second_output_loss: 0.0433 - val_loss: 0.2987 - val_main_output_loss: 0.3610 - val_second_output_loss: 0.0499
Epoch 43/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2728 - main_output_loss: 0.3302 - second_output_loss: 0.0431 - val_loss: 0.3011 - val_main_output_loss: 0.3639 - val_second_output_loss: 0.0499
Epoch 44/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2725 - main_output_loss: 0.3299 - second_output_loss: 0.0430 - val_loss: 0.3011 - val_main_output_loss: 0.3640 - val_second_output_loss: 0.0495
Epoch 45/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2707 - main_output_loss: 0.3277 - second_output_loss: 0.0431 - val_loss: 0.3016 - val_main_output_loss: 0.3645 - val_second_output_loss: 0.0498
Epoch 46/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2728 - main_output_loss: 0.3303 - second_output_loss: 0.0428 - val_loss: 0.3013 - val_main_output_loss: 0.3642 - val_second_output_loss: 0.0495
Epoch 47/100
363/363 [==============================] - 3s 8ms/step - loss: 0.2719 - main_output_loss: 0.3292 - second_output_loss: 0.0427 - val_loss: 0.3006 - val_main_output_loss: 0.3634 - val_second_output_loss: 0.0494
162/162 [==============================] - 1s 3ms/step - loss: 0.3038 - main_output_loss: 0.3699 - second_output_loss: 0.0391





[0.3037593960762024, 0.36992380023002625, 0.039101772010326385]

绘制图像

import pandas as pd
pd.DataFrame(history.history).plot(figsize=(18,10))

在这里插入图片描述

  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 12
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

起名大废废

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值