MLP

MNIST数据集介绍

MNIST是一个入门级CV数据集,包含手写数字图片。其中60000张用来训练,10000张用来测试。
一个MINST数据单元包含一张手写数字图片(2828)以及一个图片所对应的label;
将每张2828的图片展开成向量,长度为784。那么mnist.train.images就是一个[60000,784]的张量;
mnist.train.labels是一个[60000,10]的矩阵。
##模型
无论是机器学习还是深度学习,都绕不过模型.深度学习中的模型主要是各种神经网络.
但只有模型是不够的,前提条件其实是数据,然后,后置的操作是训练,再之后是测试.

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import RMSprop
import matplotlib.pyplot as plt

batch_size = 128
epochs = 5
num_classes = 10

# 导入数据
(x_train, y_train), (x_test,y_test) = mnist.load_data()


# 归一化
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# 改变数据形状,格式为(n_samples,vector)
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)

# 控制台打印输出样本数量信息
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# one-hot 十分类标签转化为二进制
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

# create Models
model = Sequential()
model.add(Dense(units=512, activation='relu', input_shape=(784, )))
model.add(Dropout(rate=0.2))
model.add(Dense(units=512, activation='relu'))
model.add(Dropout(rate=0.2))
model.add(Dense(num_classes, activation='softmax'))
model.summary()                                  # 在控制台输出模型参数信息


model.compile(loss='categorical_crossentropy', optimizer=RMSprop(), metrics=['accuracy'])


model.fit(x_train,y_train,
          batch_size=batch_size,
          epochs = epochs,
          verbose=1,
          validation_data=(x_test,y_test)
)

n = 10   # 给出需要预测的图片数量,为了方便,只取前5张图片
predicted_number = model.predict(x_test[:n], n)

plt.figure(figsize=(10, 5))
for i in range(n):
    plt.subplot(1, n, i + 1)
    t = x_test[i].reshape(28, 28)   # 向量需要reshape为矩阵
    plt.imshow(t, cmap='gray')      # 以灰度图显示
    plt.subplots_adjust(wspace=2)   # 调整子图间的间距,挨太紧了不好看
    # 第一个数字是真实标签,第二个数字是预测数值
    # 如果预测正确,绿色显示,否则红色显示
    # 预测结果是one-hot编码,需要转化为数字
    if y_test[i].argmax() == predicted_number[i].argmax():
        plt.title('%d,%d' % (y_test[i].argmax(), predicted_number[i].argmax()), color='green')
    else:
        plt.title('%d,%d' % (y_test[i].argmax(), predicted_number[i].argmax()), color='red')
    plt.xticks([])  # 取消x轴刻度
    plt.yticks([])
plt.show()

epochs设置为5显示前十个数字

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值