tf.keras搭建网络八股:
六步法: 0 -》1-》3-》4-》6-》7
其他的属于扩展操作,例如数据增强,断点续训
目录
3. model = tf.keras.model.Sequential #搭建网络结构
0. import ....#导入相关的库
1. train, test #测试数据,训练数据
2. ImageDataGenerator #数据增强
# 数据增强
image_gen_train = ImageDataGenerator (
rescale =1.0/1.0, #
rotation_range = 45, #随机旋转45度
width_shift_range = 0.15, #宽度偏移
height_shift_range = 0.15, #高度偏移
horizontal_flip = True , #水平翻转
zoom_range = 0.5# 缩放0.5
)
image_gen_train.fit(x_train)
3. model = tf.keras.model.Sequential #搭建网络结构
model =tf.keras.model.Sequential([
tf.keras.layers.Flatten(), #拉直层
tf.keras.layers.Dense(神经元个数, activation= '激活函数',kernel_regularizer = 哪种正则化), #全连接层
#activation 可选relu, softmax, sigmoid, tanh
#kernel_regularizer 可选:tf.keras.regularizer.l1(), tf.keras.regularizer.l2()
tf.keras.layers.Conv2D(filters= 卷积核个数, kernel_size = 卷积核尺寸, strides =卷积步长,
padding = 'valid' or 'same' ), #卷积层
MaxPool2D(2,2), #池化层
tf.keras.layers.LSTM() #LSTM层
])
4. model.compile #配置环境
model.compile(optimizer =优化器, loss = 损失函数, metrices = ['准确率'])
#optimizer可选:'sgd' or tf.keras.optimizers.SGD(lr= 学习率, momentum = 动量参数)
# 'adagrad' or tf.keras.optimizers.Adagrad(lr=学习率)
# 'adadelta' or tf.keras.optimizers.Adadelta(lr =学习率)
# 'adm' or tf.keras.optimizers.Adam(lr = 学习率, beta_2 = 0.9, beta_2 = 0.999)
# loss可选: 'mse' or tf.keras.losses.MeanSquaredError()
# 'sparse_categorical_crossentropy' or tf.keras.losses.SparseCategoricalCrossentropy (from logits = False)
# metrices 可选:'accuracy':y_和y都是数值
# 'categorical_accuracy':y_和y都是独热编码
# 'sparse_categorical_accuracy' y_数值,y是独热编码
5. 断点续训
#断点续存之读取模型
checkpoint_save_path = './checkpoint/fashion_mnist.ckpt'
if os.path.exists(checkpoint_save_path +'.index'):
print('----------------------------load the model--------------------------------')
model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath =checkpoint_save_path,
save_weights_only= True,
save_best_only= True)
#4. 喂入训练书记和测试数据
history = model.fit(image_gen_train(x_train, y_train, batch_size = 32), epochs =5,
validation_data = (x_test, y_test), validation_freq = 1,
callbacks = [cp_callback])
6. model.fit #喂入数据
model.fit(训练集的输入特征, 训练集标签, batch_size = ..., epochs = ...,
validation_data = (测试集的输入特征,测试集标签),
validation_split = 从训练集划分多少比例给测试集,
validation_freq = 多少次epochs测试一次)
7. model.summary #打印参数
8.示例代码
# 0.导入模块
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
#1. 导入训练集、测试集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0
# 数据增强
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
image_gen_train = ImageDataGenerator (
rescale =1.0/1.0, #
rotation_range = 45, #随机旋转45度
width_shift_range = 0.15, #宽度偏移
height_shift_range = 0.15, #高度偏移
horizontal_flip = True , #水平翻转
zoom_range = 0.5# 缩放0.5
)
image_gen_train.fit(x_train)
#2. 定义网络模型
class MnistModel(Model):
def __init__(self):#初始化
super(MnistModel, self).__init__()
self.flatten = Flatten()
self.d1 = Dense(128, activation = 'relu')
self.d2 = Dense(10, activation = 'softmax')
#定义前向传播过程
def call (self, x):
x = self.flatten(x)
x = self.d1(x)
y = self.d2(x)
return y
model = MnistModel()
#3. 配置训练参数
model.compile(optimizer ='adam',
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics = ['sparse_categorical_accuracy'])
#断点续存之读取模型
checkpoint_save_path = './checkpoint/mnist.ckpt'
if os.path.exists(checkpoint_save_path +'.index'):
print('----------------------------load the model--------------------------------')
model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath =checkpoint_save_path,
save_weights_only= True,
save_best_only= True)
#4. 喂入训练书记和测试数据
history = model.fit(image_gen_train(x_train, y_train, batch_size = 32), epochs =5, validation_data = (x_test, y_test), validation_freq = 1,
callbacks = [cp_callback])
#5. 打印参数和模型结构
model.summary()
#-----------------------------显示loss和acc的曲线-----------------------------------------------
loss = history.history['loss']
val_loss = history.history['val_loss']
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
plt.subplot (1,2,1)
plt.plot(acc, label = 'training accuracy')
plt.plot(val_acc, label = 'validation accuracy')
plt.title('training and validation accuracy curve')
plt.legend()
plt.subplot (1,2,2)
plt.plot(loss, label = 'training loss')
plt.plot(val_loss, label = 'validation loss')
plt.title('training and validation loss curve')
plt.legend()
plt.show()
#------------------------------------------预测-----------------------
def img_trans(img_path):
img = Image.open(img_path)
img = img.resize((28,28), Image.ANTIALIAS)
img_arr = np.array (img.convert('L'))
for i in range(28):
for j in range(28):
if img_arr[i][j]<200:
img_arr [i][j] =255
else:
img_arr[i][j] =0
img_arr = 255-img_arr
img_arr = img_arr/255.0
return img_arr
num = int(input('the number of images: '))
for i in range(num):
img_path = input('the path of image:')
img_arr = img_trans(img_path)
x_predict = img_arr [tf.newaxis, ...]
result = model.predict(x_predict)
print(result)
pred = tf.argmax(result, axis =1)
print('success predict\n')
print(pred)
拓展:
fashion数据库,就是导入数据不一样
mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0