1. 背景
这里的任务是要在TF框架下运用现在已经较为成熟的剪枝算法,由于TF中对于可训练参数的管理与Pytorch、Caffe并不相同,这篇文章对其TF实现卷积参数剪枝进行了探究。
2. TF中可训练参数的存储
TF中是将网络构建为一个计算图,在一个计算图中可以通过集合(collection
)来管理不同类别的资源。比如通过tf.add_to_collection
函数可以将资源加入一个或多个集合中,然后通过tf.get_collection
获取一个集合里面的所有资源(如张量,变量,或者运行TensorFlow程序所需的队列资源等等)。下表就是TF中的一些集合及其释义:
集合名称 | 集合内容 | 使用场景 |
---|---|---|
tf.GraphKeys.VARIABLES |
所有变量 | 持久化 TensorFlow 模型 |
tf.GraphKeys.TRAINABLE_VARIABLES |
可学习的变量(一般指神经网络中的参数) | 模型训练、生成模型可视化内容 |
tf.GraphKeys.SUMMARIES |
日志生成相关的张量 | TensorFlow 计算可视化 |
tf.GraphKeys.QUEUE_RUNNERS |
处理输入的 | QueueRunner |
tf.GraphKeys.MOVING_AVERAGE_VARIABLES |
所有计算了滑动平均值的变量</ |