【20200302】Tensorflow 2.1:关于tf.keras保存和载入模型的笔记

由于tensorflow api可能因为各版本变化,变来变去的,导致网上的许多网友的教程办法都不对,坑坑坑坑的满满

这里我放上最近学习tensorflow利用自定义网络来训练mnist的一个代码实例

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential

import os
# 只显示 warning 和 Error
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

def preprocess(x, y):
    # x是一张图片,不是一个batch
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = tf.reshape(x, [28 * 28])
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)

    return x, y


# 从dataset加载mnist数据集
# train:[60k, 28, 28], test: [10k, 28, 28]
(x, y), (x_test, y_test) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape)

# 设置batch分块大小
batch_size = 100

# 将x,y转成tensor数据类型
db_train = tf.data.Dataset.from_tensor_slices((x, y))
# 对x,y进行数据预处理
db_train = db_train.map(preprocess).shuffle(60000).batch(batch_size)

# 将x_test,y_test 转成tensor数据类型
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
# 对x_test,y_test 进行数据预处理
db_test = db_test.map(preprocess).shuffle(10000).batch(batch_size)

# 设置学习率
lr = 0.01

# 自定义Dense类
class MyDense(layers.Layer):
    def __init__(self, inp_dim, outp_dim):
        super(MyDense, self).__init__()
        # 这里不再使用旧API的add_variable函数
        self.kernel = self.add_weight('w', [inp_dim, outp_dim])
        self.bias = self.add_weight('b', [outp_dim])

    def call(self, inputs, training=None):
        # 这里是自定义计算方法,这里是一个最简单的线性计算
        out = inputs @ self.kernel + self.bias
        return out

# 自定义模型
class MyModel(keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        
        self.func1 = MyDense(28 * 28, 256) # [784 -> 256]
        self.func2 = MyDense(256, 128)     # [256 -> 128]
        self.func3 = MyDense(128, 64)      # [128 -> 64]
        self.func4 = MyDense(64, 32)       # [64 -> 32]
        self.func5 = MyDense(32, 10)       # [32 -> 10]

    def call(self, inputs, training=None):
        x = tf.nn.relu(self.func1(inputs))
        x = tf.nn.relu(self.func2(x))
        x = tf.nn.relu(self.func3(x))
        x = tf.nn.relu(self.func4(x))
        x = self.func5(x)

        return x

def main():
    model = MyModel()
    # 设置优化器,设置损失函数,设置统计结果
    model.compile(optimizer=optimizers.Adam(lr),
                  loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    # 设置训练数据集,设置训练周期,设置测试数据集,设置测试频率
    model.fit(db_train, epochs=3, validation_data=db_test, validation_freq=2)
    # 对训练后的模型网络进行校验
    model.evaluate(db_test)
    # 保存训练后的模型
    tf.saved_model.save(model, 'model.h5')
    print('model saved')
    del(model)

    print('model loaded')
    # 读取训练后的模型
    model = tf.keras.models.load_model('model.h5', compile=False)
    model.compile(optimizer=optimizers.Adam(lr),
                  loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    # 测试该模型
    model.evaluate(db_test)


if __name__ == "__main__":
    main()

之所以说坑的地方在于:保存的数据,与读取后的数据可能有差别,这就在于其h5 跟 tf的pb格式差别

1、旧的写法:

network.save('model.h5')
print('saved total model.')
del network

print('loaded model from file.')
network = tf.keras.models.load_model('model.h5', compile=False)

这里会出现save方法报Tensorflow SavedModel format的模型无法被保存为h5格式就会报错,可以的话试着把模型以h5格式建立

 

2、旧的写法:

tf.saved_model.save(model, 'model.h5')
print('model saved')
del(model)

print('model loaded')
model = tf.saved_model.load('model.h5')

这里保存没问题,但是这里读取的时候model并不是Keras models,也是一个神坑

官方的api中这么提到:

所以要从save_model保存的数据还原成tf.keras对象的模型,要使用tf.keras.models.load_model来完成

 

3、目前可行的用法:

tf.saved_model.save(model, 'model.h5')
print('model saved')
del(model)

print('model loaded')
model = tf.keras.models.load_model('model.h5', compile=False)

 

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值