训练mnist

1.用卷积训练mnist

自己照着网上写了个用keras卷积模型训练mnist数据集的代码,熟悉一下keras的一些知识。

import tensorflow as tf

class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        if (logs.get('accuracy') > 0.998):
            print("\n准确度达到99.8%,停止训练")
            self.model.stop_training = True


callbacks = myCallback()

mnist = tf.keras.datasets.mnist
(training_images, training_labels), (test_images, test_labels) = mnist.load_data()
training_images, test_images = training_images.reshape(60000, 28, 28, 1) / 255.0, test_images.reshape(10000, 28, 28,1) / 255.0  #reshape,将输出调整为特定形状

model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu', input_shape=(28, 28, 1)),#当该层作为第一层时,需要提供input_shape参数(整数元组,不包含样本表示的轴)
    #64指的是卷积核的数量,(3,3)指的是卷积核的大小
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),#将输入展平。不影响批量大小。
    tf.keras.layers.Dense(512, activation=tf.nn.relu),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# model fitting
history = model.fit(
    training_images,
    training_labels,
    epochs=20,
    callbacks=[callbacks])

跑了好长时间,训练速度有待优化....keras的Conv层只提供自带的卷积核,自定义卷积核需要一些麻烦的操作,还没学会。Tensorflow中的tensorflow.nn模块是TensorFlow用于深度学习计算的核心模块。

跑了好长时间,训练速度有待优化…
如果只看keras的知识的话,基本就是看着忘着,或者留下个大致印象,只有不断打代码巩固才能掌握的更牢固。

2.普通网络训练mnist

import numpy as np
import tensorflow.keras as keras
import tensorflow as tf
import matplotlib.pyplot as plt
data=keras.datasets.mnist
(train_images,train_labels),(test_images,test_labels)=data.load_data()
class_name=['0','1','2','3','4','5','6','7','8','9','10']
train_images=train_images/255.0
test_images=test_images/255.0
model=keras.Sequential([keras.layers.Flatten(input_shape=(28,28)),
                        keras.layers.Dense(128,activation='relu'),
                        keras.layers.Dense(10,activation='softmax')])
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
model.fit(train_images,train_labels,epochs=5)
#这上边的代码大家应该都熟悉,下边用plt将预测的图片显示出来
prediction=model.predict(test_images)

for i in range(5):
    plt.grid(False)
    plt.imshow(test_images[i])
    plt.xlabel('Actual:'+class_name[test_labels[i]])
    plt.title("prediction:"+class_name[np.argmax(prediction[i])])
    plt.show()

在这里插入图片描述在这里插入图片描述预测的与结果一致。

3.零碎的函数

3.1显示数据集信息

print(train_images.shape)#会显示(60000,28,28)

3.2显示模型信息

看视频的时候,有个函数model.summary(),print的话,就会显示出来网络的信息。当然,model名字是自定的。
在这里插入图片描述

3.3 matplotlib方面的

来,给个链接

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值