1.首先安装tensorflow_model_optimization
pip install tensorflow_model_optimization
2.其次代码中添加两句即可:
import tensorflow_model_optimization as tfmot
model=build_model(image_height=args.image_height, image_width=args.image_width, channels=args.channels,nb_blocks=1,filters=64,hidden_nums=128,num_classes=args.num_classes,rnn_cell=args.rnn_cell,rd_block=args.reduce_block,dynamic=args.dynamic)
model.summary()
if args.qunt:
print("quantized model:")
######################### 添加这两句 ######################
quantize_model = tfmot.quantization.keras.quantize_annotate_model
quant_model = quantize_model(model)
quant_model.compile(optimizer=keras.optimizers.Adam(args.learning_rate),
loss=CTCLoss(), metrics=[WordAccuracy()])
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
"densenet_crnn_qunt_{epoch}.h5",
monitor='val_accuracy',
save_best_only=False,
verbose=2,
save_weights_only=False,
period=1),
tf.keras.callbacks.EarlyStopping(monitor='val_loss',
patience=5,
mode='auto',
restore_best_weights=True),
LearningRateLogger(),
tf.keras.callbacks.TensorBoard(log_dir='logs/',histogram_freq=0),
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.00000001,verbose=2),]
quant_model.fit(train_ds,
epochs=args.epochs,
validation_data=val_ds,
callbacks=callbacks,)
3.两种感知量化方式对比:
if args.qunt:
print("quantized model:")
############## 感知量化方式一#############
if args.qunt_type==0:
# quant_model = tfmot.quantization.keras.quantize_model
quantize_model = tfmot.quantization.keras.quantize_annotate_model
quant_model = quantize_model(model)
qt='qunt'
quant_model.summary()
############## 感知量化方式二#############
else:
adv_model=model
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_configs
NoOpQuantizeConfig = default_8bit_quantize_configs.NoOpQuantizeConfig
def apply_quantization(layer):
#######将不能转换的放进来#####
if isinstance(layer, tf.keras.layers.BatchNormalization) or isinstance(layer, tf.python.keras.engine.base_layer.TensorFlowOpLayer)\
or isinstance(layer, tf.python.keras.layers.merge.Concatenate) or isinstance(layer, tf.python.keras.layers.recurrent_v2.GRU):
return tfmot.quantization.keras.quantize_annotate_layer(layer, quantize_config=NoOpQuantizeConfig())
else:
return tfmot.quantization.keras.quantize_annotate_layer(layer)
annotated_adv_model = tf.keras.models.clone_model(
adv_model, clone_function=apply_quantization,
)
with tf.keras.utils.custom_object_scope({"NoOpQuantizeConfig": NoOpQuantizeConfig}):
quant_model = tfmot.quantization.keras.quantize_model(annotated_adv_model)
# To quantized specific layers as previous example, this also work.
# quantized_model = tfmot.quantization.keras.quantize_apply(annotated_adv_model)
qt='qunt_noOp'
quant_model.summary()
quant_model.compile(optimizer=keras.optimizers.Adam(args.learning_rate),
loss=CTCLoss(), metrics=[WordAccuracy()])
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
os.path.join(args.save_path+args.train_path[0].split('-')[0]+'/', "densenet_crnn_"+qt+"_{epoch}.h5"),
monitor='val_accuracy',
save_best_only=False,
verbose=2,
save_weights_only=False,
period=1),
tf.keras.callbacks.EarlyStopping(monitor='val_loss',
patience=10,
mode='auto',
restore_best_weights=True),
LearningRateLogger(),
tf.keras.callbacks.TensorBoard(log_dir=os.path.join('/data/git/ocr-platform/statistic/recognize/ocr_densenet_tensorflow_2/logs/',
args.train_path[0].split('-')[0])+'_'+qt,histogram_freq=0),
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.00000001,verbose=2),]
quant_model.fit(train_ds,
epochs=args.epochs,
validation_data=val_ds,
callbacks=callbacks,)
4.加载模型一定要添加with tfmot.quantization.keras.quantize_scope(),否则会报错:
ValueError: Unknown layer: QuantizeLayer
with tfmot.quantization.keras.quantize_scope():
loaded_model = tf.keras.models.load_model(keras_file)
5.h5转换成tflite并进行预测,需要tensorflow版本为2.3及以上,否则会报错:
ValueError: Unknown layer: Functional
ValueError: Did not get operators, tensors, or buffers in subgraph 1