保存整个自定义模型
最近由于自己电脑跑不动定义的模型,所以到kaggle上跑自己的模型
- 何为自定义模型
只要你的模型继承了tf.keras.Model,那么你的就算是自定义模型了
class D_cnn(tf.keras.Model):
- 如何保存训练好的模型
通过tf.saved_model.save(netwok, path)
其中network代表你的模型的实例化,path自己定义路径,(记住path保存没有文件格式,只需要给出路径,例如“./model\my_model”
tf.saved_model.save(network, 'mymodel/')
- 这个地方需要注意
需要在自定义的模型call()方法处,利用@tf.function修饰,原因在于是图运算
@tf.function
def call(self, inputs, training=None):
pre = self.decision(inputs)
return pre
- 最后加载模型
network = tf.saved_model.load("mymodel")
5.调用模型
这个地方调用的时候,不能像以前一样用魔法方法
而是需要显示调用call().
pre = network.call()