Tensorflow2.0自定义层及其模型

  •  前言

Tensorflow2.0中引入了keras库,极大的简化了我们搭建网络的复杂度;同时eager模式的引入,更加方便了我们对代码的编写及其调试。我们知道原始的keras框架可定制性不强,但在tensorflow2.0中可以自定义我们的每一层和模型。

  • 正文:

下列为使用Sequential(容器)搭建多全连接层网络,我没有去查看源码,但我确定它肯定是继承了keras.Model类;因此它可以调用build()配置参数,调用fit()进行训练。

在tf2.0中,我们创建的Sequential的model直接可以直接model(输入)进行正向传播,拿到输出值,不需要调用第一层的call()方法将输入值输入到第一层,获得输出再输入第二层,这样一步一步得到最终的输出;其实这些操作在model.call()中已经被实现;即执行model(输入)操作时,底层会自动调用model.call()方法。

1)自定义层:

如果我们不仅限使用上述的layers.Dense(),而是自定义我们的全连接层,该怎么做呢?其实非常简单,只需要我们自定义类继承keras.layers.Layers类,同时我们要实现该父类的一些方法,包括:

  1. __init()__ 
  2. call() :这里实现自定义逻辑

这里我们自定义一个全连接层MyDense(),同时实现了 __init()__ 方法,在这里定义了2个变量分别为w和b;注意:我们的自定义的2个变量一定要使用add_variable()这种方法创建,因为我们要让创建的变量交由上上文管理器进行管理,而不能使用tf.constant()类似的方法创建变量。call() 方法中返回该层进行自定义操作的结果,下面这里直接传统的全连接运算,返回运算结果。training也是要经常根据自己的业务逻辑进行处理的。

2)自定义模型:

对比自定义层,自定义模型稍微复杂些,同样自定义类继承keras.Model类,同时我们要实现该父类的一些方法,包括:

  1. __init()__ 
  2. call() :这里实现自定义逻辑
  3. compile()
  4. fit()
  5. evaluate()
  6. predict()

 下面自定义了一个名为MyModel的模型,在__init__()中我们自定义的添加一些全连接层;同时在call()方法中指定这些全连接层如何进行传递的;我们没有重写compile()、fit()、evaluate()、predict()等方法。

后面会抽出时间去读一些这些方法的源码,做到查缺补漏。

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值