Trained Ternary Quantization
ICLR 2017
https://github.com/TropComplique/trained-ternary-quantization pytorch
https://github.com/buaabai/Ternary-Weights-Network pytorch
传统的二值网络将权重 W 量化为 +1、-1; 三值网络 TWN (Ternary weight networks) 将权重W 量化为 {−W_l ,0,+W_l }
阈值的计算公式如下所示
本文提出了新的三值网络
positive and negative weights,三个不同的值用于表示三值网络,这个正负权值是通过网络学习得到的
对应的梯度计算如下
本文的阈值选择采用:
set t to 0.05 in experiments on CIFAR-10 and ImageNet dataset
The quantization roughly proceeds as follows.
-
Train a model of your choice as usual (or take a trained model).
-
Copy all full precision weights that you want to quantize. Then do the initial quantization:
in the model replace them by ternary values {-1, 0, +1} using some heuristic. -
Repeat until convergence:
1). Make the forward pass with the quantized model. 使用量化后的网络进行前向计算
2). Compute gradients for the quantized model. 对量化网络进行梯度计算
3). Preprocess the gradients and apply them to the copy of full precision weights. 使用梯度更新网络模型的权重
4). Requantize the model using the changed full precision weights. 对新的权重进行量化 -
Throw away the copy of full precision weights and use the quantized model.
11