存储模型train.py: model.save('model_weight.h5')
在predict.py中,使用model = load_model("model_weight.h5")对模型进行加载的时报错信息:
- Unknown Layer: LayerName。此处的LayerName代指自定义的layer。
- global name 'tf' is not defined
正确加载方式:
- 声明自定义的类,并创建实例。
- model = load_model("model_weight.h5", custom_objects={'tf': tf, 'Self_Attention': Self_Attention_shili, "local_Attention":local_Attention_shili}) ;将自己定义的类的名称和实例传进去。
- 如果自定义的类中,存在参数没有设置初始默认值,则会报错TypeError: init() missing 1 required positional argument: 'XXX'。解决方法:给一个初始值,需要和训练时候的参数维度一致。
参考链接: