深度学习2.0-23.Keras高层接口之模型的加载与保存

模型的保存与加载

在这里插入图片描述

1.load/save_weights

在这里插入图片描述
在这里插入图片描述

实战
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics


def preprocess(x, y):
    """
    x is a simple image, not a batch
    """
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = tf.reshape(x, [28 * 28])
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x, y


batchsz = 128
(x, y), (x_val, y_val) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max())

db = tf.data.Dataset.from_tensor_slices((x, y))
db = db.map(preprocess).shuffle(60000).batch(batchsz)

ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(preprocess).batch(batchsz)

sample = next(iter(db))
print(sample[0].shape, sample[1].shape)

network = Sequential([layers.Dense(256, activation='relu'),
                      layers.Dense(128, activation='relu'),
                      layers.Dense(64, activation='relu'),
                      layers.Dense(32, activation='relu'),
                      layers.Dense(10)])
network.build(input_shape=(None, 28 * 28))
network.summary()

network.compile(optimizer=optimizers.Adam(lr=0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy']
                )

network.fit(db, epochs=3, validation_data=ds_val, validation_freq=2)

network.evaluate(ds_val)

# 保存模型的参数
network.save_weights('weights.ckpt')
print('saved weights.')
del network

# 构建多层网络
network = Sequential([layers.Dense(256, activation='relu'),
                      layers.Dense(128, activation='relu'),
                      layers.Dense(64, activation='relu'),
                      layers.Dense(32, activation='relu'),
                      layers.Dense(10)])
network.compile(optimizer=optimizers.Adam(lr=0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy']
                )
# 加载模型的参数
network.load_weights('weights.ckpt')
print('loaded weights!')
network.evaluate(ds_val)
Epoch 2/3
  1/469 [..............................] - ETA: 13:50 - loss: 0.1488 - accuracy: 0.9688
 19/469 [>.............................] - ETA: 43s - loss: 0.1262 - accuracy: 0.9634  
 39/469 [=>............................] - ETA: 20s - loss: 0.1342 - accuracy: 0.9619
 57/469 [==>...........................] - ETA: 13s - loss: 0.1324 - accuracy: 0.9635
 76/469 [===>..........................] - ETA: 10s - loss: 0.1379 - accuracy: 0.9623
 93/469 [====>.........................] - ETA: 8s - loss: 0.1348 - accuracy: 0.9632 
110/469 [======>.......................] - ETA: 6s - loss: 0.1370 - accuracy: 0.9618
130/469 [=======>......................] - ETA: 5s - loss: 0.1375 - accuracy: 0.9615
150/469 [========>.....................] - ETA: 4s - loss: 0.1384 - accuracy: 0.9618
169/469 [=========>....................] - ETA: 3s - loss: 0.1384 - accuracy: 0.9614
187/469 [==========>...................] - ETA: 3s - loss: 0.1369 - accuracy: 0.9619
207/469 [============>.................] - ETA: 2s - loss: 0.1385 - accuracy: 0.9612
229/469 [=============>................] - ETA: 2s - loss: 0.1387 - accuracy: 0.9612
251/469 [===============>..............] - ETA: 2s - loss: 0.1393 - accuracy: 0.9610
274/469 [================>.............] - ETA: 1s - loss: 0.1388 - accuracy: 0.9615
297/469 [=================>............] - ETA: 1s - loss: 0.1378 - accuracy: 0.9616
319/469 [===================>..........] - ETA: 1s - loss: 0.1373 - accuracy: 0.9618
342/469 [====================>.........] - ETA: 0s - loss: 0.1366 - accuracy: 0.9621
363/469 [======================>.......] - ETA: 0s - loss: 0.1356 - accuracy: 0.9624
385/469 [=======================>......] - ETA: 0s - loss: 0.1362 - accuracy: 0.9623
407/469 [=========================>....] - ETA: 0s - loss: 0.1358 - accuracy: 0.9624
429/469 [==========================>...] - ETA: 0s - loss: 0.1350 - accuracy: 0.9627
450/469 [===========================>..] - ETA: 0s - loss: 0.1342 - accuracy: 0.9629
466/469 [============================>.] - ETA: 0s - loss: 0.1343 - accuracy: 0.9629
469/469 [==============================] - 3s 7ms/step - loss: 0.1344 - accuracy: 0.9629 - val_loss: 0.1209 - val_accuracy: 0.9648
Epoch 3/3
  1/469 [..............................] - ETA: 14:16 - loss: 0.1254 - accuracy: 0.9609
 20/469 [>.............................] - ETA: 42s - loss: 0.1014 - accuracy: 0.9695  
 39/469 [=>............................] - ETA: 21s - loss: 0.1063 - accuracy: 0.9700
 60/469 [==>...........................] - ETA: 13s - loss: 0.1006 - accuracy: 0.9703
 82/469 [====>.........................] - ETA: 9s - loss: 0.1041 - accuracy: 0.9690 
105/469 [=====>........................] - ETA: 7s - loss: 0.1089 - accuracy: 0.9676
128/469 [=======>......................] - ETA: 5s - loss: 0.1072 - accuracy: 0.9684
151/469 [========>.....................] - ETA: 4s - loss: 0.1056 - accuracy: 0.9692
171/469 [=========>....................] - ETA: 3s - loss: 0.1089 - accuracy: 0.9688
189/469 [===========>..................] - ETA: 3s - loss: 0.1094 - accuracy: 0.9688
208/469 [============>.................] - ETA: 2s - loss: 0.1122 - accuracy: 0.9681
228/469 [=============>................] - ETA: 2s - loss: 0.1099 - accuracy: 0.9687
250/469 [==============>...............] - ETA: 2s - loss: 0.1093 - accuracy: 0.9691
270/469 [================>.............] - ETA: 1s - loss: 0.1088 - accuracy: 0.9692
291/469 [=================>............] - ETA: 1s - loss: 0.1081 - accuracy: 0.9696
312/469 [==================>...........] - ETA: 1s - loss: 0.1079 - accuracy: 0.9700
334/469 [====================>.........] - ETA: 1s - loss: 0.1082 - accuracy: 0.9700
356/469 [=====================>........] - ETA: 0s - loss: 0.1086 - accuracy: 0.9699
378/469 [=======================>......] - ETA: 0s - loss: 0.1083 - accuracy: 0.9699
401/469 [========================>.....] - ETA: 0s - loss: 0.1071 - accuracy: 0.9700
422/469 [=========================>....] - ETA: 0s - loss: 0.1081 - accuracy: 0.9698
441/469 [===========================>..] - ETA: 0s - loss: 0.1089 - accuracy: 0.9697
459/469 [============================>.] - ETA: 0s - loss: 0.1083 - accuracy: 0.9700
469/469 [==============================] - 3s 6ms/step - loss: 0.1082 - accuracy: 0.9701
 1/79 [..............................] - ETA: 0s - loss: 0.0620 - accuracy: 0.9844
11/79 [===>..........................] - ETA: 0s - loss: 0.1625 - accuracy: 0.9616
21/79 [======>.......................] - ETA: 0s - loss: 0.1902 - accuracy: 0.9576
32/79 [===========>..................] - ETA: 0s - loss: 0.1910 - accuracy: 0.9570
41/79 [==============>...............] - ETA: 0s - loss: 0.1845 - accuracy: 0.9573
50/79 [=================>............] - ETA: 0s - loss: 0.1695 - accuracy: 0.9605
60/79 [=====================>........] - ETA: 0s - loss: 0.1499 - accuracy: 0.9645
70/79 [=========================>....] - ETA: 0s - loss: 0.1389 - accuracy: 0.9667
79/79 [==============================] - 0s 5ms/step - loss: 0.1372 - accuracy: 0.9664
saved weights.
loaded weights!
 1/79 [..............................] - ETA: 6s - loss: 0.0620 - accuracy: 0.9844
11/79 [===>..........................] - ETA: 0s - loss: 0.1625 - accuracy: 0.9616
21/79 [======>.......................] - ETA: 0s - loss: 0.1902 - accuracy: 0.9576
30/79 [==========>...................] - ETA: 0s - loss: 0.1884 - accuracy: 0.9581
39/79 [=============>................] - ETA: 0s - loss: 0.1914 - accuracy: 0.9559
49/79 [=================>............] - ETA: 0s - loss: 0.1724 - accuracy: 0.9600
59/79 [=====================>........] - ETA: 0s - loss: 0.1522 - accuracy: 0.9640
69/79 [=========================>....] - ETA: 0s - loss: 0.1404 - accuracy: 0.9665
78/79 [============================>.] - ETA: 0s - loss: 0.1388 - accuracy: 0.9663
79/79 [==============================] - 1s 6ms/step - loss: 0.1372 - accuracy: 0.9664

2.save/load entire model

在这里插入图片描述

实战
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics

# 数据预处理
def preprocess(x, y):
    """
    x is a simple image, not a batch
    """
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = tf.reshape(x, [28 * 28])
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x, y


batchsz = 128
# 数据集加载
(x, y), (x_val, y_val) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max())

db = tf.data.Dataset.from_tensor_slices((x, y))
db = db.map(preprocess).shuffle(60000).batch(batchsz)

ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(preprocess).batch(batchsz)

sample = next(iter(db))
print(sample[0].shape, sample[1].shape)

network = Sequential([layers.Dense(256, activation='relu'),
                      layers.Dense(128, activation='relu'),
                      layers.Dense(64, activation='relu'),
                      layers.Dense(32, activation='relu'),
                      layers.Dense(10)])
network.build(input_shape=(None, 28 * 28))
network.summary()

network.compile(optimizer=optimizers.Adam(lr=0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy']
                )

network.fit(db, epochs=3, validation_data=ds_val, validation_freq=2)

network.evaluate(ds_val)

# 保存整个模型
network.save('model.h5')
print('saved total model.')
del network

print('loaded model from file.')
# 加载整个模型
network = tf.keras.models.load_model('model.h5', compile=False)
network.compile(optimizer=optimizers.Adam(lr=0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy']
                )

x_val = tf.cast(x_val, dtype=tf.float32) / 255.
x_val = tf.reshape(x_val, [-1, 28 * 28])
y_val = tf.cast(y_val, dtype=tf.int32)
y_val = tf.one_hot(y_val, depth=10)

ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(128)
network.evaluate(ds_val)
3.saved_model-用于工业环境的部署

在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值