slim.batch_norm无法更新以及保存参数

1、更新参数

当我们使用batch_norm时,slim.batch_norm中的moving_mean和moving_variance是无法更新的,当is_training = True时,意味着创建Update ops,利用当前batch的均值和方差去更新moving averages(即某层累计的平均均值和方差)。这里提供两种方式创建update_ops,

 

一是自己显式的创建update_ops,手动更新。update_ops默认放置在tf.GraphKeys.UPDATE_OPS中,因此这里在执行train_ops的同时更新均值方差即可,对于单卡来说很容易理解,对于多卡来说,相当于collection所有卡的batch的均值方差后统一更新,也可以只collection第一块卡的均值方差(理论上需要积累其他卡,但是由于这操作积累得很快,所以只取第一块卡也不影响性能,在TensorFlow高阶API的样例代码cifar10_main.py中如是说)。代码如下:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
      train_op = optimizer.minimize(loss)

二是自动的更新, 只需在初始化前 bn = BatchNorm(update_ops_collection=None)即可。不过这种方式下,会在完成更新前阻塞网络的forward,因此会带来时间上的成本。具体而言,这时bn的参数mean,var是立即更新的,也是计算完当前layer的mean,var就更新,然后进行下一个layer的操作。这在单卡下没有问题的, 但是多卡情况下就会写等读的冲突,因为可能存在GPU0更新(写)mean但此时GPU1还没有计算到该层,所以GPU0就要等GPU1读完mean才能写。

update_ops = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS))
train_op = tf.group(train_op, update_ops)

 

 

2、保存参数

 

当我们使用batch_norm时,slim.batch_norm中的moving_mean和moving_variance不是trainable的,所以使用saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3)无法保存,应该改为:

var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
var_list += bn_moving_vars
saver = tf.train.Saver(var_list=var_list, max_to_keep=3)

 

 

 

 

 

 

  • 3
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值