tensorflow2的模型保存与加载(save_weights、save和saved_model.save)

1、save/load weights

只保存网络的一个参数,不管其他的状态,这种模式适合自己对代码有个清晰的认识

用法流程如下:

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')  # 提供保存的路径

# Restore the weights
model = create_model()  # 重新创建网络
model.load_weights('./checkpoints/my_checkpoint')

loss, acc = model.evaluate(test_images, test_labels)  # 查看accuracy是否变化
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

示例:

network.save_weights('weights.ckpt')
print('saved weights.')
del network

network = Sequential([layers.Dense(256, activation='relu'),
                     layers.Dense(128, activation='relu'),
                     layers.Dense(64, activation='relu'),
                     layers.Dense(32, activation='relu'),
                     layers.Dense(10)])
network.compile(optimizer=optimizers.Adam(lr=0.01),
		loss=tf.losses.CategoricalCrossentropy(from_logits=True),
		metrics=['accuracy']
	)
network.load_weights('weights.ckpt')
network.evaluate(ds_val)

2、save/load entire model

这种方法是最简单粗暴的,它把所有的模型和状态都保存起来,可以进行完美的恢复

用法如下:

network.save('model.h5')
print('saved total model.')
del network

print('loaded model from file.')
network = tf.keras.models.load_model('model.h5', compile=False)  # 不需要重新创建网络

network.evaluate(ds_val)

3、saved_model

模型的一种保存格式,跟pytorch的ONNX对应,也就是说当训练的一个模型交给工厂的生产环境的时候,可以把这个模型直接交给用户来部署,而不需要给一个源代码或相关的信息,这个模型就包含的所有的这样一个信息。比如,你通过python写的源文件,你可以用c++解析和读取这个工作。

用法如下:

tf.saved_model.save(m, '/tmp/saved_model')

imported = tf.saved_model.load(path)
f = imported.signatures["serving_default"]
print(f(x = tf.ones([1, 28, 28, 3])))
  • 8
    点赞
  • 44
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值