1、model.save() and model.load()
此种方法可保存模型的结构、参数等内容。加载模型后无需设置即可使用!
保存模型:
model.save('my_model.h5')
加载模型:
# 加载整个模型
loaded_model = tf.keras.models.load_model('my_model.h5')
注意,创建的模型不能使用自定义的loss函数等方法,否则导入时会出错!
示例:
model_file = "data/model/multi_labels_model.h5" # 模型文件路径
def model_handle(x_train, y_train):
if os.path.exists(model_file):
print("---load the model---")
model = tf.keras.models.load_model(model_file) # 导入已存在的模型
else:
# 模型构建
model = tf.keras.Sequential([
tf.keras.layers.LSTM(128),
tf.keras.layers.Dense(class_num, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
])
# 编译模型,不能使用自定义函数方法,否则导入模型会有问题
model.compile(loss="BinaryCrossentropy", optimizer='adam', metrics=['accuracy'])
history = model.fit(x_train, y_train, epochs=epoch_num, batch_size=1, verbose=1,
callbacks=[PrintPredictionsCallback(x_train, y_train)])
model.summary()
model.save(model_file)
return model
2、model.save_weight() and model.load_weight()
此方法只保存和加载模型的权重。
保存权重:
# 只保存权重
model.save_weights('my_model_weights.h5')
加载权重:
# 创建一个新的模型实例(确保架构与原始模型相同)
new_model = tf.keras.models.Sequential([
tf.keras.layers.Dense(10, activation='relu', input_shape=(32,)),
tf.keras.layers.Dense(1)
])
# new_model.build(input_shape=x_train.shape) # 如果模型创建时没有规定input_shape,需要创建
# 加载权重到新模型
new_model.load_weights('my_model_weights.h5')
此方法的模型可以使用自定义的函数方法。
注意:以H5格式加载子类模型的参数时,需要提前建立模型,规定输入网络的shape,否则会报错!
ValueError: Unable to load weights saved in HDF5 format into a subclassed Model which has not created its variables yet. Call the Model first, then load the weights.
示例:
def model_handle(x_train, y_train):
# 模型构建,多分类的激活函数使用sigmoid 或 softmax
model = tf.keras.Sequential([
tf.keras.layers.LSTM(128),
tf.keras.layers.Dense(class_num, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
])
if os.path.exists(model_file):
print("-----load model weights-----")
model.build(input_shape=x_train.shape) # 以H5格式加载子类模型的参数时,需要提前建立模型,规定输入网络的shape,否则会报错
model.load_weights(model_file)
else:
# 编译模型,使用自定义loss函数
model.compile(loss=custom_loss, optimizer='adam', metrics=['accuracy'])
# model.compile(loss="BinaryCrossentropy", optimizer='adam', metrics=['accuracy'])
history = model.fit(x_train, y_train, epochs=epoch_num, batch_size=1, verbose=1,
callbacks=[PrintPredictionsCallback(x_train, y_train)])
model.summary()
model.save_weights(model_file)
return model
3、model.checkpoint
主要是用于模型的断点续训。用法参考如下:
checkpoint_save_path = "./checkpoint/my_checkpoint.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
print('-------------load the model-----------------')
model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True,
monitor='val_loss')
history = model.fit(x_train, y_train, batch_size=64, epochs=50, validation_data=(x_test, y_test), validation_freq=1,
callbacks=[cp_callback])
model.summary()