TensorFlow2.0模型保存与调用

TensorFlow2.0模型保存与调用

模型保存

首先建立一个手写数字识别的神经网络,将其训练后保存为模型文件:

import tensorflow as tf
from tensorflow.keras import layers

# 数据处理部分,包括训练集与测试集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 28, 28, 1).astype('float32') / 255
x_test = x_test.reshape(10000, 28, 28, 1).astype('float32') / 255
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

# 通过函数式构建网络模型
encode_input = tf.keras.Input(shape=(28,28,1), name='img')
# 第一层卷积,padding设置为SAME,保持分辨率
h1 = layers.Conv2D(16, 3, activation='relu', padding='SAME')(encode_input)
# 最大池化操作
h1 = layers.MaxPool2D()(h1)
# 第二层卷积
h1 = layers.Conv2D(32, 3, activation='relu', padding='SAME')(h1)
# 最大池化操作
h1 = layers.MaxPool2D()(h1)
# 将二维数据拉直成为一维数据
h1 = layers.Flatten()(h1)
# 第三层全连接操作
out = layers.Dense(10, activation='softmax')(h1)

model = tf.keras.Model(inputs=encode_input, outputs=out)

# 设置模型参数
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
             loss='categorical_crossentropy',
             metrics=['accuracy'])

# 输入数据开始训练模型
model.fit(x_train, y_train, epochs=2, batch_size=16)

# 查看模型输出结果
print(model.predict(x_test, batch_size=8))

# 保存完整模型
# .h5格式
model.save("../weights/minist.h5", save_format="")
del model

# .pb模式
#tf.saved_model.save(model, "../weights/minist")

通过model.save()函数能非常方便的保存模型,TensorFlow2.0主要包括两种方式保存完整模型:

  • .h5文件,具体代码model.save(“xxx.h5”)
  • pb文件,具体代码mode.save(“xxx”, save_format=“tf”)
    生成的模型文件

模型调用

通过函数tf.keras.models.load_model()函数能快速加载.h5模型与pb模型

import numpy as np
import tensorflow as tf

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 28, 28, 1).astype('float32') / 255
x_test = x_test.reshape(10000, 28, 28, 1).astype('float32') / 255
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

model = tf.keras.models.load_model("../weights/minist.h5")
# 加载pb模型
# model = tf.keras.models.load_model("../weights/minist")

model.summary()

print("第二次输出:")
print(model.predict(x_test, batch_size=8))

对比下两次输出结果就可以知道模型是否加载正确,是不是非常简单呢!

  • 6
    点赞
  • 37
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值