神经网络基础模型功能扩展

数据增强

tensorflow提供了对图片数据进行增强的函数,在小数据量时可以增加模型泛化性:

image_gen_train = tf.keras.preprocessing.image.ImageDataGenerator(
	rescale = 所有数据乘以该数值
	rotation_range = 随机旋转角度数范围
	width_shift_range = 随机宽度偏移量
	height_shift_range = 随机高度偏移量
	horizontal_flip = 是否随机水平翻转
	zoom_range = 随即缩放的范围[1-n, 1+n]
)
## 由于这里输入的x_train需要是四维数据,因此需要把原始数据reshape成28行28列单通道的数据
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
## (60000, 28, 28)->(60000, 28, 28, 1)
image_gen_train.fit(x_train)

## model.fit(x_train, y_train, batch_size=32,...)
model.fit(image_gen_train.flow(x_train, y_train, batch_size=32),...)
## model.fit同步更新为.flow()形式

示例:
在这里插入图片描述

断点续训

断点续训可以存储模型

  • 读取模型:
# load_weights(路径文件名)
checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):  ## 如果模型存在会有该模型对应的索引
	print('---------------------load the model----------------------')
	model.load_weights(checkpoint_save_path)
  • 保存模型:
cp_callback = tf.keras.callbacks.ModelCheckpoint
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值