一、报错提示
NotImplementedError: Layer XX has arguments in __init__
and therefore must override get_config
.
(XX表示自定义的CLASS)
二、错误原因
使用save方法后,未重新在Class中自定义属性
三、原因
模型保存有两种接口,save和save_weights方法。
区别如下:
save:保存网络模型图结构和参数。
save_weights:仅保存网络模型的参数。
如果使用save方法,自定义的Class里面的声明需要get_config重新配置声明一遍,否则Tensorflow无法保存模型的图结构(至于具体为什么,我就不深究了,毕竟用的人家的接口,只是知道了这个特点)。
如果不想重写,可以用save_weights暂时解决参数的保存的问题,也就不会有上述错误了。但毕竟save()更全面,为了后期的省事,还是推荐使用它(虽然稍占内存)。
四、解决方案
解决save()报错的具体操作如下:
在定义的Class里,增加一个get_config函数用于配置的更新(具体操作如下)。
其中__init__中的声明过的所有属性,需要在get_config函数中update一下。见下图圈中部分。
get_config模板如下,替换掉再跑就好了。
def get_config(self):
config = super().get_config().copy()
config.update({
'属性1': self.属性1,
'属性2': self.属性2,
'属性3': self.属性3,
})
return config
再次运行,模型保存成功!