TensorFlow2.0入门到进阶2.12 ——函数API实现wide&deep模型

1、wide&deep原理

wide&deep模型:https://blog.csdn.net/caoyuan666/article/details/105869670

2、代码

函数式API 在创建模型时就像调用函数一样,将上一层结果像函数变量一样输入的下一层的函数中:

#复合函数:f(x)=h(g(x))
input = keras.layers.Input(shape=x_train.shape[1:])
hidden1=keras.layers.Dense(30,activation='relu')(input)
hidden2=keras.layers.Dense(30,activation='relu')(hidden1)
 
#将wide和deep数据拼接
concat = keras.layers.concatenate([input,hidden2])
output = keras.layers.Dense(1)(concat)

#由于函数式API没有将模型返回保存,所以需要使用model将模型固化下来
model = keras.models.Model(inputs=[input],
                         outputs=[output])

model.summary()
model.compile(loss='mean_squared_error',
              optimizer='adam',)

网络结构结构:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 8)]          0                                            
__________________________________________________________________________________________________
dense (Dense)                   (None, 30)           270         input_1[0][0]                    
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 30)           930         dense[0][0]                      
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 38)           0           input_1[0][0]                    
                                                                 dense_1[0][0]                    
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 1)            39          concatenate[0][0]                
==================================================================================================
Total params: 1,239
Trainable params: 1,239
Non-trainable params: 0
__________________________________________________________________________________________________

wide层:input(对于输入数据只经过一层input)
deep层:hidden2(经过两层隐层,相对较深,这里只是举例,其实有点前)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

努力改掉拖延症的小白

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值