最近在使用TFLearn来载入AffectNet的TrainedNetwork,采用深度学习提取Valence & Arousal。关于AffectNet是IEEE Transactions on Affective Computing, 2017的论文成果,全名是《AffectNet: A Database for Facial Expression, Valence, and Arousal Computing in the Wild》,作者公开了database和model,可以到项目网页来获取信息和申请数据库和模型,论文可以在arXiv或网页上获取到。
在通过model.load()的方法载入时报了一个非常长的错,看起来非常吓人,核心来说就是:
NotFoundError: Key ResNeXtBlock/BatchNormalization/is_training not found in checkpoint
查了许多资料,发现出现这样的原因,主要就是因为模型和代码不匹配。
通过inspect_checkpoint查看checkpoint模型中的网络结构,代码:
import tensorflow as tf
from tensorflow.python.tools import inspect_checkpoint as chkp
chkp.print_tensors_in_checkpoint_file("./model_resnet_-332000", tensor_name=None, all_tensors=True)
刚开头就显示出来关于is_training的信息
这说明之前报错中的is_training在模型中是False的状态,那在代码中是什么样的呢?
找到代码中的关键一句函数,也是报错中有显示到的batch_normalization,这两者关联在一起,出错在这里的概率最大
net = tflearn.batch_normalization(net)
查看TFLearn关于batch_normalization的文档说明
tflearn.layers.normalization.batch_normalization (incoming, beta=0.0, gamma=1.0, epsilon=1e-05, decay=0.9, stddev=0.002, trainable=True, restore=True, reuse=False, scope=None, name='BatchNormalization')
Normalize activations of the previous layer at each batch.
Arguments
- trainable:
bool
. If True, weights will be trainable.
可以看出,有trainable参数,且默认值为True,这就跟之前模型中的is_training为False有冲突了。
把这一句话修改一下
net = tflearn.batch_normalization(net, trainable=False)
其他不修改,load的部分就没有报错执行成功了!!
另外,TFLearn存下的模型最好还是用TFLearn来载入,用TensorFlow的saver.restore()方法也会报错,目前还不知道如何解决,如果有思路的欢迎留言!(参考TensorFlow: NotFoundError: Key not found in checkpoint)
KeyError: "The name 'SGD' refers to an Operation not in the graph.
最后,把载入部分的完整代码结构放出来,出于对AffectNet的版权考虑,如需要完整数据库和模型,请联系作者,这里也将核心代码略去,只保留对解决这个问题最关键的一行代码。
from __future__ import division, print_function, absolute_import
import tensorflow as tf
import tflearn
import numpy as np
# Core part of training code of AffectNet. Please contact the author to request the official version.
# Important change of code to solve the problem.
net = tflearn.batch_normalization(net, trainable=False)
# The rest of the training code...
# Loading trained model
model.load("./Valence_Arousal/meta-data/model_resnet_-332000")
# Model loaded and you can do whatever you want now!