文 / 李锡涵,Google Developers Expert
本文节选自《简单粗暴 TensorFlow 2.0》
在《【入门教程】TensorFlow 2.0 基础:张量、自动求导与优化器》中,我们手工实现了一个简单的线性回归模型。不过,当模型变得日益复杂时,直接将模型代码写在主程序中就显得臃肿且难以重用了。于是,我们希望有一个高层的模型类,能够帮助我们将模型进行封装。本篇文章即介绍模型类的建立方法。
在 TensorFlow 中,推荐使用 Keras ( tf.keras
) 构建模型。Keras 是一个广为流行的高级神经网络 API,简单、快速而不失灵活性,现已得到 TensorFlow 的官方内置和全面支持。
Keras 有两个重要的概念: 模型(Model) 和 层(Layer) 。层将各种计算流程和变量进行了封装(例如基本的全连接层,CNN 的卷积层、池化层等),而模型则将各种层进行组织和连接,并封装成一个整体,描述了如何将输入数据通过各种层以及运算而得到输出。在需要模型调用的时候,使用 y_pred = model(X)
的形式即可。Keras 在 tf.keras.layers
下内置了深度学习中大量常用的的预定义层,同时也允许我们自定义层。
Keras 模型以类的形式呈现,我们可以通过继承 tf.keras.Model
这个 Python 类来定义自己的模型。在继承类中,我们需要重写 __init__()
(构造函数,初始化)和 call(input)
(模型调用)两个方法,同时也可以根据需要增加自定义的方法。
1class MyModel(tf.keras.Model):
2 def __init__(self):
3 super().__init__()
4 # 此处添加初始化代码(包含call方法中会用到的层),例如
5 # layer1 = tf.keras.layers.BuiltInLayer(...)
6 # layer2 = MyCustomLayer(...)
7
8 def call(self, input):
9 # 此处添加模型调用的代码(处理输入并返回输出),例如
10 # x = layer1(input)
11 # output = layer2(x)
12 return output
13
14 # 还可以添加自定义的方法
Keras 模型类定义示意图
继承 tf.keras.Model
后,我们同时可以使用父类的若干方法和属性,例如在实例化类 model = Model()
后,可以通过 model.variables
这一属性直接获得模型中的所有变量,免去我们一个个显式指定变量的麻烦。