模型的保存与加载
1.load/save_weights
实战
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
def preprocess(x, y):
"""
x is a simple image, not a batch
"""
x = tf.cast(x, dtype=tf.float32) / 255.
x = tf.reshape(x, [28 * 28])
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x, y
batchsz = 128
(x, y), (x_val, y_val) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max())
db = tf.data.Dataset.from_tensor_slices((x, y))
db = db.map(preprocess).shuffle(60000).batch(batchsz)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(preprocess).batch(batchsz)
sample = next(iter(db))
print(sample[0].shape, sample[1].shape)
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.build(input_shape=(None, 28 * 28))
network.summary()
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
network.fit(db, epochs=3, validation_data=ds_val, validation_freq=2)
network.evaluate(ds_val)
# 保存模型的参数
network.save_weights('weights.ckpt')
print('saved weights.')
del network
# 构建多层网络
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
# 加载模型的参数
network.load_weights('weights.ckpt')
print('loaded weights!')
network.evaluate(ds_val)
Epoch 2/3
1/469 [..............................] - ETA: 13:50 - loss: 0.1488 - accuracy: 0.9688
19/469 [>.............................] - ETA: 43s - loss: 0.1262 - accuracy: 0.9634
39/469 [=>............................] - ETA: 20s - loss: 0.1342 - accuracy: 0.9619
57/469 [==>...........................] - ETA: 13s - loss: 0.1324 - accuracy: 0.9635
76/469 [===>..........................] - ETA: 10s - loss: 0.1379 - accuracy: 0.9623
93/469 [====>.........................] - ETA: 8s - loss: 0.1348 - accuracy: 0.9632
110/469 [======>.......................] - ETA: 6s - loss: 0.1370 - accuracy: 0.9618
130/469 [=======>......................] - ETA: 5s - loss: 0.1375 - accuracy: 0.9615
150/469 [========>.....................] - ETA: 4s - loss: 0.1384 - accuracy: 0.9618
169/469 [=========>....................] - ETA: 3s - loss: 0.1384 - accuracy: 0.9614
187/469 [==========>...................] - ETA: 3s - loss: 0.1369 - accuracy: 0.9619
207/469 [============>.................] - ETA: 2s - loss: 0.1385 - accuracy: 0.9612
229/469 [=============>................] - ETA: 2s - loss: 0.1387 - accuracy: 0.9612
251/469 [===============>..............] - ETA: 2s - loss: 0.1393 - accuracy: 0.9610
274/469 [================>.............] - ETA: 1s - loss: 0.1388 - accuracy: 0.9615
297/469 [=================>............] - ETA: 1s - loss: 0.1378 - accuracy: 0.9616
319/469 [===================>..........] - ETA: 1s - loss: 0.1373 - accuracy: 0.9618
342/469 [====================>.........] - ETA: 0s - loss: 0.1366 - accuracy: 0.9621
363/469 [======================>.......] - ETA: 0s - loss: 0.1356 - accuracy: 0.9624
385/469 [=======================>......] - ETA: 0s - loss: 0.1362 - accuracy: 0.9623
407/469 [=========================>....] - ETA: 0s - loss: 0.1358 - accuracy: 0.9624
429/469 [==========================>...] - ETA: 0s - loss: 0.1350 - accuracy: 0.9627
450/469 [===========================>..] - ETA: 0s - loss: 0.1342 - accuracy: 0.9629
466/469 [============================>.] - ETA: 0s - loss: 0.1343 - accuracy: 0.9629
469/469 [==============================] - 3s 7ms/step - loss: 0.1344 - accuracy: 0.9629 - val_loss: 0.1209 - val_accuracy: 0.9648
Epoch 3/3
1/469 [..............................] - ETA: 14:16 - loss: 0.1254 - accuracy: 0.9609
20/469 [>.............................] - ETA: 42s - loss: 0.1014 - accuracy: 0.9695
39/469 [=>............................] - ETA: 21s - loss: 0.1063 - accuracy: 0.9700
60/469 [==>...........................] - ETA: 13s - loss: 0.1006 - accuracy: 0.9703
82/469 [====>.........................] - ETA: 9s - loss: 0.1041 - accuracy: 0.9690
105/469 [=====>........................] - ETA: 7s - loss: 0.1089 - accuracy: 0.9676
128/469 [=======>......................] - ETA: 5s - loss: 0.1072 - accuracy: 0.9684
151/469 [========>.....................] - ETA: 4s - loss: 0.1056 - accuracy: 0.9692
171/469 [=========>....................] - ETA: 3s - loss: 0.1089 - accuracy: 0.9688
189/469 [===========>..................] - ETA: 3s - loss: 0.1094 - accuracy: 0.9688
208/469 [============>.................] - ETA: 2s - loss: 0.1122 - accuracy: 0.9681
228/469 [=============>................] - ETA: 2s - loss: 0.1099 - accuracy: 0.9687
250/469 [==============>...............] - ETA: 2s - loss: 0.1093 - accuracy: 0.9691
270/469 [================>.............] - ETA: 1s - loss: 0.1088 - accuracy: 0.9692
291/469 [=================>............] - ETA: 1s - loss: 0.1081 - accuracy: 0.9696
312/469 [==================>...........] - ETA: 1s - loss: 0.1079 - accuracy: 0.9700
334/469 [====================>.........] - ETA: 1s - loss: 0.1082 - accuracy: 0.9700
356/469 [=====================>........] - ETA: 0s - loss: 0.1086 - accuracy: 0.9699
378/469 [=======================>......] - ETA: 0s - loss: 0.1083 - accuracy: 0.9699
401/469 [========================>.....] - ETA: 0s - loss: 0.1071 - accuracy: 0.9700
422/469 [=========================>....] - ETA: 0s - loss: 0.1081 - accuracy: 0.9698
441/469 [===========================>..] - ETA: 0s - loss: 0.1089 - accuracy: 0.9697
459/469 [============================>.] - ETA: 0s - loss: 0.1083 - accuracy: 0.9700
469/469 [==============================] - 3s 6ms/step - loss: 0.1082 - accuracy: 0.9701
1/79 [..............................] - ETA: 0s - loss: 0.0620 - accuracy: 0.9844
11/79 [===>..........................] - ETA: 0s - loss: 0.1625 - accuracy: 0.9616
21/79 [======>.......................] - ETA: 0s - loss: 0.1902 - accuracy: 0.9576
32/79 [===========>..................] - ETA: 0s - loss: 0.1910 - accuracy: 0.9570
41/79 [==============>...............] - ETA: 0s - loss: 0.1845 - accuracy: 0.9573
50/79 [=================>............] - ETA: 0s - loss: 0.1695 - accuracy: 0.9605
60/79 [=====================>........] - ETA: 0s - loss: 0.1499 - accuracy: 0.9645
70/79 [=========================>....] - ETA: 0s - loss: 0.1389 - accuracy: 0.9667
79/79 [==============================] - 0s 5ms/step - loss: 0.1372 - accuracy: 0.9664
saved weights.
loaded weights!
1/79 [..............................] - ETA: 6s - loss: 0.0620 - accuracy: 0.9844
11/79 [===>..........................] - ETA: 0s - loss: 0.1625 - accuracy: 0.9616
21/79 [======>.......................] - ETA: 0s - loss: 0.1902 - accuracy: 0.9576
30/79 [==========>...................] - ETA: 0s - loss: 0.1884 - accuracy: 0.9581
39/79 [=============>................] - ETA: 0s - loss: 0.1914 - accuracy: 0.9559
49/79 [=================>............] - ETA: 0s - loss: 0.1724 - accuracy: 0.9600
59/79 [=====================>........] - ETA: 0s - loss: 0.1522 - accuracy: 0.9640
69/79 [=========================>....] - ETA: 0s - loss: 0.1404 - accuracy: 0.9665
78/79 [============================>.] - ETA: 0s - loss: 0.1388 - accuracy: 0.9663
79/79 [==============================] - 1s 6ms/step - loss: 0.1372 - accuracy: 0.9664
2.save/load entire model
实战
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
# 数据预处理
def preprocess(x, y):
"""
x is a simple image, not a batch
"""
x = tf.cast(x, dtype=tf.float32) / 255.
x = tf.reshape(x, [28 * 28])
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x, y
batchsz = 128
# 数据集加载
(x, y), (x_val, y_val) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max())
db = tf.data.Dataset.from_tensor_slices((x, y))
db = db.map(preprocess).shuffle(60000).batch(batchsz)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(preprocess).batch(batchsz)
sample = next(iter(db))
print(sample[0].shape, sample[1].shape)
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.build(input_shape=(None, 28 * 28))
network.summary()
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
network.fit(db, epochs=3, validation_data=ds_val, validation_freq=2)
network.evaluate(ds_val)
# 保存整个模型
network.save('model.h5')
print('saved total model.')
del network
print('loaded model from file.')
# 加载整个模型
network = tf.keras.models.load_model('model.h5', compile=False)
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
x_val = tf.cast(x_val, dtype=tf.float32) / 255.
x_val = tf.reshape(x_val, [-1, 28 * 28])
y_val = tf.cast(y_val, dtype=tf.int32)
y_val = tf.one_hot(y_val, depth=10)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(128)
network.evaluate(ds_val)