版权声明:本文为博主原创文章,未经博主允许不得转载。
有很多机器学习的公开数据都需要手工编码读取,当然自己写代码读取是机器学习应用的基本能力,这里为了大家方便开发代码,避免重复发明轮子。
关于cifar数据集,点击这里,因为其下载比较慢,所以可以用csdn的下载地址下载cifar-10,cifar-10 csdn地址
下载后将其解压,如路径为: /xxx/cifar-10-batches-py/
代码很简单没有写注释,读取代码如下:
- import cPickle
- import numpy as np
- import os
- class Cifar10DataReader():
- def __init__(self,cifar_folder,onehot=True):
- self.cifar_folder=cifar_folder
- self.onehot=onehot
- self.data_index=1
- self.read_next=True
- self.data_label_train=None
- self.data_label_test=None
- self.batch_index=0
- def unpickle(self,f):
- fo = open(f, 'rb')
- d = cPickle.load(fo)
- fo.close()
- return d
- def next_train_data(self,batch_size=100):
- assert 10000%batch_size==0,"10000%batch_size!=0"
- rdata=None
- rlabel=None
- if self.read_next:
- f=os.path.join(self.cifar_folder,"data_batch_%s"%(self.data_index))
- print 'read: %s'%f
- dic_train=self.unpickle(f)
- self.data_label_train=zip(dic_train['data'],dic_train['labels'])#label 0~9
- np.random.shuffle(self.data_label_train)
- self.read_next=False
- if self.data_index==5:
- self.data_index=1
- else:
- self.data_index+=1
- if self.batch_index<len(self.data_label_train)//batch_size:
- #print self.batch_index
- datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]
- self.batch_index+=1
- rdata,rlabel=self._decode(datum,self.onehot)
- else:
- self.batch_index=0
- self.read_next=True
- return self.next_train_data(batch_size=batch_size)
- return rdata,rlabel
- def _decode(self,datum,onehot):
- rdata=list();rlabel=list()
- if onehot:
- for d,l in datum:
- rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))
- hot=np.zeros(10)
- hot[int(l)]=1
- rlabel.append(hot)
- else:
- for d,l in datum:
- rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))
- rlabel.append(int(l))
- return rdata,rlabel
- def next_test_data(self,batch_size=100):
- if self.data_label_test is None:
- f=os.path.join(self.cifar_folder,"test_batch")
- print 'read: %s'%f
- dic_test=self.unpickle(f)
- data=dic_test['data']
- labels=dic_test['labels']#0~9
- self.data_label_test=zip(data,labels)
- np.random.shuffle(self.data_label_test)
- datum=self.data_label_test[0:batch_size]
- return self._decode(datum,self.onehot)
- if __name__=="__main__":
- dr=Cifar10DataReader(cifar_folder="/xxx/cifar-10-batches-py/")
- import matplotlib.pyplot as plt
- d,l=dr.next_test_data()
- print np.shape(d),np.shape(l)
- plt.imshow(d[0])
- plt.show()
- for i in xrange(600):
- d,l=dr.next_train_data(batch_size=100)
- print np.shape(d),np.shape(l)
cifar-100的数据读取(测试和cifar-10一样就不写了,这里面有coarse_labels,即:大类别,需要的话可以自己添加)
- import cPickle
- import numpy as np
- import os
- class Cifar100DataReader():
- def __init__(self,cifar_folder,onehot=True):
- self.cifar_folder=cifar_folder
- self.onehot=onehot
- self.data_label_train=None
- self.data_label_test=None
- self.batch_index=0
- f=os.path.join(self.cifar_folder,"train")
- print 'read: %s'%f
- dic_train=unpickle(f)
- self.data_label_train=zip(dic_train['data'],dic_train['fine_labels'])#label 0~99
- np.random.shuffle(self.data_label_train)
- def next_train_data(self,batch_size=100):
- """
- cifar100 data content:
- {
- "coarse_labels":[0,...,19],#0~19 super category
- "filenames":["volcano_s_000012.png",...],
- "batch_label":"",
- "fine_labels":[0,1...99]#0~99 category
- }
- return list of numpy arrays [na,...,na] with specific batch_size
- na: N dimensional numpy array
- """
- if self.batch_index<len(self.data_label_train)//batch_size:
- #print self.batch_index
- datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]
- self.batch_index+=1
- return self._decode(datum,self.onehot)
- else:
- self.batch_index=0
- np.random.shuffle(self.data_label_train)
- datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]
- self.batch_index+=1
- return self._decode(datum,self.onehot)
- def _decode(self,datum,onehot):
- rdata=list();rlabel=list()
- if onehot:
- for d,l in datum:
- rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))
- hot=np.zeros(100)
- hot[int(l)]=1
- rlabel.append(hot)
- else:
- for d,l in datum:
- rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))
- rlabel.append(int(l))
- return rdata,rlabel
- def next_test_data(self,batch_size=100):
- '''''
- return list of numpy arrays [na,...,na] with specific batch_size
- na: N dimensional numpy array
- '''
- if self.data_label_test is None:
- f=os.path.join(self.cifar_folder,"test")
- print 'read: %s'%f
- dic_test=unpickle(f)
- data=dic_test['data']
- #print len(dic_test["coarse_labels"])
- #print len(dic_test["filenames"])
- labels=dic_test['fine_labels']#0~99
- self.data_label_test=zip(data,labels)
- np.random.shuffle(self.data_label_test)
- datum=self.data_label_test[0:batch_size]
- return self._decode(datum,self.onehot)