TFLearn载入模型NotFoundError: Key ...BatchNormalization/is_training not found in checkpoint解决方法

最近在使用TFLearn来载入AffectNet的TrainedNetwork,采用深度学习提取Valence & Arousal。关于AffectNet是IEEE Transactions on Affective Computing, 2017的论文成果,全名是《AffectNet: A Database for Facial Expression, Valence, and Arousal Computing in the Wild》,作者公开了database和model,可以到项目网页来获取信息和申请数据库和模型,论文可以在arXiv网页上获取到。

在通过model.load()的方法载入时报了一个非常长的错,看起来非常吓人,核心来说就是:

NotFoundError: Key ResNeXtBlock/BatchNormalization/is_training not found in checkpoint

查了许多资料,发现出现这样的原因,主要就是因为模型和代码不匹配。

通过inspect_checkpoint查看checkpoint模型中的网络结构,代码:

import tensorflow as tf
from tensorflow.python.tools import inspect_checkpoint as chkp

chkp.print_tensors_in_checkpoint_file("./model_resnet_-332000", tensor_name=None, all_tensors=True)

刚开头就显示出来关于is_training的信息

这说明之前报错中的is_training在模型中是False的状态,那在代码中是什么样的呢?

找到代码中的关键一句函数,也是报错中有显示到的batch_normalization,这两者关联在一起,出错在这里的概率最大

net = tflearn.batch_normalization(net)

查看TFLearn关于batch_normalization的文档说明

tflearn.layers.normalization.batch_normalization (incoming, beta=0.0, gamma=1.0, epsilon=1e-05, decay=0.9, stddev=0.002, trainable=True, restore=True, reuse=False, scope=None, name='BatchNormalization')

Normalize activations of the previous layer at each batch.

Arguments

  • trainablebool. If True, weights will be trainable.

可以看出,有trainable参数,且默认值为True,这就跟之前模型中的is_training为False有冲突了。

把这一句话修改一下

net = tflearn.batch_normalization(net, trainable=False)

其他不修改,load的部分就没有报错执行成功了!!

另外,TFLearn存下的模型最好还是用TFLearn来载入,用TensorFlow的saver.restore()方法也会报错,目前还不知道如何解决,如果有思路的欢迎留言!(参考TensorFlow: NotFoundError: Key not found in checkpoint

KeyError: "The name 'SGD' refers to an Operation not in the graph.

最后,把载入部分的完整代码结构放出来,出于对AffectNet的版权考虑,如需要完整数据库和模型,请联系作者,这里也将核心代码略去,只保留对解决这个问题最关键的一行代码。

from __future__ import division, print_function, absolute_import
import tensorflow as tf
import tflearn

import numpy as np

# Core part of training code of AffectNet. Please contact the author to request the official version.

# Important change of code to solve the problem.
net = tflearn.batch_normalization(net, trainable=False)

# The rest of the training code...

# Loading trained model
model.load("./Valence_Arousal/meta-data/model_resnet_-332000")

# Model loaded and you can do whatever you want now!

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值