在用SSD-Tensorflow训练自己数据的时候,遇到这样一个问题:
InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [8] rhs shape= [84]
[[Node: save/Assign_15 = Assign[T=DT_FLOAT, _class=["loc:@ssd_300_vgg/block10_box/conv_cls/biases"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](ssd_300_vgg/block11_box/conv_cls/biases, save/RestoreV2_15)]]
原因:
在train_ssd_network.py文件的slim.learning.train函数中调用了tf_utils.get_init_fn(FLAGS)用于初始化网络参数,如果存在train_dir文件夹则函数直接返回None,不会执行下面的checkpoint_exclude_scopes操作,以至于出现这个错误。
SSD-Tensorflow部分训练代码:
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
config = tf.ConfigProto(log_device_placement=False,
gpu_options=gpu_options)
saver = tf.train.Saver(max_to_keep=5,
keep_checkpoint_every_n_hours=1.0,
write_version=2,
pad_step_number=False)
slim.learning.train(
train_tensor,
logdir=FLAGS.train_dir,
master='',
is_chief=True,
init_fn=tf_utils.get_init_fn(FLAGS), # 初始化参数
summary_op=summary_op, # tf.summary.merge节点
number_of_steps=FLAGS.max_number_of_steps, # 训练step
log_every_n_steps=FLAGS.log_every_n_steps, # 输出训练信息间隔
save_summaries_secs=FLAGS.save_summaries_secs, # 每次summary时间间隔
saver=saver, # tf.train.Saver节点
save_interval_secs=FLAGS.save_interval_secs, # 每次model保存step间隔
session_config=config, # sess参数
sync_optimizer=None)
调用的初始化函数代码:
def get_init_fn(flags):
"""Returns a function run by the chief worker to warm-start the training.
Note that the init_fn is only run when initializing the model during the very
first global step.
Returns:
An init function run by the supervisor.
"""
if flags.checkpoint_path is None:
return None
# Warn the user if a checkpoint exists in the train_dir. Then ignore.
if tf.train.latest_checkpoint(flags.train_dir):
tf.logging.info(
'Ignoring --checkpoint_path because a checkpoint already exists in %s'
% flags.train_dir)
return None # 问题就出现在这里!
exclusions = []
if flags.checkpoint_exclude_scopes:
exclusions = [scope.strip()
for scope in flags.checkpoint_exclude_scopes.split(',')]
# TODO(sguada) variables.filter_variables()
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
# Change model scope if necessary.
if flags.checkpoint_model_scope is not None:
variables_to_restore = \
{var.op.name.replace(flags.model_name,
flags.checkpoint_model_scope): var
for var in variables_to_restore}
if tf.gfile.IsDirectory(flags.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(flags.checkpoint_path)
else:
checkpoint_path = flags.checkpoint_path
tf.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' % (checkpoint_path, flags.ignore_missing_vars))
return slim.assign_from_checkpoint_fn(
checkpoint_path,
variables_to_restore,
ignore_missing_vars=flags.ignore_missing_vars)
解决方法:
删除train_dir文件夹。