中兴捧月算法大赛--剪枝部分代码

因为比赛过程中,不停尝试新的想法,所以代码比较乱!有时间会抽空整理一下!

# coding:utf-8
import caffe
import numpy as np
import collections

class Pruner(object):
	def __init__(self, net):
		self._net = net
		self.conv_data = collections.OrderedDict()

	def _prune(self, conv_param, del_kernels=None, not_del_filters=False, th=0.0001):
		weight, bias = conv_param
		weight = weight.data
		bias = bias.data
		origin_channels = weight.shape[0]
		# thre = 0.05

		# delete filters
		if not not_del_filters:
			abs_mean = np.abs(weight).mean(axis=(1, 2, 3))
			# abs_mean = np.sum(np.abs(weight), axis=(1, 2, 3))
			del_filters = np.where(abs_mean < th)[0]
			# abs_mean_sort = sorted(abs_mean)
			# del_filters_2 = np.where(abs_mean < abs_mean_sort[int(origin_channels*thre)])[0]
			# if len(del_filters_1) > len(del_filters_2):
			# 	del_filters = del_filters_1
			# else:
			# 	del_filters = del_filters_2
			weight = np.delete(weight, del_filters, axis=0)
			bias = np.delete(bias, del_filters, axis=0)
		else:
			del_filters = np.array([])

		# delete kernels
		if del_kernels is not None:
			weight = np.delete(weight, del_kernels, axis=1)

		return weight, bias, del_filters, origin_channels

	def prune_conv(self, name, bottom=None, not_del_filters=False, th=0.0001):
		if bottom is None:
			self.conv_data[name] = self._prune(self._net.params[name], not_del_filters=not_del_filters, th=th)
		else:
			self.conv_data[name] = self._prune(self._net.params[name], self.conv_data[bottom][2], not_del_filters=not_del_filters, th=th)

	def prune_fc(self, name):
		weight, bias = self._net.params[name]
		weight = weight.data
		bias = bias.data
		print(bias.shape)

		abs_sum = np.sum(np.abs(weight), axis=0)
		del_filters = np.where(abs_sum < 3)[0]
		print(del_filters)
		print(len(del_filters))
		weight = np.delete(weight, del_filters, axis=1)
		# bias = np.delete(bias, del_filters, axis=1)
		self.conv_data[name] = weight, bias

	def prune_copy(self, name, name_, bottom=None):
		weight, bias = self._net.params[name]
		weight = weight.data
		bias = bias.data
		origin_channels = weight.shape[0]

		del_filters = self.conv_data[name_][2]
		del_kernels = self.conv_data[bottom][2]
		weight = np.delete(weight, del_filters, axis=0)
		bias = np.delete(bias, del_filters, axis=0)
		if del_kernels is not None:
			weight = np.delete(weight, del_kernels, axis=1)
		self.conv_data[name] = weight, bias, del_filters, origin_channels

	def prune_bias(self, name, ind):
		weight, bias = self._net.params[name]
		bias = bias.data

		del_filters = ind
		bias = np.delete(bias, del_filters, axis=0)
		return bias

	def prune_id(self, name, ind, bottom=None, bias_l=None):
		weight, bias = self._net.params[name]
		weight = weight.data
		bias = bias.data
		origin_channels = weight.shape[0]

		del_filters = ind
		del_kernels = []
		if bottom is not None:
			del_kernels = self.conv_data[bottom][2]
		weight = np.delete(weight, del_filters, axis=0)
		bias = np.delete(bias, del_filters, axis=0)
		if bias_l is not None:
			for b in bias_l:
				bias = bias + b
		if del_kernels is not None:
			weight = np.delete(weight, del_kernels, axis=1)
		self.conv_data[name] = weight, bias, del_filters, origin_channels

	def prune_concat(self, name, bottoms, not_del_filters=False, is_delete=False):
		offsets = [0] + [self.conv_data[b][3] for b in bottoms]
		for i in range(1, len(offsets)):
			offsets[i] += offsets[i - 1]
		del_filters = [self.conv_data[b][2] + offsets[i] for i, b in enumerate(bottoms)]
		if is_delete:
			del_filters.append(range(192, 256))
		# if is_delete_1:
		# 	del_filters.append(range(64, 96))
		del_filters_new = np.concatenate(del_filters)
		self.conv_data[name] = self._prune(self._net.params[name], del_filters_new, not_del_filters=not_del_filters)

	def fc_concat(self, name, bottoms):
		assert ('fc' in name)
		weight, bias = self._net.params[name]
		weight = weight.data
		bias = bias.data

		offset_1 = 9*9*512
		offset_2 = 9*9
		del_filters = []
		for ind in self.conv_data[bottoms[0]][2]:
			del_filters += range(offset_2*ind, offset_2 * (ind+1))
		# for ind in self.conv_data[bottoms[1]][2]:
		# 	del_filters += range(offset_2*ind+offset_1, offset_2 * (ind+1)+offset_1)
		# for ind in self.conv_data[bottoms[1]][2]:
		# 	del_filters += range(offset_2*ind+offset_1*2, offset_2 * (ind+1)+offset_1*2)
		del_filters += range(offset_1 * 1, offset_1 * 2)
		del_filters += range(offset_1 * 2, offset_1 * 3)
		weight = np.delete(weight, del_filters, axis=1)
		self.conv_data[name] = weight, bias

	def save(self, new_model, output_weights):
		net2 = caffe.Net(new_model, caffe.TEST)
		for key in net2.params.keys():
			if key in self.conv_data:
				net2.params[key][0].data[...] = self.conv_data[key][0]
				net2.params[key][1].data[...] = self.conv_data[key][1]
			else:
				net2.params[key][0].data[...] = self._net.params[key][0].data
				net2.params[key][1].data[...] = self._net.params[key][1].data
		net2.save(output_weights)

	def test(self, net, input_image=" ", output_tensor="fc5_"):
		transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
		transformer.set_transpose('data', (2, 0, 1))
		transformer.set_channel_swap('data', (2, 1, 0))

		image = caffe.io.load_image(input_image)
		transformed_image = transformer.preprocess('data', (image * 255 - 127.5) * 0.0078125)
		net.blobs['data'].reshape(1, 3, 128, 128)
		net.blobs['data'].data[...] = transformed_image

		output = net.forward(end=output_tensor)
		return output[output_tensor]

	def get_feature_id(self, output_tensor):
		image_path = './images/5.JPEG'
		output = self.test(self._net, image_path, output_tensor)
		h = output.shape[2]
		w = output.shape[3]
		channels = output.shape[1]
		# assert channels % 3 == 0
		output = output.reshape(channels, h, w)
		# print(output_tensor.shape)
		sum_l1 = []
		ind_2_2 = []
		ind_4_2 = []
		ind_6_2 = []
		ind = []
		for i in range(0, channels):
			sum_ = np.sum(abs(output[i, :, :]))
			sum_l1.append(sum_)
			if (sum_ < 7.98):
				ind_2_2.append(i)
		ind.append(ind_2_2)

		# for i in range(channels / 3, 2*channels / 3):
		# 	sum_ = np.sum(abs(output[i, :, :]))
		# 	sum_l1.append(sum_)
		# 	if (sum_ < 0.0001):
		# 		ind_4_2.append(i-channels / 3)
		# ind.append(ind_4_2)
		#
		# for i in range(2*channels / 3, 3*channels / 3):
		# 	sum_ = np.sum(abs(output[i, :, :]))
		# 	sum_l1.append(sum_)
		# 	if (sum_ < 0.0001):
		# 		ind_6_2.append(i-2*channels / 3)
		# ind.append(ind_6_2)

		# ind_b = list(set(ind[0]).intersection(set(ind[1])))
		# ind_b = list(set(ind_b).intersection(set(ind[2])))
		print(len(ind))

		return ind


def main():
	caffe.set_mode_gpu()
	net = caffe.Net("./models/svd/TestModel.prototxt", "./models/svd/TestModel.caffemodel", caffe.TEST)
	# net = caffe.Net("./models/merge/merge_bn.prototxt", "./models/merge/merge_bn.caffemodel", caffe.TEST)
	pruner = Pruner(net)
	pruner.prune_conv("conv1_1_1")
	pruner.prune_conv("conv1_2_1")
	pruner.prune_conv("conv1_2_2", "conv1_2_1")
	pruner.prune_conv("conv1_3_1")
	pruner.prune_conv("conv1_3_2", "conv1_3_1")
	pruner.prune_conv("conv1_3_3", "conv1_3_2")

	pruner.prune_concat("conv2_1", ("conv1_1_1", "conv1_2_2", "conv1_3_3"))
	pruner.prune_conv("conv2_2", "conv2_1")
	pruner.prune_conv("conv2_3", "conv2_2")
	pruner.prune_conv("conv2_4", "conv2_3")
	pruner.prune_conv("conv2_5", "conv2_4")
	pruner.prune_conv("conv2_6", "conv2_5")
	# pruner.prune_conv("conv2_7", "conv2_6")
	# pruner.prune_conv("conv2_8", "conv2_6")

	pruner.prune_concat("conv3_1_1", ("conv2_2", "conv2_4", "conv2_6"), is_delete=True)
	pruner.prune_concat("conv3_1_1b", ("conv2_2", "conv2_4", "conv2_6"), True, True)

	# residual
	# pruner.prune_copy("conv3_1_2", "conv3_1_1b", "conv3_1_1")
	# pruner.prune_conv("conv3_2_1", "conv3_1_1b")
	# pruner.prune_copy("conv3_2_2", "conv3_1_1b", "conv3_2_1")
	# pruner.prune_conv("conv3_3_1", "conv3_1_1b")
	# pruner.prune_copy("conv3_3_2", "conv3_1_1b", "conv3_3_1")
	# pruner.prune_conv("conv3_4_1", "conv3_1_1b")
	# pruner.prune_copy("conv3_4_2", "conv3_1_1b", "conv3_4_1")
	# pruner.prune_conv("conv3_5_1", "conv3_1_1b")
	# pruner.prune_copy("conv3_5_2", "conv3_1_1b", "conv3_5_1")
	# pruner.prune_conv("conv3_6_1", "conv3_1_1b")
	# pruner.prune_copy("conv3_6_2", "conv3_1_1b", "conv3_6_1")

	pruner.prune_conv("conv3_1_2", "conv3_1_1", not_del_filters=True)
	pruner.prune_conv("conv3_2_1")
	pruner.prune_conv("conv3_2_2", "conv3_2_1", not_del_filters=True)
	pruner.prune_conv("conv3_3_1")
	pruner.prune_conv("conv3_3_2", "conv3_3_1", not_del_filters=True)
	pruner.prune_conv("conv3_4_1")
	pruner.prune_conv("conv3_4_2", "conv3_4_1", not_del_filters=True)
	pruner.prune_conv("conv3_5_1")
	pruner.prune_conv("conv3_5_2", "conv3_5_1", not_del_filters=True)
	pruner.prune_conv("conv3_6_1")
	pruner.prune_conv("conv3_6_2", "conv3_6_1", not_del_filters=True)

	#pruner.prune_concat("conv4_1_1", ("conv3_2_2", "conv3_4_2", "conv3_6_2"))
	#pruner.prune_concat("conv4_1_1b", ("conv3_2_2", "conv3_4_2", "conv3_6_2"))

	# pruner.prune_copy("conv4_1_2", "conv4_1_1b", "conv4_1_1")
	# pruner.prune_conv("conv4_2_1", "conv4_1_1b")
	# pruner.prune_copy("conv4_2_2", "conv4_1_1b", "conv4_2_1")
	# pruner.prune_conv("conv4_3_1", "conv4_1_1b")
	# pruner.prune_copy("conv4_3_2", "conv4_1_1b", "conv4_3_1")
	# pruner.prune_conv("conv4_4_1", "conv4_1_1b")
	# pruner.prune_copy("conv4_4_2", "conv4_1_1b", "conv4_4_1")
	# pruner.prune_conv("conv4_5_1", "conv4_1_1b")
	# pruner.prune_copy("conv4_5_2", "conv4_1_1b", "conv4_5_1")
	# pruner.prune_conv("conv4_6_1", "conv4_1_1b")
	# pruner.prune_copy("conv4_6_2", "conv4_1_1b", "conv4_6_1")

	pruner.prune_conv("conv4_1_1")
	pruner.prune_conv("conv4_1_2", "conv4_1_1", not_del_filters=True)
	pruner.prune_conv("conv4_2_1")
	pruner.prune_conv("conv4_2_2", "conv4_2_1", not_del_filters=True)
	pruner.prune_conv("conv4_3_1")
	pruner.prune_conv("conv4_3_2", "conv4_3_1", not_del_filters=True)
	pruner.prune_conv("conv4_4_1")
	pruner.prune_conv("conv4_4_2", "conv4_4_1", not_del_filters=True)
	pruner.prune_conv("conv4_5_1")
	pruner.prune_conv("conv4_5_2", "conv4_5_1", not_del_filters=True)
	pruner.prune_conv("conv4_6_1")
	pruner.prune_conv("conv4_6_2", "conv4_6_1", not_del_filters=True)

	# pruner.prune_concat("conv5_1_1", ("conv4_2_2", "conv4_4_2", "conv4_6_2"))
	# pruner.prune_concat("conv5_1_1b", ("conv4_2_2", "conv4_4_2", "conv4_6_2"), True)
	ind = pruner.get_feature_id(output_tensor='conv5_1_1b')
	print(len(ind[0]))
	bias_l = []
	bias1 = pruner.prune_bias("conv5_1_1", ind[0])
	bias_l.append(bias1)
	bias2 = pruner.prune_bias("conv5_1_2", ind[0])
	bias_l.append(bias2)

	# pruner.prune_conv("conv5_1_1b")
	# pruner.prune_concat("conv5_1_1b", ("conv4_2_2", "conv4_4_2"), is_delete_4=True)
	pruner.prune_id("conv5_1_1b", ind[0], bias_l=bias_l)
	# pruner.prune_id("conv5_1_2", ind[0], "conv5_1_1")
	# pruner.prune_conv("conv5_2_1", "conv5_1_2")
	# pruner.prune_id("conv5_2_2", ind[0], "conv5_2_1")
	# pruner.prune_conv("conv5_3_1", "conv5_2_2")
	# pruner.prune_id("conv5_3_2", ind[0], "conv5_3_1")
	# pruner.prune_conv("conv5_4_1", "conv5_3_2")
	# pruner.prune_conv("conv5_4_1", "conv5_2_2")
	# pruner.prune_id("conv5_4_2", ind[0], "conv5_4_1")
	# pruner.prune_conv("conv5_5_1", "conv5_4_2")
	# pruner.prune_id("conv5_5_2", ind[0], "conv5_5_1")
	# pruner.prune_conv("conv5_6_1", "conv5_5_2")
	# pruner.prune_id("conv5_6_2", ind[0], "conv5_6_1")

	pruner.fc_concat('fc5_svd', ("conv5_1_1b",))
	# pruner.fc_concat('fc5_', ("conv5_1_1b",))

	# pruner.prune_conv("conv5_1_1")
	# pruner.prune_conv("conv5_1_2", "conv5_1_1", not_del_filters=True)
	# pruner.prune_conv("conv5_2_1")
	# pruner.prune_conv("conv5_2_2", "conv5_2_1", not_del_filters=True)
	# pruner.prune_conv("conv5_3_1")
	# pruner.prune_conv("conv5_3_2", "conv5_3_1", not_del_filters=True)
	# pruner.prune_conv("conv5_4_1")
	# pruner.prune_conv("conv5_4_2", "conv5_4_1", not_del_filters=True)
	# pruner.prune_conv("conv5_5_1")
	# pruner.prune_conv("conv5_5_2", "conv5_5_1", not_del_filters=True)
	# pruner.prune_conv("conv5_6_1")
	# pruner.prune_conv("conv5_6_2", "conv5_6_1", not_del_filters=True)

	#pruner.prune_fc("fc5_")

	print([(k, v[0].shape[0]) for k, v in pruner.conv_data.items() if v[0] is not None])

	# You should modify the number of channels in new prototxt before save
	# pruner.save("./models/prune_0.05/150/TestModel.prototxt", "./models/prune_0.05/150/TestModel.caffemodel")
	pruner.save("./models/prune_fc/TestModel.prototxt", "./models/prune_fc/TestModel.caffemodel")

if __name__ == '__main__':
	main()




  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值