MNIST手写数字识别准确度提升最全、最实用的方法

MNIST手写数字识别是所有学习AI同学的入门必经过程,MNIST识别准确率提升修炼是精通AI模型的必经课程,MNIST识别准确率开刚始大家一般都能达到90%左右,再往上提高还需要费较大的精力去修改模型、调优参数,MNIST识别率究竟能达到多少,对于初学者还是很难搞清楚,刚开始也没有经验去提升得很高,我在第一遍学习时,通过参数和训练次数调整,利用了很多模型,达到了99.3%的精度,再往上提升时当时那台电脑的计算能力不够,也没有找到新的模型,没有再做研究了,就学习其它内容去了。

上个月旧电脑出现故障,不能启动,硬盘使用了Bitlocker,数据也不能恢复,学习代码也丢了,这次更换了一台性能非常好的电脑后重新学习,就把MNIST的Lenet-5算法重新学习了一遍,结合Kaggle网站上的方案,做了个模型和方案,最终把MNIST识别准确度提升到了99.96%的水平,据了解由于MNIST数据本身有缺陷,除非将评估数据集加入训练准确率应该是达不到100%的。我想能达到3个9的准度也应该足够了,特将实现方案总结出来,分享新学习的同学们,希望对大家有所帮助。

一、Lenet-5通用模型方案

下面方案是Lenet-5最通用的方案,初始是用3X3卷积盒,模型使用Sequential构建,训练也是用自己写的循环代码一步一步构建,特别适合于新手引用。此外,该案例还有日志保存、模型可在代码。

"""
LetNet-5 实战1:网上使用最多的模型, 测试用例精确度能达到99%
"""
import datetime
import tensorflow as tf
from tensorflow.keras import Sequential, layers, losses, datasets


def preprocess(x, y):
    """
    预处理函数
    """
    # [b, 28, 28], [b]
    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x, y


# 设置LOG与模型保存相关参数
current_time = datetime.datetime.now().strftime(('%Y%m%d-%H%M%S'))
log_dir = 'logs/' + current_time
summary_writer = tf.summary.create_file_writer(log_dir)
(x, y), (x_test, y_test) = datasets.mnist.load_data()  # 加载手写数据集数据
batchsz = 128   # 此模型下batch size在128下比较好
train_db = tf.data.Dataset.from_tensor_slices((x, y))  # 转化为Dataset对象
train_db = train_db.shuffle(100000)  # 随机打散
train_db = train_db.batch(batchsz)  # 批训练
train_db = train_db.map(preprocess)  # 数据预处理

test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.shuffle(1000).batch(batchsz).map(preprocess)
# 通过Sequnentia容器创建LeNet-5
network = Sequential([
    layers.Conv2D(6, kernel_size=3, strides=1),  # 第一个卷积核,6个3X3的卷积核,
    layers.MaxPooling2D(pool_size=2, strides=2),  # 高宽各减半的池化层
    layers.ReLU(),  # 激活函数
    layers.Conv2D(16, kernel_size=3, strides=1),  # 第二个卷积核,16个3X3的卷积核,
    layers.MaxPooling2D(pool_size=2, strides=2),  # 高宽各减半的池化层
    layers.ReLU(),  # 激活函数
    layers.Flatten(),  # 打平层,方便全连接层处理
    layers.Dense(120, activation='relu'),  # 全连接层,120个结点
    layers.Dense(84, activation='relu'),  # 全连接层,84个结点
    layers.Dense(10, activation='relu')  # 全连接层,10个结点
])
# build 一次网络模型,给输入X的形状
network.build(input_shape=(batchsz, 28, 28, 1))
# 统计网络信息
network.summary()

# 创建损失函数的类,在实际计算时直接调用类实例
criteon = losses.CategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.95)  # batch,128,lr=0.01,acc:0.9914
# optimizer = tf.keras.optimizers.Nadam(learning_rate=0.002)  # batch,128,0,89
# optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.01)

# 训练30个epoch
epoch = 30
steps = 0
for n in range(epoch):
    for step, (x, y) in enumerate(train_db):
        with tf.GradientTape() as tape:
            # 插入通道维度 =》[b,28,28,1]
            x = tf.expand_dims(x, axis=3)
            # 前向计算,获取10类别的概率分布 [b,784]=>[b,10]
            out = network(x)
            # 计算交叉熵损失函数,标量
            loss = criteon(y, out)

        # 自动计算梯度
  • 5
    点赞
  • 56
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值