学习笔记:实现CNN卷积神经网络处理MNIST数据集

该博客介绍了如何利用Tensorflow2.5和Python3.7在Google Colab上实现CNN(卷积神经网络)训练,特别是使用EfficientNetB1模型处理MNIST手写数字数据集。作者首先设置了训练参数,然后对数据进行预处理,接着导入并预处理MNIST数据集。模型创建部分,构建了一个基于EfficientNetB1的模型,最后进行了模型训练,并绘制了训练损失和准确率的变化曲线。训练结果显示模型在验证集上达到了99.33%的准确率,总训练时间为4574.06秒。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

学习时间:2021.07.25
学习内容: 实现CNN卷积神经网络处理MNIST数据集
编程环境: Tensorflow2.5、Python3.7、Google Colab
训练模型:EfficientNetB1
代码链接:https://colab.research.google.com/drive/1M4xzJA8mnxMS63mDrSPlkXJdziMzYhCq?usp=sharing

文章目录:

1.导包

# 获取权限:访问云端硬盘中的所有文件
from os.path import join
from google.colab import drive

ROOT = "/content/drive"
drive.mount(ROOT)
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.ticker as ticker
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
import os, time
import pandas as pd
import numpy as np
!gdown --id 1fsKERl26TNTFIY25PhReoCujxwJvfyHn
zhfont = mpl.font_manager.FontProperties(fname='SimHei .ttf')
zhfont2 = mpl.font_manager.FontProperties(fname='New Times Roman .ttf')
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

2.导入数据集

# 参数设置
# 参数设置
batch_size = 128
class_num = 10
epochs = 100
learning_rate_original = 0.01

# 数据类型转换
def preprocess(x, y):
    
    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=class_num)
    return x,y

# 导入mnist数据
(x, y), (x_test, y_test) = datasets.mnist.load_data()

x = np.expand_dims(x, axis=3)
x = np.pad(x, ((0,0), (2,2), (2,2),(0,2)), 'edge')
x_test = np.expand_dims(x_test, axis=3)
x_test = np.pad(x_test, ((0,0), (2,2), (2,2),(0,2)), 'edge')

train_num = int(x.shape[0])
val_num = int(x_test.shape[0])

print(x.shape, y.shape, train_num, val_num)

db = tf.data.Dataset.from_tensor_slices((x,y))
db = db.map(preprocess).repeat().shuffle(10000).batch(batch_size)

db_test = tf.data.Dataset.from_tensor_slices((x_test,y_test))
db_test = db_test.map(preprocess).repeat().batch(batch_size)

# 打印数据形状
db_iter = iter(db)
sample = next(db_iter)
print('batch:', sample[0].shape, sample[1].shape)

db_iter = iter(db_test)
sample = next(db_iter)
print('batch:', sample[0].shape, sample[1].shape)
# Ir
import math
def lr_schedule(epoch):

    learning_rate = learning_rate_original*math.exp(0.01)

    if epoch > epochs*0.8:
        learning_rate = learning_rate_original*math.exp(-0.69)
    elif epoch > epochs*0.5:
        learning_rate = learning_rate_original*math.exp(-0.51)
    elif epoch > epochs*0.3:
        learning_rate = learning_rate_original*math.exp(-0.35)

    return learning_rate

lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_schedule)

rng = [i for i in range(epochs)]
y = [lr_schedule(x) for x in rng]
plt.plot(rng, y)
print("Learning rate schedule: {:.3g} to {:.3g} to {:.3g}".format(y[0], max(y), y[-1]))

在这里插入图片描述

学习率变化曲线

3.模型创建

# model_01:EfficientNetB1
from tensorflow.keras.applications import EfficientNetB1

feature = EfficientNetB1(include_top=False, weights='imagenet', input_shape=(32,32,3))
model_01 = tf.keras.Sequential([feature,
                tf.keras.layers.GlobalAvgPool2D(),
                tf.keras.layers.Dropout(rate=0.1),
                tf.keras.layers.Dense(512),
                tf.keras.layers.Dropout(rate=0.1),
                tf.keras.layers.Dense(128),
                tf.keras.layers.Dense(class_num)])

model_01.summary()

在这里插入图片描述

4.模型训练

start_1 =time.perf_counter()
model_01.compile(optimizer=tf.keras.optimizers.SGD(learning_rate = learning_rate_original,momentum=0.9),
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=["accuracy"])

callback = [tf.keras.callbacks.ModelCheckpoint(filepath='/content/save_weights/modle_01/my_model_{epoch}.h5',
                        save_best_only=True,
                        save_weights_only=True,
                        monitor='val_accuracy'),
                        lr_callback]
history = model_01.fit(x = db,
           steps_per_epoch = train_num // batch_size,
           epochs = epochs,
           validation_data = db_test,
           validation_steps = val_num // batch_size,
           callbacks = callback)
end_1 = time.perf_counter()

history_dict = history.history
train_loss = history_dict["loss"]
train_accuracy = history_dict["accuracy"]
val_loss = history_dict["val_loss"]
val_accuracy = history_dict["val_accuracy"]

print("Accuracy:", max(val_accuracy))
print('Tranning time:\t %s'%(end_1 - start_1))

5.训练效果

for i in range(0,epochs):
    val_accuracy[i] = val_accuracy[i] * 100
    train_accuracy[i] = train_accuracy[i] * 100
# figure 1
plt.figure(dpi = 100)
plt.gca().yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))#设置坐标轴精度
plt.plot(range(epochs), train_loss, label='Train', linestyle='--',linewidth = '0.8')
plt.plot(range(epochs), val_loss, color = 'red', label='Test',linewidth = '0.8')
plt.xlabel(u"训练迭代次数",fontproperties=zhfont,fontsize=14)
plt.ylabel(u"损失值",fontproperties=zhfont,fontsize=14)

plt.xticks(np.arange(0, epochs+1, 5),fontproperties = 'Times New Roman', size = 12)
plt.legend(fontsize=12)
plt.show()

# figure 2
plt.figure(dpi = 100)
plt.plot(range(epochs), train_accuracy, label='Train', linestyle='--',linewidth = '0.8')
plt.plot(range(epochs), val_accuracy,color = 'red', label='Test',linewidth = '0.8')
plt.xlabel(u"训练迭代次数",fontproperties=zhfont,fontsize=14)
plt.ylabel(u"准确率/%",fontproperties=zhfont,fontsize=14)

#设置label位置
plt.legend(loc='lower right',fontsize=12) 

plt.xticks(np.arange(0, epochs+1, 5),fontproperties = 'Times New Roman', size = 12)
plt.yticks(np.arange(0, 101,10),fontproperties = 'Times New Roman', size = 12)

plt.show()

在这里插入图片描述
在这里插入图片描述

训练总结

数据集模型名称模型参数量/M显卡训练时间/s准确率/%
MNISTEfficientNetB17.30NVIDIA Tesla T44574.0699.33

💗💗💗

print("如果文章对你有用,请点个赞呗O(∩_∩)O~")
System.out.println("如果文章对你有用,请点个赞呗O(∩_∩)O~");
cout<<"如果文章对你有用,请点个赞呗O(∩_∩)O~"<<endl;

💗💗💗

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值