找tf关于Pruning和quantization的用例较少,正好在做这方面工作,就搬一搬一些官方文档的应用。
下面的代码主要是结合一个官方Mnist的示例和guide文档看看tf的API中是怎么做pruning这一步优化的。
tensorflow/model-optimization--comprehensive_guide
总的思路是: 建baseline model → 加入剪枝操作→ 对比模型大小、acc等变化
其中关注其中如何自定义自己的pruning case和后续quantization等
目录
1.导入一些依赖库,后面似乎没用到tensorboard,暂时注释掉
3.建立一个Baseline模型,并保存权重,方便后续比较性能
4.对整个模型直接magnitude,建立剪枝模型,顺便看看模型前后变化
5.选定某个层进行magnitude(这里选择Dense layer),建立剪枝模型,看看模型变化
import tempfile
import os
import zipfile
import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot
from tensorflow import keras
#%load_ext tensorboard
1.导入一些依赖库,后面似乎没用到tensorboard,暂时注释掉
#加载MNIST数据集
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
#将图像像素值规整到[0,1]
train_images = train_images / 255.0
test_images = test_images / 255.0
2.导入Mnist数据集,作简单规整
#建立模型
def setup_model():
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=12,kernel_size=(3, 3), activation='relu'),
keras.layers.MaxPooling2D(pool_size=(2,2)),
keras.layers.Flatten(),
keras.layers.Dense(10)
])
return model
#训练分类模型参数
def setup_pretrained_weights():
model = setup_model()
model.compile(optimizer = 'adam',
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
metrics = ['accuracy']
)
model.fit(train_images,
train_labels,
epochs = 4,
validation_split = 0.1,
)
_, pretrained_weights = tempfile.mkstemp('.tf')
model.save_weights(pretrained_weights)
return pretrained_weights
3.建立一个Baseline模型,并保存权重,方便后续比较性能
setup_model()
pretrained_weights = setup_pretrained_weights()
#
Train on 54000 samples, validate on 6000 samples
Epoch 1/4
54000/54000 [==============================] - 7s 133us/sample - loss: 0.2895 - accuracy: 0.9195 - val_loss: 0.1172 - val_accuracy: 0.9685
Epoch 2/4
54000/54000 [==============================] - 5s 99us/sample - loss: 0.1119 - accuracy: 0.9678 - val_loss: 0.0866 - val_accuracy: 0.9758
Epoch 3/4
54000/54000 [==============================] - 5s 100us/sample - loss: 0.0819 - accuracy: 0.9753 - val_loss: 0.0757 - val_accuracy: 0.9787
Epoch 4/4
54000/54000 [==============================] - 6s 103us/sample - loss: 0.0678 - accuracy: 0.9797 - val_loss: 0.0714 - val_accuracy: 0.9815
4.对整个模型直接magnitude,建立剪枝模型,顺便看看模型前后变化
#比较baselin与剪裁模型的差别
base_model = setup_model()
base_model.summary()
base_model.load_weights(pretrained_weights)
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
model_for_pruning.summary()
#
Model: "sequential_4"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
reshape_4 (Reshape) (None, 28, 28, 1) 0
_________________________________________________________________
conv2d_4 (Conv2D)