Cifar官网给出的python接口的文件都是用python cPickle工具”pickled”的,可以看见 cifar 官网给出的例程是:
python 2
def unpickle(file):
import cPickle
with open(file, 'rb') as fo:
dict = cPickle.load(fo)
return dict
python 3:
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
这里给出Python3的示例代码:
import os
import pickle
import numpy as np
import sklearn
import sklearn.linear_model
import lmdb
import caffe
def unpickle(file):
fo = open(file, 'rb')
dict = pickle.load(fo,encoding ='bytes')
fo.close()
return dict
# 调用sklearn对数据进行shuffle操作
def shuffle_data(data, labels):
data, _, labels, _ = sklearn.cross_validation.train_test_split(
data, labels, test_size=0.0, random_state=42
)
return data, labels
def load_data(train_file):
d = unpickle(train_file)
#dict_keys([b'batch_label', b'filenames', b'data', b'coarse_labels', b'fine_labels']),每个键值前面都有一个b,不同于 python2
data = d[b'data']
fine_labels = d[b'fine_labels']
length = len(d[b'fine_labels'])
data, labels = shuffle_data(
data,
np.array(fine_labels)
)
return (
data.reshape(length, 3, 32, 32),
labels
)
if __name__ == '__main__':
# 解压后的 cifar-100-python 路径
cifar_python_directory = os.path.abspath(r'F:\Software_download\ChromeDownload\cifar-100-python.tar\cifar-100-python')
print('Converting...')
cifar_caffe_directory = os.path.abspath('cifar100_train_lmdb')
if not os.path.exists(cifar_caffe_directory):
X, y_f = load_data(os.path.join(cifar_python_directory, 'train'))
Xt, yt_f = load_data(os.path.join(cifar_python_directory, 'test'))
print('Data is fully loaded,now truly convertung.')
# lmdb操作,将数据写入数据库
env = lmdb.open(cifar_caffe_directory, map_size=50000 * 1000 * 5)
txn = env.begin(write=True)
count = 0
for i in range(X.shape[0]):
datum = caffe.io.array_to_datum(X[i], y_f[i])
str_id = '{:08}'.format(count)
# txn.put(str_id, datum.SerializeToString())
txn.put(str_id.encode('ascii'), datum.SerializeToString())
count += 1
if count % 1000 == 0:
print('already handled with {} pictures'.format(count))
txn.commit()
txn = env.begin(write=True)
txn.commit()
env.close()
env = lmdb.open('cifar100_test_lmdb', map_size=10000 * 1000 * 5)
txn = env.begin(write=True)
count = 0
for i in range(Xt.shape[0]):
datum = caffe.io.array_to_datum(Xt[i], yt_f[i])
str_id = '{:08}'.format(count)
# python 3 在 str_id 后多了一个 .encode('ascii')
txn.put(str_id.encode('ascii'), datum.SerializeToString())
count += 1
if count % 1000 == 0:
print('already handled with {} pictures'.format(count))
txn.commit()
txn = env.begin(write=True)
txn.commit()
env.close()
else:
print('Conversion was already done. ')
—————————————————————————————
参考博客:http://blog.csdn.net/u010165147/article/details/54176612