之前一直都不懂如何本地加载cifar10的数据集,联网下载太慢了
可将数据集放在当前项目的“F:/PU_prior/data/CIFAR10/” 目录下
# 本地加载CIFAR10数据集 def get_cifar10(datapath): train_filenames = ['F:/PU_prior/data/CIFAR10/data_batch_{}'.format(ii + 1) for ii in range(5)] eval_filename = 'F:/PU_prior/data/CIFAR10/test_batch' x_tr = np.zeros((50000, 32, 32, 3), dtype='uint8') y_tr = np.zeros(50000, dtype='int32') for ii, fname in enumerate(train_filenames): cur_images, cur_labels = _load_datafile(os.path.join(datapath, fname)) x_tr[ii * 10000: (ii + 1) * 10000, ...] = cur_images y_tr[ii * 10000: (ii + 1) * 10000, ...] = cur_labels x_te, y_te = _load_datafile(os.path.join(datapath, eval_filename)) return (x_tr, y_tr), (x_te, y_te) (X_train, Y_train), (X_test, Y_test) = get_cifar10('') 即可得到训练集与测试集