一、函数式API(Fucntional API)
本文代码基于tensorflow2.0 python 3.7
- tf.keras.Sequential 模型是层的简单堆叠,无法表示任意模型。
import
二、模型子类化—实现自定义模型
"模型子类化"就是自己实现一个类来继承Model类,构建一个Model类的子类,
需要实现两个方法,即:
__init__() call()
通过对 tf.keras.Model 进行子类化并定义自己的前向传播来构建完全可自定义的模型。
- 在 __init__ 方法中创建层并将它们设置为类实例的属性
- 在 call 方法中定义前向传播
下面给出典型的ResNet网络代码:
import
总结:一般情况下,简单的应用可以直接使用函数式API编程,对于复杂的网络的定义和训练可以使用类继承的方式,这样的代码逻辑和封装性较好