实战Keras2.0 API:keras.models.load_model加载已保存模型继续训练

1、数据准备

# 加载MNIST数据集
data= np.load('/Users/code/MNIST_data/mnist.npz', allow_pickle=True)
x_train, y_train = data['x_train'], data['y_train']
x_test, y_test = data['x_test'], data['y_test']

# 归一化像素值到0~1之间
x_train = x_train.astype('float32') / 255.  
x_test = x_test.astype('float32') / 255.  

#随机排序
perm = np.random.permutation(x_train.shape[0])  
x_train = x_train[perm]    
y_train = y_train[perm]
 

# 对label进行one-hot编码 
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)


2、使用keras.models.load_model加载模型

model_path='/Users/code/model/mnist.h5'
new_model = models.load_model(model_path) 


3、编译模型

#optimizer优化器
#loss损失函数
#metrics准确率
new_model.compile(optimizer='adam', 
    loss='categorical_crossentropy',  
    metrics=['accuracy'])


4、模型训练

#train_images, train_labels输入训练集的特征和标签
#validation_data导入测试集的特征和标签
#epochs要迭代多少次数据集
#batch_size每批数据多少个
#verbose打印训练过程,verbose=1会显示进度条,verbose=2会显示每轮次的训练损失和验证损失
#callbacks=[callback]回调函数,非必需
new_model.fit(x_train, y_train, 
          validation_data=(x_test, y_test),  
          epochs=5, 
          batch_size=128,verbose=1)


5、打印模型

new_model.summary()

结果输出:

Epoch 1/5
469/469 [==============================] - 52s 109ms/step - loss: 0.0211 - accuracy: 0.9930 - val_loss: 0.0271 - val_accuracy: 0.9913
Epoch 2/5
469/469 [==============================] - 49s 105ms/step - loss: 0.0145 - accuracy: 0.9954 - val_loss: 0.0274 - val_accuracy: 0.9910
Epoch 3/5
469/469 [==============================] - 52s 110ms/step - loss: 0.0111 - accuracy: 0.9964 - val_loss: 0.0268 - val_accuracy: 0.9917
Epoch 4/5
469/469 [==============================] - 53s 112ms/step - loss: 0.0089 - accuracy: 0.9970 - val_loss: 0.0276 - val_accuracy: 0.9924
Epoch 5/5
469/469 [==============================] - 52s 111ms/step - loss: 0.0080 - accuracy: 0.9973 - val_loss: 0.0225 - val_accuracy: 0.9940
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 28, 28, 32)        320       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 14, 14, 32)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 12, 12, 64)        18496     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 6, 6, 64)         0         
 2D)                                                             
                                                                 
 conv2d_2 (Conv2D)           (None, 4, 4, 128)         73856     
                                                                 
 flatten (Flatten)           (None, 2048)              0         
                                                                 
 dense (Dense)               (None, 128)               262272    
                                                                 
 dense_1 (Dense)             (None, 10)                1290      
                                                                 
=================================================================
Total params: 356,234
Trainable params: 356,234

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

缘起性空、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值