Tensorflow2.x中的模型保存和加载
模型保存和读取
在tensorflow2.x中我们经常需要将训练后的模型保存下来以之后进行使用。这里我们将模型保存进行简单的总结。
1、测试环境
系统 | 显卡 | 处理器 | Cuda版本 | Tensorflow版本 |
---|---|---|---|---|
Windows10 Pro | Nvidia RTX2070Super | Intel core i5 9600KF | 10.1 | Tensorflow-GPU 2.3 |
2、基于网络模式模型保存和读取
这里我们简单定义一个5层全连接网络
# 定义一个5层的全连接网络
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.build(input_shape=(None, 28*28))
network.summary()
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
network.fit(db, epochs=3, validation_data=ds_val, validation_freq=2)
network.evaluate(ds_val)
# 网络保存
network.save('model.h5')
print('saved total model.')
del network
在模型训练完毕后将训练结果保存,该结果包含了所有的模型参数和网络结构
此时如果我们需要读取训练后的模型可以通过以下语句进行读取
network = keras.models.load_model('model.h5')
基于网络的模型保存和读取主要适用于使用Sequential容器定义的网络结构,如果是自定义的网络结构如resnet则需要使用张量方式进行模型保存
3、基于张量方式的模型保存和读取
基于张量方式的模型保存主要需要保存其网络的参数,在模型读取的时候需要重建整个网络然后读取权重和参数才可以继续使用网络,该方法也可保存自定义网络结构
这里我们继续定义一个5层的全连接网络,在训练完成后保存网络
# 定义一个5层的全连接网络
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.build(input_shape=(None, 28*28))
network.summary()
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
network.fit(db, epochs=3, validation_data=ds_val, validation_freq=2)
network.evaluate(ds_val)
# 网络保存
network.save_weights('weights.ckpt')
print('model weights saved.')
del network
模型读取
模型读取的时候需要重建整个网络结构后进行权重和参数的读取
# 重建整个网络结构
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.build(input_shape=(None, 28*28))
network.summary()
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
# 模型参数读取
network.load_weights('weights.ckpt')
print('weights loaded')
3、基于saved_model方法的模型保存和读取
使用Tensorflow提供的saved_model接口可以轻松的保存和读取模型到path目录中
# 保存模型结构与模型参数到文件
# model_savedmodel为自定义的路径
tf.saved_model.save(network, 'model_savedmodel')
print('saving savedmodel')
模型的读取也十分的简单
print('load savedmodel from file')
# 从文件中恢复网络结构与参数,model_savedmodel为pb文件的路径
network = tf.saved_model.load('model_savedmodel')