tensorflow2.x的模型输入默认格式是NHWC格式,但是某些应用场景下需要在模型中将NHWC输入格式转换为输入NCHW格式,操作代码如下
import tensorflow as tf
model_path = "./xxx.h5"
output_path = "./yyy.h5"
model.load_model(model_path) # 当前输入尺寸是128*128
new_model = tf.keras.models.Sequential([Input((3, 128, 128)), tf.keras.layers.Lambda(lambda x: tf.transpose(x,[0,2,3,1])), model])
new_model.summary()
new_model.save(output_path)