How to process Conv weight during the model trainning

Tensorflow如何在训练的时候对卷积的权重进行一些特殊处理

比如对卷积的权重进行求均值,取最大值等等自定义的操作。

步骤如下

  1. 获取想要操作的的卷积核权重,保存到一个list中
  2. 申请一个同该卷核相同shape的占位符;
  3. 使用tf.assign()创建一个更新卷积核权重的操作op;

整体代码

# step 1 and step 2
originalConvList = []
processConvList = []
for weight in tf.trainable_variables():
	# conv2d/kernel:0 You want to process conv name
	if weight.name.endswith('conv2d/kernel:0'): 
		originalConvList.append(weight)
		# Init the weight placeholer same to the original weight
		newConvWeight = tf.placeholder(dtype=tf.float32, shape = weight.shape)
		# Creat the update conv weight operation by the tf.assign()
		processConvList.append([tf.assign(weight, newConvWeight), newConvWeight])
"""
# The normal model training part
...
"""
# step 3
for convIndex in range(len(originalConvList)):
	# Get the corresponding conv weight from the originalConvList
	convWeight = sess.run(originalConvList[convIndex])
	"""
	# You can process the convWeight
	# Make sure keep the convWeight dimentions
	...
	ie. convTmp = np.mean(convWeight[:,:,:,:], axis=(0,1), keepdims = True)
	"""
	# Run the assign operation to update the convWeight
	see.run(processConvList[convIndex][0], feed_dict={newConvWeight[convIndex][1]:convWeight})
发布了49 篇原创文章 · 获赞 14 · 访问量 19万+
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 大白 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览