Tensorflow:模型保存与加载

模型训练完成后,需要将模型保存到文件系统上,从而方便后续的模型测试与部署工作。实际上,在训练时间隔性地保存模型状态也是非常好的习惯,这一点对于训练大规模的网络尤其重要,一般大规模的网络需要训练数天乃至数周的时长,一旦训练过程被中断或者发生宕机等意外,之前训练的进度将全部丢失。如果能够间断的保存模型状态到文件系统,即使发生宕机等意外,也可以从最近一次的网络状态文件中恢复,从而避免浪费大量的训练时间。因此模型的保存与加载非常重要。
在 Keras 中,有三种常用的模型保存与加载方法。

一、张量方式

网络的状态主要体现在网络的结构以及网络层内部张量参数上,因此在拥有网络结构源文件的条件下,直接保存网络张量参数到文件上是最轻量级的一种方式。我们以MNIST 手写数字图片识别模型为例,通过调用Model.save_weights(path)方法即可讲当前的网络参
数保存到path 文件上:

network.save_weights('weights.ckpt')

上述代码将 network 模型保存到weights.ckpt 文件上,在需要的时候,只需要先创建好网络对象,然后调用网络对象的load_weights(path)方法即可将指定的模型文件中保存的张量数值写入到当前网络参数中去:

# 保存模型参数到文件上
network.save_weights('weights.ckpt')
print('saved weights.')
del network # 删除网络对象
# 重新创建相同的网络结构
network = Sequential([layers.Dense(256, activation='relu'),
	layers.Dense(128, activation='relu'),
	layers.Dense(64, activation='relu'),
	layers.Dense(32, activation='relu'),
	layers.Dense(10)])
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
# 从参数文件中读取数据并写入当前网络
network.load_weights('weights.ckpt')
print('loaded weights!')

这种保存与加载网络的方式最为轻量级,文件中保存的仅仅是参数张量的数值,并没有其
他额外的结构参数。但是它需要使用相同的网络结构才能够恢复网络状态,因此一般在拥
有网络源文件的情况下使用。

二、网络方式

我们来介绍一种不需要网络源文件,仅仅需要模型参数文件即可恢复出网络模型的方式。通过Model.save(path)函数可以将模型的结构以及模型的参数保存到一个path 文件上,在不需要网络源文件的条件下,通过keras.models.load_model(path)即可恢复网络结构和网络参数。

我们首先将MNIST 手写数字图片识别模型保存到文件上,并且删除网络对象:

# 保存模型结构与模型参数到文件
network.save('model.h5')
print('saved total model.')
del network # 删除网络对象

此时通过model.h5 文件即可恢复出网络的结构和状态:

# 从文件恢复网络结构与网络参数
network = tf.keras.models.load_model('model.h5')

可以看到,model.h5 文件除了保存了模型参数外,还保存了网络结构信息,不需要提前创建模型即可直接从文件中恢复出网络network 对象。

三、SavedModel

TensorFlow 之所以能够被业界青睐,除了优秀的神经网络层API 支持之外,还得益于
它强大的生态系统,包括移动端和网页端的支持。当需要将模型部署到其他平台时,采用
TensorFlow 提出的 SavedModel 方式更具有平台无关性。

通过 tf.keras.experimental.export_saved_model(network, path)即可将模型以 SavedModel 方式保存到path 目录中:

# 保存模型结构与模型参数到文件
tf.keras.experimental.export_saved_model(network, 'model-savedmodel')
print('export saved model.')
del network # 删除网络对象

此时在文件系统model-savedmodel 目录上出现了如下网络文件:
在这里插入图片描述
用户无需关心文件的保存格式,只需要通过

network = tf.keras.experimental.load_from_saved_model('model-savedmodel')

即可恢复出网络结构和参数,方便各个平台能够无缝对接训练好的网络模型。

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
TensorFlow是一个强大的机器学习库,可以用于构建和训练各种深度学习模型。在训练完一个模型后,我们通常会将模型保存到磁盘上以备将来使用。TensorFlow提供了一种简单的方法来保存加载模型。 在 TensorFlow 中,我们可以使用 `tf.saved_model` API 将模型保存到一个文件夹中。该文件夹包含了模型的图和变量等信息。保存模型时,我们可以指定要保存的内容,例如只保存变量或只保存图等等。 以下是一个保存加载模型的示例: ```python import tensorflow as tf from tensorflow import keras # 构建一个简单的模型 model = keras.Sequential([ keras.layers.Dense(64, activation='relu', input_shape=(784,)), keras.layers.Dense(10, activation='softmax') ]) # 训练模型 model.compile(optimizer=tf.train.AdamOptimizer(), loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.fit(x_train, y_train, epochs=5) # 保存模型 export_path = './saved_model' tf.saved_model.simple_save( keras.backend.get_session(), export_path, inputs={'input_image': model.input}, outputs={t.name: t for t in model.outputs}) # 加载模型 with tf.Session(graph=tf.Graph()) as sess: tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], export_path) graph = tf.get_default_graph() inputs = graph.get_tensor_by_name('input_image:0') outputs = graph.get_tensor_by_name('dense_1/Softmax:0') result = sess.run(outputs, feed_dict={inputs: x_test}) print(result) ``` 在上面的示例中,我们首先构建了一个简单的神经网络模型并训练了它。然后我们使用 `tf.saved_model.simple_save` 将模型保存到一个文件夹中。在保存模型时,我们指定了模型的输入和输出。最后,我们使用 `tf.saved_model.loader.load` 方法加载模型并运行它。 值得注意的是,在加载模型时,我们需要使用 `tf.Session` 和 `tf.Graph` 来创建一个新的计算图,并使用 `tf.saved_model.loader.load` 方法从文件夹中加载模型。在加载模型后,我们可以使用 `tf.get_default_graph` 方法获取默认的计算图,并使用 `graph.get_tensor_by_name` 方法获取输入和输出张量的引用。 这就是 TensorFlow保存加载模型的基本方法。当我们需要在生产环境中使用模型时,我们可以使用保存模型轻松地加载和运行它。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

南淮北安

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值