数据增强
tensorflow提供了对图片数据进行增强的函数,在小数据量时可以增加模型泛化性:
image_gen_train = tf.keras.preprocessing.image.ImageDataGenerator(
rescale = 所有数据乘以该数值
rotation_range = 随机旋转角度数范围
width_shift_range = 随机宽度偏移量
height_shift_range = 随机高度偏移量
horizontal_flip = 是否随机水平翻转
zoom_range = 随即缩放的范围[1-n, 1+n]
)
## 由于这里输入的x_train需要是四维数据,因此需要把原始数据reshape成28行28列单通道的数据
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
## (60000, 28, 28)->(60000, 28, 28, 1)
image_gen_train.fit(x_train)
## model.fit(x_train, y_train, batch_size=32,...)
model.fit(image_gen_train.flow(x_train, y_train, batch_size=32),...)
## model.fit同步更新为.flow()形式
示例:
断点续训
断点续训可以存储模型
- 读取模型:
# load_weights(路径文件名)
checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'): ## 如果模型存在会有该模型对应的索引
print('---------------------load the model----------------------')
model.load_weights(checkpoint_save_path)
- 保存模型:
cp_callback = tf.keras.callbacks.ModelCheckpoint