提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
前言
图像增强—图像的简单形变,应对因拍照角度不同引起的图片形变。
x_train.shape(数目,长,宽,通道)
一.数据增强使用
代码如下:
# 1.导入模块---import
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# 2.指定数据集和训练集---(x_train,y_train)
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255, x_test/255
# 给数据增加一个维度,使数据和网络结构匹配
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1) # x_train.shape:(60000,28,28)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
# 数据增强函数参数设置----数据增强函数输入为4维,升维
image_gen_train = ImageDataGenerator(rescale=1./1,
rotation_range=45,
width_shift_range=.15,
height_shift_range=.15,
horizontal_flip=True,
zoom_range=0.5
)
# 输入数据(图像像素)进行数据增强
image_gen_train.fit(x_train)
# 3.搭建网络模型----class
class MnistModel(Model):
def __init__(self):
super(MnistModel, self).__init__()
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.flatten(x)
x = self.d1(x)
y = self.d2(x)
return y
model = MnistModel()
# 4。为网络模型配置训练方法----compile
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy']
)
# 5.网络模型传入数据集开始训练---fit
model.fit(image_gen_train.flow(x_train, y_train, batch_size=32), epochs=5,
validation_data=(x_test, y_test),
validation_freq=1)
# 6.打印网络莫模型和结构----summary
model.summary()