一、程序结构展示
下图是项目中datasets文件夹下面的结构,今天主要讲解的是mvtec.py文件
二、 代码作用
这段代码实现了一个用于加载 MVTEC Anomaly Detection 数据集的函数 load_mvtec
,它能够根据指定的类别、k-shot 值和实验索引加载对应的图像数据和真值,并根据训练索引选定特定的训练数据集,以便进行有控制的模型训练。
在这里单独介绍一下“k-shot”值,“k-shot”是指在机器学习和元学习领域中的一个概念。它通常用于描述一个学习任务中每个类别所包含的训练样本数量。"k-shot"中的"k"代表着一个整数,表示每个类别的训练样本数量。例如,在一个5-shot分类任务中,每个类别将有5个训练样本。这意味着模型在学习阶段仅能够从每个类别的这几个样本中学习,而不是像传统的监督学习中那样从大量的训练数据中学习。"k-shot"任务通常用于评估模型在小样本情况下的泛化能力和学习能力。
三、逐行注释
import glob
import os
# 导入了glob和os模块,用于文件路径操作和文件匹配
mvtec_classes = ['carpet', 'grid', 'leather', 'tile', 'wood',
'bottle', 'cable', 'capsule', 'hazelnut', 'metal_nut', 'pill',
'screw', 'toothbrush', 'transistor', 'zipper']
# 定义了数据集中的类别列表
MVTEC2D_DIR = './datasets/mvtec_anomaly_detection'
# 指定了mvtec2d数据集的根目录
# 定义一个名为load_mvtec的函数,这个函数接受“类别、k_shot值、实验索引”三者作为参数
def load_mvtec(category, k_shot, experiment_indx):
def load_phase(root_path, gt_path):# 内部定义了一个load_phase函数,用于加载数据集中的一个阶段(训练或测试)
# 这个函数加载了训练和测试阶段的图像路径、地面真相路径、标签和缺陷类型
img_tot_paths = [] # 用于存储所有图像文件的路径
gt_tot_paths = [] # 用于存储所有地面真相文件的路径
tot_labels = [] # 用于存储所有图像的标签(0代表正常,1代表异常 )
tot_types = [] # 用于存储所有图像的缺陷类型
print(root_path)
print(gt_path)
defect_types = os.listdir(root_path) # 使用os.listdir(root_path)获取root_path下的所有子目录(即缺陷类型)
for defect_type in defect_types:
if defect_type == 'good': # 如果是正常图像(“good”):
img_paths = glob.glob(os.path.join(root_path, defect_type) + "/*.png") # 使用glob.glob获取该类型下所有.png文件的路径
print(img_paths)
img_tot_paths.extend(img_paths) # 将图像路径添加到img_tot_paths中
gt_tot_paths.extend([0] * len(img_paths)) # 将与图像数量相同的零值添加到“gt_tot_paths”中,表示为真值
tot_labels.extend([0] * len(img_paths)) # 将与图像数量相同的零值添加到tot_labels中,表示为正常标签
tot_types.extend(['good'] * len(img_paths)) # 将与图像数量相同的“good”字符串添加到tot_types中,表示正常类型
else: # 补充:len()是一个内置函数,用于返回对象的长度或元素个数
img_paths = glob.glob(os.path.join(root_path, defect_type) + "/*.png") # 当为缺陷样本的时候使用glob.glob获取该类型下所有.png文件的路径
gt_paths = [os.path.join(gt_path, defect_type, os.path.basename(s)[:-4] + '_mask.png') for s in
img_paths] # 构建对应的地面真相路径列表,使用os.path.basename(s)[:-4]截取图像文件名,并在末尾加上_mask.png
img_paths.sort() # 对图像路径进行排序
gt_paths.sort() # 对真值进行排序
img_tot_paths.extend(img_paths) # 将图像路径、真值路径、标签和类型添加到相应的列表中
gt_tot_paths.extend(gt_paths)
tot_labels.extend([1] * len(img_paths))
tot_types.extend([defect_type] * len(img_paths))
assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"
# 确保图像路径列表和地面真相路径列表的长度相同,以确保每个图像都有对应的地面真相
return img_tot_paths, gt_tot_paths, tot_labels, tot_types
# 返回加载的图像路径、地面真相路径、标签和类型列表
assert category in mvtec_classes
assert k_shot in [0, 1, 5, 10]
assert experiment_indx in [0, 1, 2]
test_img_path = os.path.join(MVTEC2D_DIR, category, 'test')
train_img_path = os.path.join(MVTEC2D_DIR, category, 'train')
ground_truth_path = os.path.join(MVTEC2D_DIR, category, 'ground_truth')
if k_shot == 0:
training_indx = [] # 如果k_shot为0,则将训练索引设置为空列表,否则根据指定的实验索引和k_shot从文件中读取训练索引
else: # 否则就根据指定的实验索引和k_shot从文件中读取训练索引
seed_file = os.path.join('./datasets/seeds_mvtec', category, 'selected_samples_per_run.txt') # 这些索引保存在名为selected_samples_per_run.txt
with open(seed_file, 'r') as f:
files = f.readlines()
begin_str = f'{experiment_indx}-{k_shot}: ' # 文件路径基于指定的类别构建
training_indx = []
for line in files:
if line.count(begin_str) > 0:
strip_line = line[len(begin_str):-1]
index = strip_line.split(' ')
training_indx = index
train_img_tot_paths, train_gt_tot_paths, train_tot_labels, \
train_tot_types = load_phase(train_img_path, ground_truth_path) # 调用load_phase函数加载训练图像和测试图像的路径、地面真相路径、标签和缺陷类型。
test_img_tot_paths, test_gt_tot_paths, test_tot_labels, \
test_tot_types = load_phase(test_img_path, ground_truth_path)
selected_train_img_tot_paths = [] # 选择训练数据子集
selected_train_gt_tot_paths = [] # 遍历训练图像路径、地面真相路径、标签和缺陷类型,仅将在训练索引中的图像添加到选定的训练数据集中。这样做是为了将数据集限制在特定的子集上,以便进行有控制的训练。
selected_train_tot_labels = []
selected_train_tot_types = []
for img_path, gt_path, label, defect_type in zip(train_img_tot_paths, train_gt_tot_paths, train_tot_labels,
train_tot_types): # 遍历训练图像路径、地面真相路径、标签和缺陷类型,仅将在训练索引中的图像添加到选定的训练数据集中。这样做是为了将数据集限制在特定的子集上,以便进行有控制的训练。
if os.path.basename(img_path[:-4]) in training_indx:
selected_train_img_tot_paths.append(img_path)
selected_train_gt_tot_paths.append(gt_path)
selected_train_tot_labels.append(label)
selected_train_tot_types.append(defect_type)
return (selected_train_img_tot_paths, selected_train_gt_tot_paths, selected_train_tot_labels, selected_train_tot_types), \
(test_img_tot_paths, test_gt_tot_paths, test_tot_labels, test_tot_types)
# 返回选定的训练数据集和完整的测试数据集,以便后续的模型训练和测试使用