在保存时,自己写的call函数要加
@tf.function(input_signature=[ [tf.TensorSpec(shape=[None, None], dtype=tf.int32), tf.TensorSpec(shape=[None, None], dtype=tf.int32)]])
这个input_signature只和输入有关,training这种状态码,不需要写在这里,使用时直接加上training=True/False即可
使用
tf.summary.trace_on()
训练语句
with summary_writer.as_default():
tf.summary.trace_export(name="model_trace", step=0)
将模型图写入tensorboard的graph中
使用model.save('model', save_format="tf")来保存自定义模型
使用model = keras.models.load_model("model")来恢复自定义模型