5.模型的保存和恢复

本文详细介绍了如何在TensorFlow中保存和恢复模型,包括使用检查点回调函数在训练过程中保存模型,理解检查点文件,手动保存权重以及保存整个模型。通过实例展示了在MNIST数据集上训练模型,利用ModelCheckpoint回调函数设置检查点,恢复模型后的准确率提升,以及使用Model.save_weights()和Model.save()保存和恢复模型的完整过程。
摘要由CSDN通过智能技术生成

这里我们使用TensorFlow的关于MNIST的数据集的前1000张图片来进行模型的训练和测试.

一.准备

1.1 得到数据集

下载数据集的代码:且我们只取得前面1000个样本.并且都除以255进行归一化处理.

from __future__ import absolute_import,division,print_function
import os

import tensorflow as tf
from tensorflow import keras

(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.mnist.load_data()

train_labels=train_labels[:1000]
test_labels=test_labels[:1000]

print(train_images.shape)
print(test_images.shape)
train_images=train_images[:1000].reshape(-1,28*28)/255.0

test_images=test_images[:1000].reshape(-1,28*28)/255.0
print(test_images.shape)
print(test_images.shape)

结果:

11493376/11490434 [==============================] - 14s 1us/step
(60000, 28, 28)
(10000, 28, 28)
(1000, 784)
(1000, 784)

1.2 定义一个模型

我们这里训练的模型只有三层,最后一层通过softmax输出对于每一个样本预测的概率值.

def create_model():
    model=tf.keras.models.Sequential([
        keras.layers.Dense(512,activation=tf.nn.relu,input_shape=(784,)),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(10,activation=tf.nn.softmax)
    ])
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.sparse_categorical_crossentropy,
                  metrics=['accuracy'])
    return model

model=create_model()
model.summary()

结果:这里的第一层是全连接层,所以参数的个数是512*(784+1)=401920,第二层的参数是0个,...

Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 512)               401920    
_________________________________________________________________
dropout_1 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

 

二.在训练过程中保存检查点

使用tf.keras.callbacks.ModelCheckPoint这个回调函数来实现检查点机制,需要配置以下的参数.

2.1 检查点回调函数的使用

先训练模型,然后将其传给ModelCheckpoint回调函数:

在这里直接运行代码的时候会出现错误:ImportError: `save_weights` requires h5py.所以我先使用pip install h5py,但是提示我先安装cython,所以我先pip install cython.发现可以成功的import h5py.


checkpoint_path='./cp.ckpt'
checkpoint_dir=os.path.dirname(checkpoint_path)

cp_callback=tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                               save_weights_only=True,
                                               verbose=1)

model=create_model()
model.fit(train_images,train_labels,epochs=10,
          validation_data=(test_images,test_labels),
          callbacks=[cp_callback])

结果展示:

Epoch 8/10
  32/1000 [..............................] - ETA: 0s - loss: 0.1069 - acc: 1.0000
 160/1000 [===>..........................] - ETA: 0s - loss: 0.0643 - acc: 1.0000
 288/1000 [=======>......................] - ETA: 0s - loss: 0.0758 - acc: 0.9861
 384/1000 [==========>...................] - ETA: 0s - loss: 0.0713 - acc: 0.9896
 512/1000 [==============>...............] - ETA: 0s - loss: 0.0659 - acc: 0.9922
 640/1000 [===&
  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值