TensorFlow2.0入门到进阶2.14 —— wide&deep模型多输入多输出

1、wide&deep理论及前期博客

wide&deep模型:https://blog.csdn.net/caoyuan666/article/details/105869670

函数API实现wide&deep模型

子类API实现wide&deep模型

2、多输入

本实验使用数据为房价预测的数据集,如果不清楚的小伙伴请看:
一个房价预测回归项目轻松入门TensorFlow

多输入一般用于多套输入特征的情况下使用。

通过查看数据集维度,可知本数据集共有8个特征

from sklearn.datasets import fetch_california_housing

housing=fetch_california_housing()

#print(housing.DESCR)
print(housing.data.shape)
print(housing.target.shape)

输出结果:
(20640, 8)
(20640,)

用一个输入维度分别为5和6的多输入例子来展示:

#多输入 函数式的方法
input_wide=keras.layers.Input(shape=[5])
input_deep=keras.layers.Input(shape=[6])
hidden1=keras.layers.Dense(30,activation='relu')(input_deep)
hidden2=keras.layers.Dense(30,activation='relu')(hidden1)
concat=keras.layers.concatenate([input_wide,hidden2])
output=keras.layers.Dense(1)(concat)
model=keras.models.Model(inputs=[input_wide,input_deep],
                       outputs=output)

model.summary()
model.compile(loss='mean_squared_error',
              optimizer='adam',)

模型结构:

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_6 (InputLayer)            [(None, 6)]          0                                            
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, 30)           210         input_6[0][0]                    
__________________________________________________________________________________________________
input_5 (InputLayer)            [(None, 5)]          0                                            
__________________________________________________________________________________________________
dense_7 (Dense)                 (None, 30)           930         dense_6[0][0]                    
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 35)           0           input_5[0][0]                    
                                                                 dense_7[0][0]                    
__________________________________________________________________________________________________
dense_8 (Dense)                 (None, 1)            36          concatenate_2[0][0]              
==================================================================================================
Total params: 1,176
Trainable params: 1,176
Non-trainable params: 0
__________________________________________________________________________________________________

之前我们数据集为8维特征,这时候当然要拆分一下,第一个截取前5维特征,第二个截取后6维特征,中间的部分特征使用了两次:

callbacks=[keras.callbacks.EarlyStopping(patience=5,min_delta=1e-2)]

x_train_scaled_wide=x_train_scaled[:,:5]
x_train_scaled_deep=x_train_scaled[:,2:]
x_valid_scaled_wide=x_valid_scaled[:,:5]
x_valid_scaled_deep=x_valid_scaled[:,2:]
x_test_scaled_wide=x_test_scaled[:,:5]
x_test_scaled_deep=x_test_scaled[:,2:]

history=model.fit([x_train_scaled_wide,x_train_scaled_deep],y_train,
         epochs=20,
         validation_data=([x_valid_scaled_wide,x_valid_scaled_deep],y_valid),
         callbacks = callbacks )

3、多输出

  • 多输出:比如预测未来一天、一个月、一年的房价
  • 由于本数据集只提供了当前房价,所以本程序将用wide_deep和deep两种方式预测的结果来展示多输出

input_wide=keras.layers.Input(shape=[5])
input_deep=keras.layers.Input(shape=[6])
hidden1=keras.layers.Dense(30,activation='relu')(input_deep)
hidden2=keras.layers.Dense(30,activation='relu')(hidden1)
concat=keras.layers.concatenate([input_wide,hidden2])
output=keras.layers.Dense(1)(concat)
output2=keras.layers.Dense(1)(hidden2)

model=keras.models.Model(inputs=[input_wide,input_deep],
                       outputs=[output,output2])

model.summary()
model.compile(loss='mean_squared_error',
              optimizer='adam',)

模型结构:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            [(None, 6)]          0                                            
__________________________________________________________________________________________________
dense (Dense)                   (None, 30)           210         input_2[0][0]                    
__________________________________________________________________________________________________
input_1 (InputLayer)            [(None, 5)]          0                                            
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 30)           930         dense[0][0]                      
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 35)           0           input_1[0][0]                    
                                                                 dense_1[0][0]                    
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 1)            36          concatenate[0][0]                
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 1)            31          dense_1[0][0]                    
==================================================================================================
Total params: 1,207
Trainable params: 1,207
Non-trainable params: 0
__________________________________________________________________________________________________
  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

努力改掉拖延症的小白

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值