python读取MNIST image数据

本文介绍了三种不同的MNIST数据集加载方式:使用numpy和struct直接读取原始二进制文件;通过MnistReader类方便地进行数据集的批量加载,并支持数据维度调整和标签的one-hot编码;利用gzip和struct读取gz压缩文件。
摘要由CSDN通过智能技术生成

Lecun Mnist数据集下载

import numpy as np
import struct

def loadImageSet(which=0): print "load image set" binfile=None if which==0: binfile = open("..//dataset//train-images-idx3-ubyte", 'rb') else: binfile= open("..//dataset//t10k-images-idx3-ubyte", 'rb') buffers = binfile.read() head = struct.unpack_from('>IIII' , buffers ,0) print "head,",head offset=struct.calcsize('>IIII') imgNum=head[1] width=head[2] height=head[3] #[60000]*28*28 bits=imgNum*width*height bitsString='>'+str(bits)+'B' #like '>47040000B' imgs=struct.unpack_from(bitsString,buffers,offset) binfile.close() imgs=np.reshape(imgs,[imgNum,width,height]) print "load imgs finished" return imgs def loadLabelSet(which=0): print "load label set" binfile=None if which==0: binfile = open("..//dataset//train-labels-idx1-ubyte", 'rb') else: binfile= open("..//dataset//t10k-labels-idx1-ubyte", 'rb') buffers = binfile.read() head = struct.unpack_from('>II' , buffers ,0) print "head,",head imgNum=head[1] offset = struct.calcsize('>II') numString='>'+str(imgNum)+"B" labels= struct.unpack_from(numString , buffers , offset) binfile.close() labels=np.reshape(labels,[imgNum,1]) #print labels print 'load label finished' return labels if __name__=="__main__": imgs=loadImageSet() #import PlotUtil as pu #pu.showImgMatrix(imgs[0]) loadLabelSet()

及方便训练的reader

import numpy as np
import struct
import gzip
import cPickle class MnistReader(): def __init__(self,mnist_path,data_dim=1,one_hot=True): ''' mnist_path: the path of mnist.pkl.gz data_dim=1 [N,784] data_dim=3 [N,28,28,1] one_hot: one hot encoding(like: [0,1,0,0,0,0,0,0,0,0]) if true ''' self.mnist_path=mnist_path self.data_dim=data_dim self.one_hot=one_hot self.load_minist(mnist_path) self.train_datalabel=zip(self.train_x,self.train_y) self.valid_datalabel=zip(self.valid_x,self.valid_y) self.batch_offset_train=0 def next_batch_train(self,batch_size): ''' return list of images with shape [N,784] or [N,28,28,1] dependents on self.data_dim and list of labels with shape [N] or [N,10] dependents on self.one_hot ''' if self.batch_offset_train<len(self.train_datalabel)//batch_size: imgs=list();labels=list() for d,l in self.train_datalabel[self.batch_offset_train:self.batch_offset_train+batch_size]: if self.data_dim==3: d=np.reshape(d, [28,28,1]) imgs.append(d) if self.one_hot: a=np.zeros(10) a[l]=1 labels.append(l) else: labels.append(l) self.batch_offset_train+=1 return imgs,labels else: self.batch_offset_train=0 np.random.shuffle(self.train_datalabel) return self.next_batch_train(batch_size) def next_batch_val(self,batch_size): ''' return list of images with shape [N,784] or [N,28,28,1] dependents on self.data_dim and list of labels with shape [N,1] or [N,10] dependents on self.one_hot ''' np.random.shuffle(self.valid_datalabel) imgs=list();labels=list() for d,l in self.train_datalabel[0:batch_size]: if self.data_dim==3: d=np.reshape(d, [28,28,1]) imgs.append(d) if self.one_hot: a=np.zeros(10) a[l]=1 labels.append(l) else: labels.append(l) return imgs,labels def load_minist(self,dataset): print "load dataset" f = gzip.open(dataset, 'rb') train_set, valid_set, test_set = cPickle.load(f) f.close() self.train_x,self.train_y=train_set self.valid_x,self.valid_y=valid_set self.test_x , self.test_y=test_set print "train image,label shape:",self.train_x.shape,self.train_y.shape print "valid image,label shape:",self.valid_x.shape,self.valid_y.shape print "test image,label shape:",self.test_x.shape,self.test_y.shape print "load dataset end" if __name__=="__main__": mnist=MnistReader('../dataset/mnist.pkl.gz',data_dim=3) data,label=mnist.next_batch_train(batch_size=1) print data print label 

第三种加载方式需要 gzip和struct

import gzip, struct

def _read(image,label): minist_dir = 'your_dir/' with gzip.open(minist_dir+label) as flbl: magic, num = struct.unpack(">II", flbl.read(8)) label = np.fromstring(flbl.read(), dtype=np.int8) with gzip.open(minist_dir+image, 'rb') as fimg: magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16)) image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols) return image,label def get_data(): train_img,train_label = _read( 'train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz') test_img,test_label = _read( 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz') return [train_img,train_label,test_img,test_label]

转载于:https://www.cnblogs.com/judejie/p/9143974.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值