我们之前用到的用于训练网络的数据集大部份为常用的经典数据集,可以 通过 TensorFlow 几行代码即可完成数据集的下载、加载以及预处理工作,这无疑大大的提升了算 法的研究效率,对于刚入门的新手来说比较友好。然而在实际应用中,针对于不同的应用场景,算法的数据集也各不相同。因此我们就需要自定义数据集,来完场网络的训练。下面我们就要对水泥裂缝图片进行数据集的制作:
- 第一步、收集裂缝的图片
笔者收集到了五类裂缝图片,分别放在五个文件夹当中,如下图所示:
这五张图片分别取自五类照片集中。
- 自定义数据集加载流
第一步、创建编码表
样本的类别一般以字符串类型的类别名区分,但是对于神经网络来说,首先需要将类 别进行数字编码,转换成 one-hot 编码。考虑𝑁类的数据集,我们将每 个类别随机编码为𝑙 ∈ [0,𝑁 − 1]的数字,类别名与数字的映射关系一旦创建,一 般不能变动。
实战如下,首先按序遍历 liefeng 根目录下的所有子目录,对每个子目标,利用类别 名作为编码表的 key,编码表的长度作为类别的标签数字保存进name_ label字典对象。
def load_liefeng(root, mode='train'):
# 创建数字编码表
name_label = {
}
# 遍历根目录下的子文件夹,并排序,保证映射关系固定
for name in sorted(os.listdir(os.path.join(root))):
# 跳过非文件夹
if not os.path.isdir(os.path.join(root, name)):
continue
# 给每个类别编码一个数字
name_label[name] = len(name_label.keys())
...
第二步、创建路径-标签表格
编码表确定后,我们需要根据实际数据祈祷存储方式获得每个样本的存储路径以及他 的标签数字,分别表示为 images 和 labels 2 个 List 对象。其中 images List 存储了每个样本 的路径字符串,labels 存储了样本的类别数字,两者长度一致。 我们将 images 和 labels 信息存储在 csv 格式的文件中,其中 csv 文件格式是一种以逗 号符号分隔数据的纯文本文件格式,可以使用记事本或者 MS Excel 软件打开。通过将所有 样本信息存储在一个 csv 文件中有诸多好处,比如可以直接进行数据集的划分,可以随机 采样 batch 等等。csv 文件中可以保存数据集所有样本的信息,也可以根据 train-val-test 分 别创建 3 个 csv 文件。最终产生的 csv 文件内容如下图所示,每一行的第一个元素保存 了当前样本的存储路径,第二个元素保存了样本的类别数字。
csv 文件创建实现如下,遍历 liefeng根目录下的所有图片,保存图片的路径,并根 据编码表获得其编码数字,保存到 csv 文件中:
def load_csv(root, filename, name_label):
# 从 csv 文件返回 images,labels 列表
# root:数据集根目录,filename:csv 文件名, name_label:类别名编码表
if not os.path.exists(os.path.join(root, filename)):
# 如果 csv 文件不存在,则创建
images = []
for name in name_label.keys(): # 遍历所有子目录,获得所有的图片
images += glob.glob(os.path.join(root, name, '*.jpg'))
print(len(images), images)
random.shuffle(images) # 随机打散顺序
# 创建 csv 文件,并存储图片路径及其 label 信息
with open(os.path.join(root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images:
name = img.split(os.sep)[-2]
label = name_label[name]
writer.writerow([img, label])
print('written into csv file:', filename)
...
创建完 csv 文件后,下一次只需要从 csv 文件中读取样本路径和标签信息即可:
def load_csv(root, filename, name2label):
…
# 此时已经有 csv 文件,直接读取
images, labels = [], []
with open(os.path.join(root, filename)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row
label = int(label)
images.append(img)
labels.append(label)
# 返回图片路径 list 和标签 list