TensoFlow slim fine tune "BatchNorm/gamma/RMSProp_1 not found in checkpoint"

在用slim的model zoo fine tune的时候,经常想先tune fc layer,在tune conv layer,

但是如果调完fc直接free 所有conv layer的时候,会出现"BatchNorm/gamma/RMSProp_1 not found in checkpoint",意思就是优化器RMSprop的参数没有找到,一个解决办法是吧优化器换成没有额外参数的SGD,但这个只是曲线救国


在https://github.com/tensorflow/models/issues/1836找到了原因,因为之前调fc是freeze了前面的conv layer,自然也对这些layer的优化器参数做了某些处理,等到free conv的时候,因为指定的是同一个train_dir,slim会自动load这里面的ckpt,也就是说load进来的model的conv layer是没有RMSprop参数的,但是free conv的时候又要train所有layer,自然就会报找不到参数

如果换一个train_dir就OK了,应该是因为换了一种load parameter参数的方法吧

Make sure that the path --train_dir is initialized when you did your training.
I had the same problem before, it was because in my --train_dir there was an checkpoint file of a different model, so when you run with the current model, this wrong checkpoint file will be searched and used instead .


def _get_init_fn():
  """Returns a function run by the chief worker to warm-start the training.

  Note that the init_fn is only run when initializing the model during the very
  first global step.

  Returns:
    An init function run by the supervisor.
  """
  if FLAGS.checkpoint_path is None:
    return None

  # Warn the user if a checkpoint exists in the train_dir. Then we'll be
  # ignoring the checkpoint anyway.
  if tf.train.latest_checkpoint(FLAGS.train_dir):
    tf.logging.info(
        'Ignoring --checkpoint_path because a checkpoint already exists in %s'
        % FLAGS.train_dir)
    return None

  exclusions = []
  if FLAGS.checkpoint_exclude_scopes:
    exclusions = [scope.strip()
                  for scope in FLAGS.checkpoint_exclude_scopes.split(',')]

  # TODO(sguada) variables.filter_variables()
  variables_to_restore = []
  for var in slim.get_model_variables():
    excluded = False
    for exclusion in exclusions:
      if var.op.name.startswith(exclusion):
        excluded = True
        break
    if not excluded:
      variables_to_restore.append(var)

  if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
    checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
  else:
    checkpoint_path = FLAGS.checkpoint_path

  tf.logging.info('Fine-tuning from %s' % checkpoint_path)

  return slim.assign_from_checkpoint_fn(
      checkpoint_path,
      variables_to_restore,
      ignore_missing_vars=FLAGS.ignore_missing_vars)





  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值