Copyright 2020 The TensorFlow Authors.
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Quantization aware training comprehensive guide
![]() |
![]() |
![]() |
![]() |
Welcome to the comprehensive guide for Keras quantization aware training.
This page documents various use cases and shows how to use the API for each one. Once you know which APIs you need, find the parameters and the low-level details in the
API docs.
- If you want to see the benefits of quantization aware training and what’s supported, see the overview.
- For a single end-to-end example, see the quantization aware training example.
The following use cases are covered:
- Deploy a model with 8-bit quantization with these steps.
- Define a quantization aware model.
- For Keras HDF5 models only, use special checkpointing and
deserialization logic. Training is otherwise standard. - Create a quantized model from the quantization aware one.
- Experiment with quantization.
- Anything for experimentation has no supported path to deployment.
- Custom Keras layers fall under experimentation.
Setup
For finding the APIs you need and understanding purposes, you can run but skip reading this section.
! pip uninstall -y tensorflow
! pip install -q tf-nightly
! pip install -q tensorflow-model-optimization
import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot
import tempfile
input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=20)
def setup_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(20, input_shape=input_shape),
tf.keras.layers.Flatten()
])
return model
def setup_pretrained_weights():
model= setup_model()
model.compile(
loss=tf.keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy']
)
model.fit(x_train, y_train)
_, pretrained_weights = tempfile.mkstemp('.tf')
model.save_weights(pretrained_weights)
return pretrained_weights
def setup_pretrained_model():
model = setup_model()
pretrained_weights = setup_pretrained_weights()
model.load_weights(pretrained_weights)
return model
setup_model()
pretrained_weights = setup_pretrained_weights()
WARNING: Skipping tensorflow as it is not installed.
1/1 [==============================] - 0s 1000us/step - loss: 1.1921e-07 - accuracy: 0.0000e+00
##Define quantization aware model
By defining models in the following ways, there are available paths to deployment to backends listed in the overview page. By default, 8-bit quantization is used.
Note: a quantization aware model is not actually quantized. Creating a quantized model is a separate step.
Quantize whole model
Your use case:
- Subclassed models are not supported.
Tips for better model accuracy:
- Try “Quantize some layers” to skip quantizing the layers that reduce accuracy the most.
- It’s generally better to finetune with quantization aware training as opposed to training from scratch.
To make the whole model aware of quantization, apply tfmot.quantization.keras.quantize_model
to the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
quant_aware_model = tfmot.quantization.keras.quantize_model(base_model)
quant_aware_model.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
quantize_layer (QuantizeLaye (None, 20) 3
_________________________________________________________________
quant_dense_2 (QuantizeWrapp (None, 20) 425
_________________________________________________________________
quant_flatten_2 (QuantizeWra (None, 20) 1
=================================================================
Total params: 429
Trainable params: 420
Non-trainable params: 9
_________________________________________________________________
Quantize some layers
Quantizing a model can have a negative effect on accuracy. You can selectively quantize layers of a model to explore the trade-off between accuracy, speed, and model size.
Your use case:
- To deploy to a backend that only works well with fully quantized models (e.g. EdgeTPU v1, most DSPs), try “Quantize whole model”.
Tips for better model accuracy:
- It’s generally better to finetune with quantization aware training as opposed to training from scratch.
- Try quantizing the later layers instead of the first layers.
- Avoid quantizing critical layers (e.g. attention mechanism).
In the example below, quantize only the Dense
layers.
# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but