刚刚学习tensorflow2.0版本,总结了一下tensorflow2.0应用的基本框架,希望能帮到以后跟我一样刚刚接触的萌新。
在tensorflow2.0中,首先我们要设计一个自己的模型,所以要创建一个class类,这里我们用一个简单的CNN手写体识别网络来举例。在创建我们自己的model类时,首先要继承tensorflow为我们写好的模块父类,tf.keras.Model,并且至少重写其中的两个函数,init()和call()。
init函数是模块的初始化函数,在模块被创建时运行一次,我们把要实现的各个网络层在这里命名。在例子中我们共初始化了两个卷积层,两个池化层,两个线性层,一个flatten层和一个dropout层。call()函数会在我们调用模块时运行,例如我们实例化了一个cnn = CNN(),此时y = cnn(x)等价于y = cnn.call(x)。在这一函数中我们进行网络的搭建。至此模型的搭建就算结束了。
class CNN(tf.keras.Model):
def __init__(self):
super().__init__() # 继承父类的init函数,这里需要注意,在python2中需要写成super(CNN, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(
filters=32,
kernel_size = [5,5],
activation=tf.nn.relu,
padding='same'
)
self.pool1 = tf.keras.layers.MaxPool2D(pool_size=[2,2], strides=2)
self.conv2 = tf.keras.layers.Conv2D(
filters=64,
kernel_size=[5,5],
activation=tf.nn.relu,
padding='same'
)
self.pool2 = tf.keras.layers.MaxPool2D(pool_size=[2,2], strides=2)
self.flatten = tf.keras.layers.Reshape(target_shape=(7 * 7 * 64,))
self.drop1 = tf.keras.layers.Dropout(0.5)
self.dense1 = tf.keras.layers.Dense(units=1024, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(units=10, activation=tf.nn.softmax)
@tf.function
def call(self, inputs):
x = self.conv1(inputs)
x = self.pool1(x)
x = self.conv2(x)
x = self.pool2(x)
x = self.flatten(x)
x = self.drop1(x)
x = self.dense1(x)
outputs = self.dense2(x)
return outputs
在训练模型时,最简单的可以分为5步。
- 运用现在的模型生成预测,即y_pred
- 通过对比y_pred和y_true生成模型的loss
- 将一个batch内的loss压缩成一个实数
- 获得loss关于模型参数的梯度
- 通过获得的梯度优化模型
相应的python代码自然也就是5行。在那之前,我们要实例化一个优化器optimizer,它会帮我们根据梯度自动优化模型参数。相应代码如下:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001) # 优化器类型根据需求自行选择
for index in range(epoch_num):
with tf.GradientTape() as tape:
y_pred = cnn(x) # 对应步骤1
loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred) # 对应步骤2,loss模型根据需求自行选择
loss = tf.reduce_mean(loss) # 对应步骤3
print('epoch: %d, loss: %f' %(index, loss)) # 打印出每一次训练的loss
grads = tape.gradient(loss, model.variables) # 对应步骤4
optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables)) # 对应步骤5
至此,一个简单模型的搭建和预测步骤就到此结束了。复杂的结构也可以参考这一简单框架。当我们需要应用模型时,只需要调用y_pred = cnn(x)即可。