# -*- coding: utf-8 -*-
"""
Created on Fri Dec 15 10:07:41 2017
@author: acer
"""
from _future_ import print_function
import gzip
import os
import urllib #urllib这个是python内置的HTTP请求库
import numpy
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def maybe_downloadd(filename,work_directory):
#首先需要对于目录文件判断是否存在,不存在那就创建该目录
if not os.path.exists(work_directory):
os.mkdir(work_directory)
#需要读取的文件路径
#os.path.join一般适用于拼接路径的时候,这里是将目录文件放置指定文件路径中
filepath=os.path.join(work_directory,filename)
if not os.path.exists(filepath):
#urlretrieve就是将远程数据下载到本地,filename是指保存到本地的路径
filepath,_=urllib.urlretrieve(SOURCE_URL+filename,filepath)
#os.stat就是将文件的相关属性读出来,然后使用stat模块处理
statinfo=os.stat(filepath)
print('Succesfully downloaded',filename,statinfo.st_size,'bytes.')
return filepath
def _read32(bytestream):
#将数据类型设置为uint32,新的字节顺序是>
dt=numpy.dtype(numpy.uint32).newbyteorder('>')
#numpy.frombuffer这个函数就是将buffer输入转换为ndarray对象,read(4)每次读四个字节
return numpy.frombuffer(bytestream.read(4),dtype=dt)[0]
def extract_images(f):
print('Extracting',f.name)
#with语句并没有太理解,但是看到一句这样解释:with语句是提供一个有效的机制,让代码更加简练,同时在异常产生时清理工作更加简单
#GzipFile这个函数支持文件名称方式读写文件,也支持提供文件对象fileobj方式进行读取文件
#从f中提取出图像
with gzip.GzipFile(fileobj=f) as bytestream:
magic=_read32(bytestream)
if magic!=2051:
raise ValueError('Invalid magic number %d in MMIST image file: %s'%(magic,f.name))
num_images=_read32(bytestream)
rows=_read32(bytestream)
cols=_read32(bytestream)
buf=bytestream.read(rows*cols*num_images)
data=numpy.frombuffer(buf,dtyte=numpy.uint8)
data=data.reshape(num_images,rows,cols,1)
return data
#将稠密标签向量变成稀疏的标签矩阵
#eg:若原向量的第i行为3,则对应稀疏矩阵的第i行下标为3的值为1,其余为0
def dense_to_one_hot(labels_dense,numclasses):
num_labels=labels_dense.shape[0]
#numpy.arange是返回array对象
index_offset=numpy.arange(num_labels)*num_classes
labels_one_hot=numpy.zeros((num_labels,num_classes))
#labels_dense.ravel()将整个数组展成一个一维数组
#labels_dense.flat[i]即将labels_dense看成一个一维数组,取其第i个变量v
labels_one_hot.flat[index_offset+labels_dense.ravel()]=1
return labels_one_hot
def extract_labels(filename,one_hot=False):
print('Extracting',filename)
with gzip.open(filename) as bytestream:
magic=_read32(bytestream)
if magic!=2049:
raise ValueError('Invalid magic number %d in MINST label file:%s'%(magic,filename))
num_items=_read32(byestream)
buf=bytestream.read(num_items)
#将标签放入一维数组,类型为uint8
labels=numpy.frombuffer(buf,dtype=numpy.uint8)
if one_hot:
return dense_to_one_hot(labels)
return labels
class DataSet(object) :
#创建Data_Set这个类,_init_是这个类的初始化函数
#fake_data是什么?
def _init_(self,images,labels,fake_data=False):
if fake_data:
self._num_examples=10000
else:
#assert函数是断言函数,判断条件是否为真,若为假,会给出AssertionError,为真则继续
#这里是为了判断特征数是否与标签样本数是否相同
assert images.shape[0]=labels.shape[0],("images.shape:%s labels.shape:%s"%(images.shape,labels.shape))
#得到样本数
self._num_examples=images.shape[0]
# Convert shape from [num examples, rows, columns, depth] to [num examples, rows*columns] (assuming depth == 1)
assert images.shape[3]=1
images=images.reshape(images.shape[0],images.shape[1]*images.shape[2])
#Convert from [0, 255] -> [0.0, 1.0]
images=images.astype(numpy.float32)
#使用点乘进行归一化
images=numpy.multiply(images,1.0/255.0)
self._images=images
self._labels=labels
#这里的epoch_completed是什么意思
self._epochs_completed=0
self._index_in_epoch=0
def images(self):
return self._images
def labels(self):
return self._labels
def num_examples(self):
return self._num_examples
def epochs_completed(self):
return self._epochs_completed
def next_batch(self,batches_size,fake_data=False):
if fake_data:
fake_image = [1] * 784
if self.one_hot:
fake_label = [1] + [0] * 9
else:
fake_label=0
return [fake_image for _in range(batch_size)],[fake_label for _in range(batch_size)]
start=self._index_in_epoch
self._index_in_epoch+=batch_size
#若当前训练读取的index>总体的images数时,则读取读取开始的batch_size大小的数据
if self._index_in_epoch>self._num_examples:
#Finished epoch
self._epochs_completed+=1
#Shuffle the data随机取数据
perm=numpy.arange(self._num_examples)
numpy.random.shuffle(perm)
self._images=self._images[perm]
self._labels=self._labels[perm]
#Start next epoch
start=0
self._index_in_epoch=batch_size
assert batch_size<=self._num_examples
end=self._index_in_epoch
return self._images[start:end],self._labels[start:end]
def read_data_sets(train_dir,fake_data=False,one_hot=False):
class DataSets(objects):
pass
data_sets=DataSets()
if fake_data:
data_sets.train=DataSet([],[],fake_data=True)
data_sets.validation=DataSet([],[],fake_data=True)
data_sets.test=DataSet([],[],fake_data=True)
return data_sets
TRAIN_IMAGES='train-images-idx3-ubyte.gz'
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
VALIDTION_SIZE=5000
local_file=maybe_download(TRAIN_IMAGES,train_dir)
train_images=extract_images(local_file)
local_file = maybe_download(TRAIN_LABELS, train_dir)
train_labels = extract_labels(local_file, one_hot=one_hot)
local_file = maybe_download(TEST_IMAGES, train_dir)
test_images = extract_images(local_file)
local_file = maybe_download(TEST_LABELS, train_dir)
test_labels = extract_labels(local_file, one_hot=one_hot)
validation_images = train_images[:VALIDATION_SIZE]
validation_labels = train_labels[:VALIDATION_SIZE]
train_images = train_images[VALIDATION_SIZE:]
train_labels = train_labels[VALIDATION_SIZE:]
data_sets.train = DataSet(train_images, train_labels)
data_sets.validation = DataSet(validation_images, validation_labels)
data_sets.test = DataSet(test_images, test_labels)
return data_sets