源码
简介
Mnist是深度学习界的"Hello World! ",是深度学习中很好的入门范例,Mnist数据集主要是由手写数字的图片和标签组成,图片一共有10类,由0-9组成。本文会以Mnist数据库为基础,通过卷积神经网络(cnn)建立一个手写字体识别模型。在过程中能帮助大家更好的学习理解cnn的原理。
使用cnn训练模型一共可分为3个步骤:读取数据,初始化cnn网络结构,开始训练。
当然也可以直接使用tensorflow里面的api,这里是为了让大家更加清楚tensorflow中关于cnn API里面具体在做什么以及是如何完成的,这样对后期的学习会有很大的帮助。
读取数据
下面是处理Mnist数据的方法,首先read_images read_labels是分别用来读取图片内容和标签,从Mnist_data类的属性中我们可以看到,我们将Mnist数据分为训练集和测试集。因为后面我们采用的是小批量梯度下降的方法来优化梯度下降,所以依次从mnist数据集中向前读取一个batch容量的数据用来训练和测试模型,每个batch中训练图片为600张,测试图片为100张,其他属性包括当前batch位置,图片的长宽等。数据读取实质是将图片信息存放在一个矩阵中,以方便后面的计算,这里的forward和backward就是神经网络里面的向前传播和反向传播,这里大家可以先不要纠结,结合后面的网络结构的定义和讲解,大家就会明白这两个函数,forward函数是为了将数据转换为下一层网络可处理的格式,也就是下一层网络输入的矩阵的结构,而输入层是不用传递误差梯度的(不用定义权重,也就无需利用梯度更新权重)。
import os
import cv2
import struct
import numpy as np
def read_images(bin_file_name):
binfile = open(bin_file_name, 'rb')
buffers = binfile.read()
head = struct.unpack_from('>IIII', buffers, 0)
offset = struct.calcsize('>IIII')
img_num = head[1]
img_width = head[2]
img_height = head[3]
bits_size = img_num * img_height * img_width
raw_imgs = struct.unpack_from('>' + str(bits_size) + 'B', buffers, offset)
binfile.close()
imgs = np.reshape(raw_imgs, head[1:])
return imgs
def read_labels(bind_file_name):
binfile = open(bind_file_name, 'rb')
buffers = binfile.read()
head = struct.unpack_from('>II', buffers, 0)
img_num = head[1]
offset = struct.calcsize('>II')
raw_labels = struct.unpack_from('>' + str(img_num) + 'B', buffers, offset)
binfile.close()
labels = np.reshape(raw_labels, [img_num, 1])
return labels
class Mnist_data:
TRAIN = 'TRAIN'
TEST = 'TEST'
def __init__(self, data_dir):
self.mode = self.TRAIN
self.train_epoch = 0
self.test_eopch = 0
self.train_num = 600#训练时,一个batch的容量大小
self.test_num = 100
self.num = self.train_num # batch_size
self.num_output = 1 # channels mnist数据集为黑白照片,所以为1
self.train_images = read_images(os.path.join(data_dir, 'train-images-idx3-ubyte')) / 256.
self.train_labels = read_labels(os.path.join(data_dir, 'train-labels-idx1-ubyte'))
self.test_images = read_images(os.path.join(data_dir, 't10k-images-idx3-ubyte')) / 256.
self.test_labels = read_labels(os.path.join(data_dir, 't10k-labels-idx1-ubyte'))
self.train_img_num, self.output_h, self.output_w = self.train_images.shape
self.test_img_num, _, _ = self.test_images.shape
self.train_cur_index = 0
self.test_cur_index = 0
def next_batch_train_data(self):
if self.train_cur_index + self.num >= self.train_img_num:
t1 = np.arange(self.train_cur_index, self.train_img_num)
t2 = np.arange(0, self.train_cur_index + self.num - self.train_img_num)
self.output_train_index = np.append(t1, t2)
self.train_epoch = self.train_epoch + 1
self.train_cur_index = self.train_cur_index + self.num - self.train_img_num
else:
self.output_train_index = np.arange(self.train_cur_index, self.train_cur_index + self.num)
self.train_cur_index = self.train_cur_index + self.num
def next_batch_test_data(self):
if self.test_cur_index + self.num >= self.test_img_num:
t1 = np.arange(self.test_cur_index, self.test_img_num)
t2 = np.arange(0, self.test_cur_index + self.num - self.test_img_num)
self.output_test_index = np.append(t1, t2)
self.test_epoch = self.test_eopch = + 1
self.test_cur_index = self.test_cur_index + self.num - self.test_img_num
else:
self.output_test_index = np.arange(self.test_cur_index, self.test_cur_index + self.num)
self.test_cur_index = self.test_cur_index + self.num
def forward(self):
if self.mode == self.TRAIN:
self.output_images = self.train_images[self.output_train_index].reshape(self.num, 1, self.output_h, self.output_w)
self.output_labels = self.train_labels[self.output_train_index].reshape(-1)
elif self.mode == self.TEST:
self.output_images = self.test_images[self.output_test_index].reshape(self.num, 1, self.output_h, self.output_w)
self.output_labels = self.test_labels[self.output_test_index].reshape(-1)
else:
return None
return self.output_images
def backward(self, diff):
pass
def get_data(self):
return self.output_images
def get_label(self):
return self.output_labels
def get_mode(self):
return self.mode
def set_mode(self, mode):
self.mode = mode
if self.mode == self.TRAIN:
self.num = self.train_num
elif self.mode == self.TEST:
self.num = self.test_num
输入层的工作准备好以后,接下来的内容我们学习如何 初始化cnn网络结构以及定义cnn中的卷积层和池化层。