keras Model模型( functional API) 多输入多输出!

更强大!Model模型~

用Sequential只能定义一些简单的模型,如果你想要定义多输入、多输出以及共享网络层,就需要使用Model模型了。

声明方法

inputs = Input(shape=(784,))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
out = Dense(10, activation='softmax')(x)
model = Model(inputs=inputs, outputs=out)

在model模型的声明中,需要使用 y = l a y e r ( . . . ) ( x ) y = layer(...)(x) y=layer(...)(x)这样的格式来构建没一个层次,并在构造函数中声明你的模型的输入和输出是什么。


多输入多输出

来考虑下面的模型。我们试图预测 Twitter 上的一条新闻标题有多少转发和点赞数。模型的主要输入将是新闻标题本身,即一系列词语,但是为了增添趣味,我们的模型还添加了其他的辅助输入来接收额外的数据,例如新闻标题的发布的时间等。 该模型也将通过两个损失函数进行监督学习。较早地在模型中使用主损失函数,是深度学习模型的一个良好正则方法。

模型结构如下图所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-w7tuPas2-1586093539846)(model.assets/multi-input-multi-output-graph.png)]

from keras.layers import Input, Embedding, LSTM, Dense
from keras.models import Model
main_input = Input(shape=(100,), dtype='int32', name='main_input')
x = Embedding(output_dim=512, input_dim=10000, input_length=100)(main_input)
lstm_out = LSTM(32)(x)
auxiliary_output = Dense(1, activation='sigmoid', name='aux_output')(lstm_out)
auxiliary_input = Input(shape=(5,), name='aux_input')
x = keras.layers.concatenate([lstm_out, auxiliary_input]) # 共享层,下面介绍

x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)

main_output = Dense(1, activation='sigmoid', name='main_output')(x)
model = Model(inputs=[main_input, auxiliary_input], outputs=[main_output, auxiliary_output])

共享网络层

多输入依赖于共享,直接看例子。

import keras
from keras.layers import Input, LSTM, Dense
from keras.models import Model

tweet_a = Input(shape=(280, 256))
tweet_b = Input(shape=(280, 256))

要在不同的输入上共享同一个层,只需实例化该层一次,然后根据需要传入你想要的输入即可:

shared_lstm = LSTM(64)
encoded_a = shared_lstm(tweet_a)
encoded_b = shared_lstm(tweet_b)
merged_vector = keras.layers.concatenate([encoded_a, encoded_b], axis=-1)
predictions = Dense(1, activation='sigmoid')(merged_vector)
model = Model(inputs=[tweet_a, tweet_b], outputs=predictions)

总结

在Model模型中,多输入、多输出、共享的实现都是简单的,只要按照一定的逻辑创建好有向图即可!

关于本文内容,部分借鉴自keras文档。

实验代码见https://github.com/1173710224/keras-cnn-captcha.git的model-example分支。
欢迎关注公众号BBIT
让我们共同学习共同进步!

在这里插入图片描述

  • 1
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值