安装Tensorflow教程
请移步到这篇博客 ☛ https://blog.csdn.net/weixin_38283428/article/details/84201733
下载的tensorflow版本要和自己的python版本一致。
一、Minist手写体数据集的特点
1.训练数据和测试数据各有多少
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
print('训练数据数量',mnist.train.num_examples)
print('验证数据数量',mnist.validation.num_examples)
print('测试数据数量',mnist.test.num_examples)
2.train集合数据及标签的形状
print('train集合数据矩阵形状:',mnist.train.images.shape)
print('train集合数据标签矩阵形状:',mnist.train.labels.shape)
print('train集合第一个数据标签长度、内容:',len(mnist.train.labels[0]),mnist.train.labels[0])
从上面的运行结果可以看出,在train集合数据中总共有55000个样本,每个样本有784个特征。原图片形状为28*28=784,每个图片样本展平后则有784维特征。在train集合数据标签中总共有55000个样本,每个样本有10维特征,根据所显示的内容能够推断出这幅图片为数字为7,因为它在第8个位置响应度为1,其他位置为0.
3.查看一部分图片内容
import os
import numpy as np
import matplotlib.pyplot as plt
import math
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
def drawdigit(position,image,title):
plt.subplot(*position)
plt.imshow(image,cmap='gray_r')
plt.axis('off')
plt.title(title)
def batchDraw(batch_size):
images,labels = mnist.train.next_batch(batch_size)
row_num = math.ceil(batch_size ** 0.5)
column_num = row_num
plt.figure(figsize=(row_num,column_num))
for i in range(row_num):
for j in range(column_num):
index = i * column_num + j
if index < batch_size:
position = (row_num,column_num,index+1)
image = images[index].reshape(-1,28)
title = 'actual:%d'%(np.argmax(labels[index]))
drawdigit(position,image,title)
if __name__ == '__main__':
batchDraw(196)
plt.show()
观察train集合中的数据发现
0类:有普通正常的0,未闭合成圈的0,容易被误认为6和8的0,圈极扁的0,朝着左右方向倾斜的0
1类:打印体的1(容易被误认为7),朝着左右方向倾斜的1,弯曲的1,粗细不同的1
2类:第一笔弯曲弧度不够的2,最后结束那一笔写得过长的2(容易被误认为是3)或者过于短的2,粗细不同的2
3类:朝着左右方向倾斜的3,最后一笔向上勾的弧度不够的3,写得像烟圈的3,整体弯曲弧度不够的3,粗细不同的3
4类:写得像Y、A、H的4,像飘扬的旗帜的4,写得太过紧凑的4,粗细不同4
5类:写得像S的5,因为最后一笔而容易被认为是6的5,粗细不同的5
6类:下半部分的圆圈被涂满不是空心的6,横躺的6,下半部分未闭合成圆圈的6,最后一笔写出头容易被认成4的6
7类:打印体形状的7,多加一条横杠的7,横与竖形成的角度过小的7,粗细不同的7
8类:最上面是开口的8,朝着左右方向倾斜的8,最下面是开口的8,上下两个圈被涂满的8,手写习惯为先画上面一个圈再画下面一个圈的8
9类:打印体形状的9,圆圈过大容易被误认为0的9,粗细不同的9,上方未闭合成圆圈的9,像g