最近在做一个文本分类工具,功能包括上传样本,使用样本训练model,save训练好的model并且使用model对文本进行分类。
用到框架有Keras和Django。
训练阶段将训练好的模型保存到指定目录。预测阶段加载训练好的模型进行预测(每一次预测都需要加载模型)。
第一次预测的时候是没有问题的,可以正常预测,第二次预测的时候报出了如下错误:
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor(“Placeholde r:0”, shape=(40001,50), dtype=float32) is not an element of this graph.
在网上找了一下,发现很多人都遇到过这个问题,解决方案大体分为两种:
1. 在加载模型前加上keras.backend.clear_session()
clear_session()的作用是结束当前的TF计算图并新建一个。
经过尝试这种的方法并不能解决我的问题。
2.在初始化加载模型之后,就随便生成一个向量让 model 执行一次 predict 函数
经过尝试这种方法是可行的。
但是,我的项目逻辑是在训练出模型后才能进行加载,并不能在初始化的时候加载模型,
因此,这种方法并不解决的我的问题。
最后是我的解决方法,在训练阶段保存模型后和预测阶段预测结束后手动清空内存。
如果还是不行,建议加上keras.backend.clear_session() (上述第一种方法)
添加如下代码:
import gc # training model = "自己的模型" model.save('model/model.h5') del model #删除model gc.collect() #手动清理内存 # prediction model = load_model('model/model.h5') result = model.predict(test) del model gc.collect()