前言: 接续到前面所做工作,在从网上爬取到适量所需图像后,扩充数据集,进行数据增强,然后将所有获得的数据,保存为TFRecord文件,即构建出自己的数据集。
说明: 计划整个过程基于tensorflow实现,不使用更高层的keras工具。因此数据增强也是使用numpy、opencv、tensorflow或其他一些图像处理工具实现,而没有使用keras中的数据增强函数(虽然会很方便)。
step2: 数据预处理
数据增强的方式和实现总结见: 深度学习之数据增强(数据集扩充)方式和实现总结
代码实现:
import cv2
import tensorflow as tf
#import numpy
import os
import random
base_dir = '/home/zhangwei/workfiles/deeplearning/dogVScat/'
input_dir = '/home/zhangwei/workfiles/deeplearning/dogVScat/data/'
aug_dir = '/home/zhangwei/workfiles/deeplearning/dogVScat/aug_data/'
info_name = 'Info.txt'
record_name = 'dogVScat.tfrecord'
'''
补充下Python常使用到的操作
os.path.exit(file_path) # 判断是否存在该文件
op.path.listdir() # 获取路径名
os.path.dirname(input_dir) # 罗列出基础路径
os.path.join() # 路径合并
file_path.split() # 分离路径
'''
# 0.定义数据增强方案,采用简单随机的水平/垂直翻转,转置,缩放,移位等操作
def augmentation(img):
pass
# 1.对数据集进行扩充,扩充后的图像保存在.../dogVScat/train_data/下。还包含一些预处理:乱序,标签信息统计,重命名等。
def dataAugmentation(input_dir, aug_dir):
filename = os.listdir(input_dir) # 获取路径下所有文件名
for i,name in enumerate(filename):
img = cv2.imread( input_dir + name)
# 具体的增强操作方案(本例中只简单进行概率性转置+缩放)
height,width = img.shape[:2]
_img = cv2.resize(img, (int(0.5*width), int(0.5*height)),interpolation = cv2.INTER_NEAREST)
if random.randint(0,1) == 1:
_img = cv2.flip(_img, -1)
label = name[:3]
cv2.imwrite(aug_dir + 'aug_' + name, _img)
print('数据增强处理进度:%d / %d \n'%(i, len(filename)))
# 统计至Info.txt文件中
f = open(base_dir + info_name, 'w+')
filename.extend(os.listdir(aug_dir))
random.shuffle(filename)
for i,name in enumerate(filename):
label = 0 if name.split('_')[-2] == 'cat' else 1
imgDir = aug_dir if name.split('_')[0] == 'aug' else input_dir
string = '%s %d\n'%(imgDir + name, label)
f.write(string)
f.close()
# 3.保存为TFRecord文件
## 定义所需整型属性:可以不用定义,下面直接写也可。以下定义同理。
def _int64_feature(value):
return tf.train.Feature(int64_list = tf.train.Int64List(value = value))
## 定义所需字符串属性
def _bytes_feature(value):
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
## 定义所需实数型属性
def _float_featurn(value):
return tf.train.Feature(float_list = tf.train.FloatList(value = value))
## 生成TFRecord文件
def createTFRecord():
image_list = []
label_list = []
with open(base_dir + info_name, 'r') as f:
for line in f.readlines():
image_list.append(line.rstrip().split(' ')[0])
label_list.append(line.rstrip().split(' ')[1])
writer = tf.io.TFRecordWriter(base_dir + record_name)
for i,[image, label] in enumerate(zip(image_list, label_list)):
img = cv2.imread(image)
img_data = img.tostring()
name = bytes('cat' if label==0 else 'dog',encoding = 'utf-8')
shape = img.shape
print('TFRecord文件生成处理进度:%d / %d'%(i, len(image_list)))
example = tf.train.Example(
features = tf.train.Features(
feature = {
'name': _bytes_feature(name),
'shape': _int64_feature(list(img.shape)),
'data': _int64_feature(list(img_data))
}))
writer.write(example.SerializeToString())
writer.close
if __name__ == '__main__':
print('程序开始!')
dataAugmentation(input_dir, aug_dir)
createTFRecord()
print('完成!')
说明:
- 代码书写很不规范,可加入多线程处理。