How to process Conv weight during the model trainning

Tensorflow如何在训练的时候对卷积的权重进行一些特殊处理

比如对卷积的权重进行求均值,取最大值等等自定义的操作。

步骤如下

  1. 获取想要操作的的卷积核权重,保存到一个list中
  2. 申请一个同该卷核相同shape的占位符;
  3. 使用tf.assign()创建一个更新卷积核权重的操作op;

整体代码

# step 1 and step 2
originalConvList = []
processConvList = []
for weight in tf.trainable_variables():
	# conv2d/kernel:0 You want to process conv name
	if weight.name.endswith('conv2d/kernel:0'): 
		originalConvList.append(weight)
		# Init the weight placeholer same to the original weight
		newConvWeight = tf.placeholder(dtype=tf.float32, shape = weight.shape)
		# Creat the update conv weight operation by the tf.assign()
		processConvList.append([tf.assign(weight, newConvWeight), newConvWeight])
"""
# The normal model training part
...
"""
# step 3
for convIndex in range(len(originalConvList)):
	# Get the corresponding conv weight from the originalConvList
	convWeight = sess.run(originalConvList[convIndex])
	"""
	# You can process the convWeight
	# Make sure keep the convWeight dimentions
	...
	ie. convTmp = np.mean(convWeight[:,:,:,:], axis=(0,1), keepdims = True)
	"""
	# Run the assign operation to update the convWeight
	see.run(processConvList[convIndex][0], feed_dict={newConvWeight[convIndex][1]:convWeight})
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值