tensorflow中手写数据集inputdata代码解释


# -*- coding: utf-8 -*-
"""
Created on Fri Dec 15 10:07:41 2017

@author: acer
"""

from _future_ import print_function
import gzip
import os
import urllib    #urllib这个是python内置的HTTP请求库
import numpy

SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def maybe_downloadd(filename,work_directory):
	#首先需要对于目录文件判断是否存在,不存在那就创建该目录
	if not os.path.exists(work_directory):
		os.mkdir(work_directory)
	#需要读取的文件路径
	#os.path.join一般适用于拼接路径的时候,这里是将目录文件放置指定文件路径中
	filepath=os.path.join(work_directory,filename)
	if not os.path.exists(filepath):
		#urlretrieve就是将远程数据下载到本地,filename是指保存到本地的路径
		filepath,_=urllib.urlretrieve(SOURCE_URL+filename,filepath)
		#os.stat就是将文件的相关属性读出来,然后使用stat模块处理
		statinfo=os.stat(filepath)
		print('Succesfully downloaded',filename,statinfo.st_size,'bytes.')
	return filepath


def _read32(bytestream):
	#将数据类型设置为uint32,新的字节顺序是>
	dt=numpy.dtype(numpy.uint32).newbyteorder('>')
	#numpy.frombuffer这个函数就是将buffer输入转换为ndarray对象,read(4)每次读四个字节
	return numpy.frombuffer(bytestream.read(4),dtype=dt)[0]

def extract_images(f):
	print('Extracting',f.name)
	#with语句并没有太理解,但是看到一句这样解释:with语句是提供一个有效的机制,让代码更加简练,同时在异常产生时清理工作更加简单
	#GzipFile这个函数支持文件名称方式读写文件,也支持提供文件对象fileobj方式进行读取文件
	#从f中提取出图像
	with gzip.GzipFile(fileobj=f) as bytestream:
		magic=_read32(bytestream)
		if magic!=2051:
			raise ValueError('Invalid magic number %d in MMIST image file: %s'%(magic,f.name))
		num_images=_read32(bytestream)
		rows=_read32(bytestream)
		cols=_read32(bytestream)
		buf=bytestream.read(rows*cols*num_images)
		data=numpy.frombuffer(buf,dtyte=numpy.uint8)
		data=data.reshape(num_images,rows,cols,1)
		return data


#将稠密标签向量变成稀疏的标签矩阵
#eg:若原向量的第i行为3,则对应稀疏矩阵的第i行下标为3的值为1,其余为0
def dense_to_one_hot(labels_dense,numclasses):
	num_labels=labels_dense.shape[0]
	#numpy.arange是返回array对象
	index_offset=numpy.arange(num_labels)*num_classes
	labels_one_hot=numpy.zeros((num_labels,num_classes))
	#labels_dense.ravel()将整个数组展成一个一维数组
	#labels_dense.flat[i]即将labels_dense看成一个一维数组,取其第i个变量v
	labels_one_hot.flat[index_offset+labels_dense.ravel()]=1
	return labels_one_hot


def extract_labels(filename,one_hot=False):
	print('Extracting',filename)
	with gzip.open(filename) as bytestream:
		magic=_read32(bytestream)
		if magic!=2049:
			raise ValueError('Invalid magic number %d in MINST label file:%s'%(magic,filename))
		num_items=_read32(byestream)
		buf=bytestream.read(num_items)
		#将标签放入一维数组,类型为uint8
		labels=numpy.frombuffer(buf,dtype=numpy.uint8)
		if one_hot:
			return dense_to_one_hot(labels)
		return labels



class DataSet(object) :
	#创建Data_Set这个类,_init_是这个类的初始化函数
	#fake_data是什么?
	def _init_(self,images,labels,fake_data=False):
		if fake_data:
			self._num_examples=10000
		else:
			#assert函数是断言函数,判断条件是否为真,若为假,会给出AssertionError,为真则继续
			#这里是为了判断特征数是否与标签样本数是否相同
			assert images.shape[0]=labels.shape[0],("images.shape:%s labels.shape:%s"%(images.shape,labels.shape))
			#得到样本数
			self._num_examples=images.shape[0]
			# Convert shape from [num examples, rows, columns, depth] to [num examples, rows*columns] (assuming depth == 1)
			assert images.shape[3]=1
			images=images.reshape(images.shape[0],images.shape[1]*images.shape[2])
			#Convert from [0, 255] -> [0.0, 1.0]
			images=images.astype(numpy.float32)
			#使用点乘进行归一化
			images=numpy.multiply(images,1.0/255.0)
		self._images=images
		self._labels=labels
		#这里的epoch_completed是什么意思
		self._epochs_completed=0
		self._index_in_epoch=0
	def images(self):
		return self._images
	def labels(self):
		return self._labels
	def num_examples(self):
		return self._num_examples
	def epochs_completed(self):
		return self._epochs_completed
	def next_batch(self,batches_size,fake_data=False):
		if fake_data:
			fake_image = [1] * 784
			if self.one_hot:
				fake_label = [1] + [0] * 9
			else:
				fake_label=0
			return [fake_image for _in range(batch_size)],[fake_label for _in range(batch_size)]
		start=self._index_in_epoch
		self._index_in_epoch+=batch_size
		#若当前训练读取的index>总体的images数时,则读取读取开始的batch_size大小的数据
		if self._index_in_epoch>self._num_examples:
			#Finished epoch
			self._epochs_completed+=1
			#Shuffle the data随机取数据
			perm=numpy.arange(self._num_examples)
			numpy.random.shuffle(perm)
			self._images=self._images[perm]
			self._labels=self._labels[perm]
			#Start next epoch
			start=0
			self._index_in_epoch=batch_size
			assert batch_size<=self._num_examples
		end=self._index_in_epoch
		return self._images[start:end],self._labels[start:end]
	
	def read_data_sets(train_dir,fake_data=False,one_hot=False):
		class DataSets(objects):
			pass
		data_sets=DataSets()
		if fake_data:
			data_sets.train=DataSet([],[],fake_data=True)
			data_sets.validation=DataSet([],[],fake_data=True)
			data_sets.test=DataSet([],[],fake_data=True)
			return data_sets
		TRAIN_IMAGES='train-images-idx3-ubyte.gz'
		TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
		TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
		TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
		VALIDTION_SIZE=5000
		local_file=maybe_download(TRAIN_IMAGES,train_dir)
		train_images=extract_images(local_file)
		local_file = maybe_download(TRAIN_LABELS, train_dir)
		train_labels = extract_labels(local_file, one_hot=one_hot)
		local_file = maybe_download(TEST_IMAGES, train_dir)
		test_images = extract_images(local_file)
		local_file = maybe_download(TEST_LABELS, train_dir)
		test_labels = extract_labels(local_file, one_hot=one_hot)
		validation_images = train_images[:VALIDATION_SIZE]
		validation_labels = train_labels[:VALIDATION_SIZE]
		train_images = train_images[VALIDATION_SIZE:]
		train_labels = train_labels[VALIDATION_SIZE:]
		data_sets.train = DataSet(train_images, train_labels)
		data_sets.validation = DataSet(validation_images, validation_labels)
		data_sets.test = DataSet(test_images, test_labels)
		return data_sets

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值