文章目录
一、保存网络权重
.save_weights()
保存网络中所有的 w1,b1,w2,b2…,一些其他细节并不保存
.load_weights()
加载权重文件
del net_name : 删除网络 net_name
删除网络后,要想恢复原始模型,必须建立与之前相同的Sequential
network.save_weights('weights.ckpt')
print('saved weights.')
del network
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
network.load_weights('weights.ckpt')
print('loaded weights!')
二、保存网络模型
.save()
保存模型的所有细节,删掉网络后,只需要恢复保存的模型就可使用,不用新建squential。
.load_model()
network.save('model.h5')
print('saved total model.')
del network
print('loaded model from file.')
network = tf.keras.models.load_model('model.h5', compile=False)
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)