一、环境
本系列文章主要基于windows7,Anaconda。
Anaconda是个很有用的工具,安装各种库文件都非常方便,除了网络卡顿导致安装失败,目前都没发现其他问题
二、写在前面
本文主要基于TensorFlow中文社区的一系列文章进行学习和记录,对mnist数据集和神经网络相关原理进行介绍。
http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html
上图显示了,神经网络训练的基本步骤,接下来将按照图中的几步来讲解。
三、mnist数据集
MNIST是在机器学习领域中的一个经典问题。该问题解决的是把28x28像素的灰度手写数字图片识别为相应的数字,其中数字的范围从0到9.
该数据集主要包含了以下四个部分:
文件 | 内容 |
train-images-idx3-ubyte.gz | 训练集图片 - 55000 张 训练图片, 5000 张 验证图片 |
train-labels-idx1-ubyte.gz | 训练集图片对应的数字标签 |
t10k-images-idx3-ubyte.gz | 测试集图片 - 10000 张 图片 |
t10k-labels-idx1-ubyte.gz | 测试集图片对应的数字标签 |
可以通过以下的代码获取mnist数据集:
from tensorflow.examples.tutorials.mnist import input_data #导入模块
mnist = input_data.read_data_sets("mnist_dd/", one_hot=True) #下载数据集,读取数据
获取数据集之后,我们就可以进行开始神经网络的构建了。以下
在构建网络之前,我们先来看一下这几个压缩文件中到底有什么文件呢,解压完之后我们得到的是一个.idx3-ubyte文件,用notepad++打开,开到的是一系列的16进制数据。如下图所示(t10k-images.idx3-ubyte):
前面的16byte(每4个byte)分别表示:
参数 | 十六进制 | 十进制 |
魔数 | 0x00000803 | 2051 |
图片数 | 0x00002710 | 10000 |
行像素点 | 0x0000001c | 28 |
列像素点 | 0x0000001c | 28 |
后面的数据是按照每张图片28*28个像素点照顺序排列的。
我们可以用下面的代码去读一下数据并显示
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 13 14:38:18 2018
@author: ZCH
"""
import numpy as np
import struct
import matplotlib.pyplot as plt
filename = 'MNIST_data/t10k-images.idx3-ubyte'
binfile = open(filename,'rb')#以二进制方式打开
buf = binfile.read()
index = 0
magic, numImages, numRows, numColums = struct.unpack_from('>IIII',buf,index)#读取4个32 int 大端存取
print (magic,' ',numImages,' ',numRows,' ',numColums )
index += struct.calcsize('>IIII')
im = struct.unpack_from('>784B',buf,index)#每张图是28*28=784Byte,这里只显示第一张图
index += struct.calcsize('>784B' )
im = np.array(im)
im = im.reshape(28,28)
print( im )
fig = plt.figure()
plt.imshow(im,cmap = 'binary')#黑白显示
plt.show()
运行结果如下:
白色的为0,黑色的为255
代码讲解:
struct.unpack_from('>IIII',buf,index)
调用struct模块对数据进行解析,'I'表示四个字节,四个则表示一次性读取16字节;‘B’表示一个字节,‘784B’表示读取784个字节;‘>’表示以大端格式存储数据,‘<’表示以小端格式存储数据;buf指文件内容;index指读取文件的起始位。
我们也可以用以下的方式将图片数据读取出来,并且保存为图片。
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 29 14:38:57 2018
@author: ZCH
"""
from PIL import Image
import struct
filename = 'MNIST_data/t10k-images.idx3-ubyte'
def readfile(file):
fd = open(filename,'rb')#以二进制方式打开
buf = fd.read()
fd.close()
index = 0
magic, numImages, numRows, numColums = struct.unpack_from('>IIII',buf,index)#读取4个32 int 大端存取
print (magic,' ',numImages,' ',numRows,' ',numColums )
index += struct.calcsize('>IIII')
for i in range(numImages):
image = Image.new('L',(numColums,numRows))
for x in range(numRows):
for y in range(numColums):
image.putpixel((y,x),int(struct.unpack_from('>B',buf,index)[0]))
index += struct.calcsize('>B')
image.save('test/'+str(i)+'.png')
readfile(filename)