Tensorflow2.0学习笔记-数据增强,断点续训

数据增强

在小数据模型中,数据增强可以起到明显的效果,本次使用的是mnist数据集单靠准确率去证明数据增强的效果是不可行的,需要自己在实际运用中体会。

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)  # 给数据增加一个维度,从(60000, 28, 28)reshape为(60000, 28, 28, 1)

image_gen_train = ImageDataGenerator(
    rescale=1. / 1.,  # 所有数据将乘以该数值,如为图像,分母为255时,可归至0~1
    rotation_range=45,  # 随机旋转角度范围。随机45度旋转
    width_shift_range=.15,  # 随机宽度偏移量
    height_shift_range=.15,  # 随机高度偏移
    horizontal_flip=False,  # 是否随机水平翻转
    zoom_range=0.5  # 调整缩放范围。将图像随机缩放阈量50%
)
image_gen_train.fit(x_train)

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(image_gen_train.flow(x_train, y_train, batch_size=32),
          # 将输入的x_train, y_train和数据打乱输入,但是二者之间数据的一一对应性不变
          epochs=5, validation_data=(x_test, y_test),
          validation_freq=1)
model.summary()


断点续训

断点续训可以接着之前的训练模型进行训练。代码如下:

import tensorflow as tf
import os

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    # 读取模型函数
    model.load_weights(checkpoint_save_path)
# 保存模型的函数
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,  # 保存的路径名称
                                                 save_weights_only=True,  #  是否只保存模型参数
                                                 save_best_only=True)     #  是否只保存最优结果  

history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
model.summary()

运行成功后,会在工程文件夹下生成checkpoint文件。
在这里插入图片描述
再次运行程序会得到相应的输出。

G:\anaconda\envs\tensorflow-2.0\python.exe G:/Pycharmprojects/tf2_notes-master/class4/MNIST_FC/p16_mnist_train_ex3.py
2021-01-07 17:16:46.259132: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library cudart64_100.dll
-------------load the model-----------------
......
#  因为是之前训练好的模型所以得到的准确率先由大到小,再由小到大
   32/60000 [..............................] - ETA: 47:16 - loss: 0.0069 - sparse_categorical_accuracy: 1.0000
  352/60000 [..............................] - ETA: 4:25 - loss: 0.0603 - sparse_categorical_accuracy: 0.9830 
  672/60000 [..............................] - ETA: 2:23 - loss: 0.0491 - sparse_categorical_accuracy: 0.9851
  992/60000 [..............................] - ETA: 1:39 - loss: 0.0422 - sparse_categorical_accuracy: 0.9879
 1280/60000 [..............................] - ETA: 1:19 - loss: 0.0395 - sparse_categorical_accuracy: 0.9883
 1568/60000 [..............................] - ETA: 1:06 - loss: 0.0434 - sparse_categorical_accuracy: 0.9872
 1888/60000 [..............................] - ETA: 56s - loss: 0.0443 - sparse_categorical_accuracy: 0.9857 
  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值