Tensorflow2.x中的模型保存和加载

Tensorflow2.x中的模型保存和加载

模型保存和读取

在tensorflow2.x中我们经常需要将训练后的模型保存下来以之后进行使用。这里我们将模型保存进行简单的总结。

1、测试环境

系统显卡处理器Cuda版本Tensorflow版本
Windows10 ProNvidia RTX2070SuperIntel core i5 9600KF10.1Tensorflow-GPU 2.3

2、基于网络模式模型保存和读取

这里我们简单定义一个5层全连接网络

# 定义一个5层的全连接网络
network = Sequential([layers.Dense(256, activation='relu'),
                     layers.Dense(128, activation='relu'),
                     layers.Dense(64, activation='relu'),
                     layers.Dense(32, activation='relu'),
                     layers.Dense(10)])
network.build(input_shape=(None, 28*28))
network.summary()




network.compile(optimizer=optimizers.Adam(lr=0.01),
		loss=tf.losses.CategoricalCrossentropy(from_logits=True),
		metrics=['accuracy']
	)

network.fit(db, epochs=3, validation_data=ds_val, validation_freq=2)
 
network.evaluate(ds_val)

# 网络保存
network.save('model.h5')
print('saved total model.')
del network

在模型训练完毕后将训练结果保存,该结果包含了所有的模型参数和网络结构

此时如果我们需要读取训练后的模型可以通过以下语句进行读取

network = keras.models.load_model('model.h5')

基于网络的模型保存和读取主要适用于使用Sequential容器定义的网络结构,如果是自定义的网络结构如resnet则需要使用张量方式进行模型保存

3、基于张量方式的模型保存和读取

基于张量方式的模型保存主要需要保存其网络的参数,在模型读取的时候需要重建整个网络然后读取权重和参数才可以继续使用网络,该方法也可保存自定义网络结构

这里我们继续定义一个5层的全连接网络,在训练完成后保存网络

# 定义一个5层的全连接网络
network = Sequential([layers.Dense(256, activation='relu'),
                     layers.Dense(128, activation='relu'),
                     layers.Dense(64, activation='relu'),
                     layers.Dense(32, activation='relu'),
                     layers.Dense(10)])
network.build(input_shape=(None, 28*28))
network.summary()




network.compile(optimizer=optimizers.Adam(lr=0.01),
		loss=tf.losses.CategoricalCrossentropy(from_logits=True),
		metrics=['accuracy']
	)

network.fit(db, epochs=3, validation_data=ds_val, validation_freq=2)
 
network.evaluate(ds_val)

# 网络保存
network.save_weights('weights.ckpt')
print('model weights saved.')
del network

模型读取

模型读取的时候需要重建整个网络结构后进行权重和参数的读取

# 重建整个网络结构
network = Sequential([layers.Dense(256, activation='relu'),
                     layers.Dense(128, activation='relu'),
                     layers.Dense(64, activation='relu'),
                     layers.Dense(32, activation='relu'),
                     layers.Dense(10)])
network.build(input_shape=(None, 28*28))
network.summary()




network.compile(optimizer=optimizers.Adam(lr=0.01),
		loss=tf.losses.CategoricalCrossentropy(from_logits=True),
		metrics=['accuracy']
	)


# 模型参数读取
network.load_weights('weights.ckpt')
print('weights loaded')

3、基于saved_model方法的模型保存和读取

使用Tensorflow提供的saved_model接口可以轻松的保存和读取模型到path目录中

# 保存模型结构与模型参数到文件
# model_savedmodel为自定义的路径
tf.saved_model.save(network, 'model_savedmodel')
print('saving savedmodel')

模型的读取也十分的简单

print('load savedmodel from file')
# 从文件中恢复网络结构与参数,model_savedmodel为pb文件的路径
network = tf.saved_model.load('model_savedmodel')
  • 1
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值