主函数:#测试集和训练集是分开整的
import random
import pickle
import numpy as np
from pickled import *
import os
def load_data_cifar(filename, mode='cifar10'):
""" load data and labels information from cifar10 and cifar100
cifar10 keys(): dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
cifar100 keys(): dict_keys([b'filenames', b'batch_label', b'fine_labels', b'coarse_labels', b'data'])
"""
with open(filename, 'rb') as f:
dataset = pickle.load(f, encoding='bytes')
if mode == 'cifar10':
data = dataset[b'data']
labels = dataset[b'labels']
img_names = dataset[b'filenames']
elif mode == 'cifar100':
data = dataset[b'data']
labels = dataset[b'fine_labels']
img_names = dataset[b'filenames']
else:
print("mode should be in ['cifar10', 'cifar100']")
return None, None, None
return data, labels, img_names
def load_cifar10(cifar10_path, mode='train'):
if mode == "train":
data_all = np.empty(shape=[0, 3072], dtype=np.uint8)
labels_all = []
img_names_all = []
for i in range(1, 6):
filename = os.path.join(cifar10_path,
'data_batch_' + str(i)).replace('\\', '/')
print("Loading {}".format(filename))
data, labels, img_names = load_data_cifar(filename, mode='cifar10')
data_all = np.vstack((data_all, data))
labels_all += labels
img_names_all += img_names
return data_all, labels_all, img_names_all
elif mode == "test":
filename = os.path.join(cifar10_path, 'test_batch').replace('\\', '/')
print("Loading {}".format(filename))
return load_data_cifar(filename, mode='cifar10')
if __name__ == "__main__":
# 修改为你的数据集存放路径
cifar10_path = "../data/cifar-10-batches-py"
# label_names、num_cases_per_batch、num_vis
# # # 提取cifar10、cifar100的图片数据、标签、文件名
# data_cifar10_train, labels_cifar10_train, img_names_cifar10_train = \
# load_cifar10(cifar10_path, mode='train')
data_cifar10_test, labels_cifar10_test, img_names_cifar10_test = \
load_cifar10(cifar10_path, mode='test')
label0_test=[]
for i in enumerate(labels_cifar10_test):
if i[1]==0:
#所有label=0的位置id
label0_test.append(i[0])
print("label0_test.len(5000)",len(label0_test))
label0_500_test = random.sample(label0_test,900)
label0_500_test.sort()
#delete data(index)
data_cifar10_test_del=np.delete(data_cifar10_test,label0_500_test,axis=0)
for counter, index in enumerate(label0_500_test):
index = index - counter
#delete label(index)
labels_cifar10_test.pop(index)
#delete name(index)
img_names_cifar10_test.pop(index)
print(len(labels_cifar10_test),
len(img_names_cifar10_test),len(data_cifar10_test_del))
save_path='data/cifar-10-1-batches-py/test'
pickled(save_path,data_cifar10_test_del,
labels_cifar10_test,img_names_cifar10_test)
下面这个是pickled.py文件
import os
import pickle
BIN_COUNTS = 1
def pickled(savepath, data, label, fnames, bin_num=BIN_COUNTS, mode="test"):
'''
savepath (str): save path
data (array): image data, a nx3072 array
label (list): image label, a list with length n
fnames (str list): image names, a list with length n
bin_num (int): save data in several files
mode (str): {'train', 'test'}
'''
assert os.path.isdir(savepath)
total_num = len(fnames)
samples_per_bin = total_num / bin_num
print("samples_per_bin",samples_per_bin)
assert samples_per_bin > 0
idx = 0
for i in range(bin_num):
start = int(i * samples_per_bin)
end = int((i + 1) * samples_per_bin)
if end <= total_num:
dict = {'data': data[start:end, :],
'labels': label[start:end],
'filenames': fnames[start:end]}
else:
dict = {'data': data[start:, :],
'labels': label[start:],
'filenames': fnames[start:]}
if mode == "train":
dict['batch_label'] = "training batch {} of {}".format(idx, bin_num)
else:
dict['batch_label'] = "testing batch {} of {}".format(idx, bin_num)
with open(os.path.join(savepath, 'data_batch_' + str(idx)), 'wb') as fi:
pickle.dump(dict, fi)
idx = idx + 1