引言
最近在使用keras自定义网络层时,保存模型时出现了很多bug,困扰了好几天。现在把他总结如下。
keras作为一个封装好的文件直接可以被调用,一直受到初学者的追捧,但是,封装好的模块可能满足不了需求,keras自定义网络层时必须遵守‘八股文’规则所以很受人诟病,下面总结一下在自定义网络层时的bug。
解决方案
- 在使用keras的回调函数checkpoint方式保存模型时要设置参数save_weights_only = False,否则默认只保存权重,不保存网络结构,在调用时还要重新加载网络结构才能使用.hdf5文件,详情参考
https://zhuanlan.zhihu.com/p/86886620 - 这里默认你已经创造好自己的模型,并已经训练好了,保存了’.hdf5’文件,在重新调用模型时(加载hdf5文件)需要引入你自定义的网络层(from xxx import xxx),再者使用
load_model(filepath,custom_objects={'TCN':TCN,'Attention_layer':Attention_layer(W_regularizer=None, b_regularizer=None,W_constraint=None, b_constraint=None,bias=True))
- 需要注意的是:要加上网络层自定义的参数值