Tensorflow2.0 keras模块 实现mnist数据集识别

23 篇文章 0 订阅
20 篇文章 0 订阅

tf.keras搭建网络八股:

六步法: 0 -》1-》3-》4-》6-》7

其他的属于扩展操作,例如数据增强,断点续训

目录

0. import  ....#导入相关的库

1. train, test  #测试数据,训练数据

2. ImageDataGenerator #数据增强

3. model = tf.keras.model.Sequential  #搭建网络结构

4. model.compile   #配置环境

5. 断点续训

6. model.fit  #喂入数据

7. model.summary  #打印参数

8.示例代码

0. import  ....#导入相关的库

1. train, test  #测试数据,训练数据

2. ImageDataGenerator #数据增强

# 数据增强
image_gen_train = ImageDataGenerator (
    rescale =1.0/1.0,   #
    rotation_range = 45, #随机旋转45度
    width_shift_range = 0.15, #宽度偏移
    height_shift_range = 0.15, #高度偏移
    horizontal_flip = True , #水平翻转
    zoom_range = 0.5# 缩放0.5
)
image_gen_train.fit(x_train)

3. model = tf.keras.model.Sequential  #搭建网络结构

model =tf.keras.model.Sequential([
    tf.keras.layers.Flatten(),  #拉直层
    tf.keras.layers.Dense(神经元个数, activation= '激活函数',kernel_regularizer = 哪种正则化), #全连接层
    #activation 可选relu, softmax, sigmoid, tanh
    #kernel_regularizer 可选:tf.keras.regularizer.l1(), tf.keras.regularizer.l2()
    tf.keras.layers.Conv2D(filters= 卷积核个数, kernel_size = 卷积核尺寸, strides =卷积步长,
    padding = 'valid' or 'same' ),  #卷积层
    MaxPool2D(2,2), #池化层
    tf.keras.layers.LSTM()    #LSTM层
])

4. model.compile   #配置环境

model.compile(optimizer =优化器, loss = 损失函数, metrices = ['准确率'])
#optimizer可选:'sgd' or tf.keras.optimizers.SGD(lr= 学习率, momentum = 动量参数)
# 'adagrad' or tf.keras.optimizers.Adagrad(lr=学习率)
# 'adadelta' or tf.keras.optimizers.Adadelta(lr =学习率)
# 'adm' or tf.keras.optimizers.Adam(lr = 学习率, beta_2 = 0.9, beta_2 = 0.999)
# loss可选: 'mse' or tf.keras.losses.MeanSquaredError()
# 'sparse_categorical_crossentropy' or tf.keras.losses.SparseCategoricalCrossentropy (from logits = False)
# metrices 可选:'accuracy':y_和y都是数值
# 'categorical_accuracy':y_和y都是独热编码
# 'sparse_categorical_accuracy' y_数值,y是独热编码

5. 断点续训

#断点续存之读取模型
checkpoint_save_path = './checkpoint/fashion_mnist.ckpt'
if os.path.exists(checkpoint_save_path +'.index'):
    print('----------------------------load the model--------------------------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath =checkpoint_save_path,
    save_weights_only= True,
    save_best_only= True)

#4. 喂入训练书记和测试数据
history = model.fit(image_gen_train(x_train, y_train, batch_size = 32), epochs =5, 
    validation_data = (x_test, y_test), validation_freq = 1,
    callbacks = [cp_callback])

6. model.fit  #喂入数据

model.fit(训练集的输入特征, 训练集标签, batch_size = ..., epochs = ..., 
validation_data = (测试集的输入特征,测试集标签),
validation_split = 从训练集划分多少比例给测试集,
validation_freq = 多少次epochs测试一次)

7. model.summary  #打印参数

8.示例代码

# 0.导入模块
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator

#1. 导入训练集、测试集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0

# 数据增强
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
image_gen_train = ImageDataGenerator (
    rescale =1.0/1.0,   #
    rotation_range = 45, #随机旋转45度
    width_shift_range = 0.15, #宽度偏移
    height_shift_range = 0.15, #高度偏移
    horizontal_flip = True , #水平翻转
    zoom_range = 0.5# 缩放0.5
)
image_gen_train.fit(x_train)

#2. 定义网络模型
class MnistModel(Model):
    def __init__(self):#初始化
        super(MnistModel, self).__init__()
        self.flatten = Flatten()
        self.d1 = Dense(128, activation = 'relu')
        self.d2 = Dense(10, activation = 'softmax')
    #定义前向传播过程
    def call (self, x):
        x = self.flatten(x)
        x = self.d1(x) 
        y = self.d2(x)
        return y

model = MnistModel()

#3. 配置训练参数
model.compile(optimizer ='adam',
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics = ['sparse_categorical_accuracy'])

#断点续存之读取模型
checkpoint_save_path = './checkpoint/mnist.ckpt'
if os.path.exists(checkpoint_save_path +'.index'):
    print('----------------------------load the model--------------------------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath =checkpoint_save_path,
    save_weights_only= True,
    save_best_only= True)

#4. 喂入训练书记和测试数据
history = model.fit(image_gen_train(x_train, y_train, batch_size = 32), epochs =5, validation_data = (x_test, y_test), validation_freq = 1,
callbacks = [cp_callback])
#5. 打印参数和模型结构
model.summary()
#-----------------------------显示loss和acc的曲线-----------------------------------------------
loss = history.history['loss']
val_loss = history.history['val_loss']
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']

plt.subplot (1,2,1)
plt.plot(acc, label = 'training accuracy')
plt.plot(val_acc, label = 'validation accuracy')
plt.title('training and validation accuracy curve')
plt.legend()

plt.subplot (1,2,2)
plt.plot(loss, label = 'training loss')
plt.plot(val_loss, label = 'validation loss')
plt.title('training and validation loss curve')
plt.legend()
plt.show()
#------------------------------------------预测-----------------------
def img_trans(img_path):
    img = Image.open(img_path)
    img = img.resize((28,28), Image.ANTIALIAS)
    img_arr = np.array (img.convert('L'))

    for i in range(28):
        for j in range(28):
            if img_arr[i][j]<200:
                img_arr [i][j] =255
            else:
                img_arr[i][j] =0
    img_arr = 255-img_arr
    img_arr = img_arr/255.0
    return img_arr

num  = int(input('the number of images: '))
for i in range(num):
    img_path = input('the path of image:')
    img_arr = img_trans(img_path)
    x_predict = img_arr [tf.newaxis, ...]
    result = model.predict(x_predict)
    print(result)
    pred = tf.argmax(result, axis =1)
    print('success predict\n')
    print(pred)

拓展:

fashion数据库,就是导入数据不一样

mnist = tf.keras.datasets.fashion_mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train, x_test = x_train/255.0, x_test/255.0

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值