一、程序结构展示
今天讲解的是datasets目录下的最后一个python文件—visa.py文件,这个文件也是数据预处理的一部分哦~
二、代码作用
实现了一个函数load_visa,用于加载 VISA 数据集的图像路径、标签和类型信息,并根据给定的类别、k_shot(每个类别的样本数)、实验索引来返回训练集和测试集的相关信息。
VISA(Visual Inspection with Semi-supervised Attention)数据集是一个用于视觉检验的数据集,旨在为缺陷检测和图像分割等任务提供训练和测试数据。该数据集由清华大学智能化工程研究所于2021年发布。
VISA数据集主要包含了工业生产中常见的不同类别的物体图像,这些物体包括了各种不同的产品、零件和组件。数据集中的图像被标注为正常或者带有特定类型的缺陷,以便用于训练和评估缺陷检测模型。
以下是VISA数据集的一些特点和内容:
-
多样性:VISA数据集涵盖了多种不同的物体和场景,包括了电路板、工业产品、食品和其他制造业中常见的物品。
-
标注信息:每张图像都配有对应的标签,用于指示图像中是否存在缺陷以及缺陷的类型。这些标签可以帮助模型进行监督学习和评估。
-
图像分辨率:图像的分辨率通常较高,以保证对缺陷进行准确的检测和分割。
-
用途:VISA数据集可以用于训练和评估图像分类、目标检测、物体识别、缺陷检测和图像分割等各种视觉任务的模型。
三、逐行注释
import glob # 用于文件通配符匹配
import os # 用于操作文件和目录
import random # 用于随机数的生成
visa_classes = ['candle', 'capsules', 'cashew', 'chewinggum',
'fryum', 'macaroni1', 'macaroni2',
'pcb1', 'pcb2', 'pcb3','pcb4', 'pipe_fryum'] # 定义了一个包含VISA数据集中所有的列表“visa_classes”
# VISA_DIR = '../datasets/VisA_pytorch/1cls'
VISA_DIR = './datasets/VisA_pytorch/1cls' # 设置visa数据集的路径
def load_visa(category, k_shot, experiment_indx): # 定义一个名为load_visa,这个函数啊接受三个参数:”类别“(category)、”每个类别的样本数目“(k-shot)、”实验索引“(index)
def load_phase(root_path, gt_path): # 定义一个内部函数(load_phase): 这个内部函数负责加载单个阶段(训练或者测试)的图像路径、标签和类型信息。接受两个参数:图像文件夹路径、标签文件夹路径
img_tot_paths = [] # 初始化了一些空列表
gt_tot_paths = []
tot_labels = []
tot_types = []
defect_types = os.listdir(root_path) # 遍历每个缺陷类型
for defect_type in defect_types: # 根据缺陷加载图像路径,并相应地加载标签路径
if defect_type == 'good': # 如果为正常图片,则直接加载图像路径,并将标签设为0
img_paths = glob.glob(os.path.join(root_path, defect_type) + "/*.JPG")
img_tot_paths.extend(img_paths)
gt_tot_paths.extend([0] * len(img_paths))
tot_labels.extend([0] * len(img_paths))
tot_types.extend(['good'] * len(img_paths))
else: # 否则加载图像和相应的标签,并将标签设为1
img_paths = glob.glob(os.path.join(root_path, defect_type) + "/*.JPG")
gt_paths = [os.path.join(gt_path, defect_type, os.path.basename(s)[:-4] + '.png') for s in
img_paths]
img_paths.sort()
gt_paths.sort()
img_tot_paths.extend(img_paths)
gt_tot_paths.extend(gt_paths)
tot_labels.extend([1] * len(img_paths))
tot_types.extend([defect_type] * len(img_paths))
# 函数确保图像和标签的数量一致,并返回加载的路径、标签和类型信息
assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"
return img_tot_paths, gt_tot_paths, tot_labels, tot_types
assert category in visa_classes # 对传入的参数进行了断言验证,确保它们的合法性
assert k_shot in [0, 1, 5, 10]
assert experiment_indx in [0, 1, 2]
# 构造训练集和测试集路径
test_img_path = os.path.join(VISA_DIR, category, 'test')
train_img_path = os.path.join(VISA_DIR, category, 'train')
ground_truth_path = os.path.join(VISA_DIR, category, 'ground_truth')
# 调用load_phase函数分别加载它们的图像路径、标签路径和类型信息
train_img_tot_paths, train_gt_tot_paths, train_tot_labels, \
train_tot_types = load_phase(train_img_path, ground_truth_path)
test_img_tot_paths, test_gt_tot_paths, test_tot_labels, \
test_tot_types = load_phase(test_img_path, ground_truth_path)
selected_train_img_tot_paths = []
selected_train_gt_tot_paths = []
selected_train_tot_labels = []
selected_train_tot_types = []
# 选择训练样本
if k_shot > 0: # 如果”k_shot“大于0,从训练集中随机选择”k_shot“个样本,并保存相应的路径、标签和类型信息
full_index = range(len(train_img_tot_paths))
selected_index = random.sample(full_index, k_shot)
selected_train_img_tot_paths = [train_img_tot_paths[k] for k in selected_index]
selected_train_gt_tot_paths = [train_gt_tot_paths[k] for k in selected_index]
selected_train_tot_labels = [train_tot_labels[k] for k in selected_index]
selected_train_tot_types = [train_tot_types[k] for k in selected_index]
return (selected_train_img_tot_paths, selected_train_gt_tot_paths, selected_train_tot_labels, selected_train_tot_types), \
(test_img_tot_paths, test_gt_tot_paths, test_tot_labels, test_tot_types)
# 返回了选择的训练样本信息和完整的测试集信息