前言:我们在跑神经网络时,通常使用的都是别人已经整理好的数据集,如MNIST、CIFAR10、CIFAR100等,但是在实际的应用中,往往需要根据实际的问题创建对应该问题的数据集,这就需要用自己的图片创建一个类似CIFAR10的数据集。如果我们创建的数据集格式和CIFAR10的格式一样的,那么所创建的数据集将很容易地输入到原有的神经网络,而无需改动太多的结构。
1、首先查看CIFAR10数据集是长什么样子的?
CIFAR10经过解压后会得到cifar-100batches-py的文件夹,如下图所示:
![553752b054f1071a0e511930731b1530.png](https://img-blog.csdnimg.cn/img_convert/553752b054f1071a0e511930731b1530.png)
cifar10解压后的数据
通过dataset2Image.py文件可以将以上的数据batch解压成图片
import pickledef unpickle(file): with open(file, 'rb') as fo: #dict = cPickle.load(fo) dict = pickle.load(fo, encoding='iso-8859-1') # encoding='bytes' #dict = pickle.load(fo, encoding='bytes') return dictif __name__ == '__main__': #print (unpickle('train_batch_video')) print(unpickle('data_batch_1'))
输出的结果如下:
![5e3cd31b30249bf23e68ab6a45ff2303.png](https://img-blog.csdnimg.cn/img_convert/5e3cd31b30249bf23e68ab6a45ff2303.png)
读取data_batcj_1的数据
因此,我们只需要生成类似的数据格式即可。
2、工程文件结构
![1568a7814dda56112ef337f941ad5412.png](https://img-blog.csdnimg.cn/img_convert/1568a7814dda56112ef337f941ad5412.png)
py文件,image2dataset可直接生成数据集,其它文件可以忽视
保存数据的文件夹在data目录下:
1、batch_save_train: 用于保存训练集的batch;
2、batch_save_val: 用于保存验证集或者测试集的batch;
3、figure_name_label_train: 用于保存生成的训练集图片名和标签的txt文件;
4、figure_name_label_val: 用于保存生成的验证集或者测试集图片名和标签的txt文件;
5、train: 训练集数据,里面包含了以分好类的数据;
image2dataset.py (代码用电脑端看比较好)
# -*- coding: UTF-8 -*-import cv2import osimport numpy as npDATA_LEN = 3072 # 32x32x3=3072#DATA_LEN = 43200 # 160x90x3CHANNEL_LEN = 1024 # 32x32=1024#CHANNEL_LEN = 14400 # 160x90 = 14400SHAPE = (32, 32)#(160, 90)#32# 修改#figure_path = '/home/user/PycharmProjects/DataSet_ipanel/Image2Dataset/layoutdata-160-90/train/video'#figure_name_label = '/home/user/PycharmProjects/DataSet_ipanel/Image2Dataset/layoutdata-160-90/figure_name_label_train/image_train_video_list.txt'#batch_save = '/home/user/PycharmProjects/DataSet_ipanel/Image2Dataset/layoutdata-160-90/batch_save_train'## 修改imagelist()标签值figure_path = './data/train/airbus' # 图片的位置figure_name_label = './data/figure_name_label_train/image_train_airbus_list.txt' # 保存图片名称和标签batch_save = './data/batch_save_train' # 保存batch文件## 修改imagelist()标签值def imread(im_path, shape=None, color="RGB", mode=cv2.IMREAD_UNCHANGED): im = cv2.imread(im_path, cv2.IMREAD_UNCHANGED) if color == "RGB": im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) if shape != None: #assert isinstance(shape, int) #im = cv2.resize(im, (shape, shape)) im = cv2.resize(im, shape) return imdef read_data(filename, data_path, shape=None, color='RGB'): """ filename (str): a file data file is stored in such format: image_name label data_path (str): image data folder return (numpy): a array of image and a array of label """ (shape1, shape2) = shape if os.path.isdir(filename): print("Can't found data file!") else: f = open(filename) lines = f.read().splitlines() count = len(lines) data = np.zeros((count, DATA_LEN), dtype=np.uint8) # label = np.zeros(count, dtype=np.uint8) lst = [ln.split(' ')[0] for ln in lines] label = [int(ln.split(' ')[1]) for ln in lines] idx = 0 #s, c = SHAPE, CHANNEL_LEN c = CHANNEL_LEN for ln in lines: fname, lab = ln.split(' ') #im = imread(os.path.join(data_path, fname), shape=s, color='RGB') im = imread(os.path.join(data_path, fname), shape=SHAPE, color='RGB') ''' im = cv2.imread(os.path.join(data_path, fname), cv2.IMREAD_UNCHANGED) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.resize(im, (s, s)) ''' data[idx, :c] = np.reshape(im[:, :, 0], c) data[idx, c:2 * c] = np.reshape(im[:, :, 1], c) data[idx, 2 * c:] = np.reshape(im[:, :, 2], c) label[idx] = int(lab) idx = idx + 1 return data, label, lstdef py2bin(data, label): label_arr = np.array(label).reshape(len(label), 1) label_uint8 = label_arr.astype(np.uint8) arr = np.hstack((label_uint8, data)) with open(batch_save, 'wb') as f: # 每个文件夹修改 #with open('/home/user/PycharmProjects/DataSet_ipanel/layoutdata-160-90/train/train_batch/train_batch_big5small5', 'wb') as f: for element in arr.flat: f.write(element)import pickleBIN_COUNTS = 1 # 每一类的数据为一个batchdef pickled(savepath, data, label, fnames, bin_num=BIN_COUNTS, mode="train", name=None): ''' savepath (str): save path data (array): image data, a nx3072 array label (list): image label, a list with length n fnames (str list): image names, a list with length n bin_num (int): save data in several files mode (str): {'train', 'test'} ''' assert os.path.isdir(savepath) total_num = len(fnames) samples_per_bin = total_num // bin_num # 将/换为// (TypeError: slice indices must be integers or None or have an __index__ method) assert samples_per_bin > 0 idx = 0 for i in range(bin_num): start = i * samples_per_bin end = (i + 1) * samples_per_bin if end <= total_num: dict = {'data': data[start:end, :], 'labels': label[start:end], 'filenames': fnames[start:end]} else: dict = {'data': data[start:, :], 'labels': label[start:], 'filenames': fnames[start:]} if mode == "train": dict['batch_label'] = "training batch {}".format(name)#(idx, bin_num) else: dict['batch_label'] = "testing batch {}".format(name)#(idx, bin_num) with open(os.path.join(savepath, 'data_batch_' + str(name)), 'wb') as fi:#str(idx)), 'wb') as fi: # cPickle.dump(dict, fi) pickle.dump(dict, fi) #idx = idx + 1def imagelist(): directory_normal = figure_path #directory_normal = r"/home/user/PycharmProjects/DataSet_ipanel/layoutdata-160-90/train/big5small6" # 原始图片位置,32*32 pixel file_train_list = figure_name_label #file_train_list = r"/home/user/PycharmProjects/DataSet_ipanel/layoutdata-160-90/train/image_train_big5small6_list.txt" # 构建imagelist输出位置 with open(file_train_list, "a") as f: for filename in os.listdir(directory_normal): #f.write(filename + " " + "0" + "") #这里分类默认全为0 f.write(filename + " " + "0" + "") # 这里分类默认全为0 ##########if __name__ == '__main__': data_path = figure_path #data_path = '/home/user/PycharmProjects/DataSet_ipanel/layoutdata-160-90/train/big5small6' file_list = figure_name_label #file_list = '/home/user/PycharmProjects/DataSet_ipanel/layoutdata-160-90/train/image_train_big5small6_list.txt' save_path = batch_save#'./bin' imagelist() #构建imagelist # 生成名字和标签的对应关系 data, label, lst = read_data(file_list, data_path, shape=SHAPE) #将图片像素数据转成矩阵和标签列表 #py2bin(data, label) #将像素矩阵和标签列表转成cifar10 binary version # 二进制版本 pickled(save_path, data, label, lst, bin_num=1, name='airbus') # 生成python版本
还是截图吧。。。
![1299956e30493a01eceefa2f9d5efa20.png](https://img-blog.csdnimg.cn/img_convert/1299956e30493a01eceefa2f9d5efa20.png)
![e1a3798e937663a96e1aec179313d85a.png](https://img-blog.csdnimg.cn/img_convert/e1a3798e937663a96e1aec179313d85a.png)
![b8aa0e44b4b53050b3d36b4233da5ef6.png](https://img-blog.csdnimg.cn/img_convert/b8aa0e44b4b53050b3d36b4233da5ef6.png)
![0d1e9234beff4e23f2a3dd026755b780.png](https://img-blog.csdnimg.cn/img_convert/0d1e9234beff4e23f2a3dd026755b780.png)
![c38059fe6c728223c60593d44fa5b570.png](https://img-blog.csdnimg.cn/img_convert/c38059fe6c728223c60593d44fa5b570.png)
![3bef545d13c7354965fcc845ad0f309a.png](https://img-blog.csdnimg.cn/img_convert/3bef545d13c7354965fcc845ad0f309a.png)
![89e97b72e1b59fb32580db79d3faa8ab.png](https://img-blog.csdnimg.cn/img_convert/89e97b72e1b59fb32580db79d3faa8ab.png)
生成了data_batch_airbus文件
![d691058cc6622d8e3e0cdbe2db85e666.png](https://img-blog.csdnimg.cn/img_convert/d691058cc6622d8e3e0cdbe2db85e666.png)
对data_batch_airbus文件进行提取,如下所示:
![872653bdcd9480d1a56e6b419813c099.png](https://img-blog.csdnimg.cn/img_convert/872653bdcd9480d1a56e6b419813c099.png)
可以的到和之前cifar10解压的数据一样的格式,这种做法是每个batch都是一类图片,在训练时对训练数据进行随机打乱即可。(具体可以到我主页(我的博客)查看代码)。文字感觉还是很难表达清楚,需要的同学看代码自己理解下就OK了,并不是很复杂,有需要的同学可以看下收藏、转发,不懂的同学可以留言@我,有时间会尽量解答。