读代码0 DBN_MNIST
(本文写于对这篇博客的阅读) 本文代码来源博客
简介
任务:图像分类
模型:DBN
数据集:MNIST
环境:Tensorflow
MNIST | 图片数 | 标签数 |
---|---|---|
训练集 | 60000 | 60000 |
测试集 | 10000 | 10000 |
模块分析:
两大模块:(1)数据集处理 (2)搭建DBN
说明:
(1)以下将用np代表numpy
(2)训练集和测试集的图片、标签处理方法一致,以下以训练集数据
1.数据集处理 input_data.py
1.1 def maybe_download()
数据集在网站上以压缩包形式存在,该函数功能是从网站下载数据
def maybe_download(filename, work_directory):
"""Download the data from Yann's website, unless it's already here."""
if not os.path.exists(work_directory): # 如果存压缩包文件路径不存在,产生新路径
os.mkdir(work_directory)
filepath = os.path.join(work_directory, filename) # 文件路径拼接用os.path.join,用python中string的+拼接是不可以的
if not os.path.exists(filepath): # 如果文件不存在,下载
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
statinfo = os.stat(filepath)
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
return filepath # 返回文件的路径(MNIST_data/train-images-idx3-ubyte.gz/)
1.2 def _read32()
该函数功能是读取4个字节,以大端的方式转化成无符号整型,每次返回一个单值数组
def _read32(bytestream):
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
return numpy.frombuffer(bytestream.read(4), dtype=dt)
关于numpy.frombuffer():函数介绍
1.3 def extract_images()
该函数功能是解压并返回numpy数组
以训练集图片为例(60000张,28 * 28),最后返回(60000, 28, 28, 1)的numpy数组。
def extract_images(filename):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
''' _read32(bytestream)每次返回一个单值的数组,多次返回的数据如下:
[2051] [60000] [28] [28] [0] [0] [0] [0] [0] [0] [0]
[0] [0] [0] [0] [0] [0] [0] [0] [0] [0]
[0] [0] [0] [0] [0] [0] [0] [0] [0] [0]
[0] [0] [0] [0] [0] [0] [0] [0] [0] [0]
[0] [51515922] [2122886938] [2801792895] ...
,2051不知道是什么,60000是图片个数,28和28是图片的长和宽,接下来是各个像素值
'''
magic = _read32(bytestream) # 2051
if magic != 2051:
raise ValueError(
'Invalid magic number %d in MNIST image file: %s' %
(magic, filename))
num_images = _read32(bytestream) # 60000
rows = _read32(bytestream) # 28
cols = _read32(bytestream) # 28
buf = bytestream.read(rows[0] * cols[0] * num_images[0]) # buf是28x28x60000=47040000大小的一个bytes字节流
data = numpy.frombuffer(buf, dtype=numpy.uint8) # data是与buf一样大的numpy.ndarray类型的(47040000,)数组,一维,存放的是6万张图片的像素值,值范围0~255
data = data.reshape(num_images[0], rows[0], cols[0], 1) # 转成(60000, 28, 28, 1)
return data
1.4 def extract_labels()
以训练集标签为例,该函数返回(60000, 10)的numpy数组
def dense_to_one_hot(labels_dense, num_classes=10):
"""Convert class labels from scalars to one-hot vectors."""
num_labels = labels_dense.shape[0] # 60000
index_offset = numpy.arange(num_labels) * num_classes # arange返回从0到59999num_labels-1的numpy.ndarray, *10是每个元素都*10, 此句得到数组[ 0 10 20 ... 599970 599980 599990]
labels_one_hot = numpy.zeros((num_labels, num_classes)) # 大小为(60000, 10)的全0数组
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 # ravel()将多维展成一维的, flat返回的是一个迭代器,可以用for访问数组每一个元素
return labels_one_hot # 这里labels_one_hot得到了(60000, 10)的numpy类型的数组
def extract_labels(filename, one_hot=False):
"""Extract the labels into a 1D uint8 numpy array [index]."""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
magic = _read32(bytestream)
if magic != 2049:
raise ValueError(
'Invalid magic number %d in MNIST label file: %s' %
(magic, filename))
num_items = _read32(bytestream)
buf = bytestream.read(num_items[0]) # 60000
labels = numpy.frombuffer(buf, dtype=numpy.uint8) # 得到numpy.ndarray类型的一维数组(60000,),是0~9的标签值
if one_hot:
return dense_to_one_hot(labels)
return labels # 这里labels得到了(60000, 10)的numpy类型的one_hot数组
def dense_to_one_hot() 计算方法:
1.5 变量说明
划分验证集
validation_images = train_images[:VALIDATION_SIZE] # (5000, 28, 28, 1)numpy数组
validation_labels = train_labels[:VALIDATION_SIZE]
train_images = train_images[VALIDATION_SIZE:] # (55000, 28, 28, 1)numpy数组
train_labels = train_labels[VALIDATION_SIZE:]
为后续容易理解,给出各变量含义:
trX:训练集图片,(55000, 28, 28, 1)
trY:训练集标签,(55000, 10)
teX:测试集图片,(10000, 28, 28, 1)
teY:测试集标签,(10000, 10)
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
2. 搭建DBN
2.1 参数设置–opts.py
参数设置为 opts = DLOption(10, 1., 100, 0.0, 0., 0.)
class DLOption(object):
def __init__(self, epoches, learning_rate, batchsize, momentum, penaltyL2, dropoutProb):
self._epoches = epoches
self._learning_rate = learning_rate
self._batchsize = batchsize
self._momentum = momentum
self._penaltyL2 = penaltyL2
self._dropoutProb = dropoutProb
2.2 tile_raster_images() 数据集可视化
效果:
3. 搭建rmb rmb_tf.py
功能:
(1)定义rbm类
(2)前向传播
4. 搭建dbn dbn_tf.py
功能:设置DBN结构,当X.shape = 55000x784,以 sizes = [400, 100]为例,搭建了784->400, 400->100两个RBM
from rbm_tf import RBM
class DBN(object):
def __init__(self, sizes, opts, X):
self._sizes = sizes
self._opts = opts
self._X = X
self.rbm_list = []
input_size = X.shape[1] # X shape = 55000x784
for i, size in enumerate(self._sizes):
self.rbm_list.append(RBM("rbm%d" % i, input_size, size, self._opts))
input_size = size
def train(self):
X = self._X
for rbm in self.rbm_list:
rbm.train(X)
X = rbm.rbmup(X)