WinClip非官方复现代码学习笔记7

一、程序结构展示

今天讲解的是datasets目录下的最后一个python文件—visa.py文件,这个文件也是数据预处理的一部分哦~

二、代码作用

实现了一个函数load_visa,用于加载 VISA 数据集的图像路径、标签和类型信息,并根据给定的类别、k_shot(每个类别的样本数)、实验索引来返回训练集和测试集的相关信息。

VISA(Visual Inspection with Semi-supervised Attention)数据集是一个用于视觉检验的数据集,旨在为缺陷检测和图像分割等任务提供训练和测试数据。该数据集由清华大学智能化工程研究所于2021年发布。

VISA数据集主要包含了工业生产中常见的不同类别的物体图像,这些物体包括了各种不同的产品、零件和组件。数据集中的图像被标注为正常或者带有特定类型的缺陷,以便用于训练和评估缺陷检测模型。

以下是VISA数据集的一些特点和内容:

  1. 多样性:VISA数据集涵盖了多种不同的物体和场景,包括了电路板、工业产品、食品和其他制造业中常见的物品。

  2. 标注信息:每张图像都配有对应的标签,用于指示图像中是否存在缺陷以及缺陷的类型。这些标签可以帮助模型进行监督学习和评估。

  3. 图像分辨率:图像的分辨率通常较高,以保证对缺陷进行准确的检测和分割。

  4. 用途:VISA数据集可以用于训练和评估图像分类、目标检测、物体识别、缺陷检测和图像分割等各种视觉任务的模型。

三、逐行注释

import glob  # 用于文件通配符匹配
import os  # 用于操作文件和目录
import random  # 用于随机数的生成

visa_classes = ['candle', 'capsules', 'cashew', 'chewinggum',
                   'fryum', 'macaroni1', 'macaroni2',
                'pcb1', 'pcb2', 'pcb3','pcb4', 'pipe_fryum']  # 定义了一个包含VISA数据集中所有的列表“visa_classes”

# VISA_DIR = '../datasets/VisA_pytorch/1cls'
VISA_DIR = './datasets/VisA_pytorch/1cls'  # 设置visa数据集的路径

def load_visa(category, k_shot, experiment_indx):  # 定义一个名为load_visa,这个函数啊接受三个参数:”类别“(category)、”每个类别的样本数目“(k-shot)、”实验索引“(index)
    def load_phase(root_path, gt_path):  # 定义一个内部函数(load_phase): 这个内部函数负责加载单个阶段(训练或者测试)的图像路径、标签和类型信息。接受两个参数:图像文件夹路径、标签文件夹路径
        img_tot_paths = []  # 初始化了一些空列表
        gt_tot_paths = []
        tot_labels = []
        tot_types = []

        defect_types = os.listdir(root_path)  # 遍历每个缺陷类型
        for defect_type in defect_types:  # 根据缺陷加载图像路径,并相应地加载标签路径
            if defect_type == 'good':  # 如果为正常图片,则直接加载图像路径,并将标签设为0
                img_paths = glob.glob(os.path.join(root_path, defect_type) + "/*.JPG")
                img_tot_paths.extend(img_paths)
                gt_tot_paths.extend([0] * len(img_paths))
                tot_labels.extend([0] * len(img_paths))
                tot_types.extend(['good'] * len(img_paths))
            else:  # 否则加载图像和相应的标签,并将标签设为1
                img_paths = glob.glob(os.path.join(root_path, defect_type) + "/*.JPG")
                gt_paths = [os.path.join(gt_path, defect_type, os.path.basename(s)[:-4] + '.png') for s in
                            img_paths]
                img_paths.sort()
                gt_paths.sort()
                img_tot_paths.extend(img_paths)
                gt_tot_paths.extend(gt_paths)
                tot_labels.extend([1] * len(img_paths))
                tot_types.extend([defect_type] * len(img_paths))
                # 函数确保图像和标签的数量一致,并返回加载的路径、标签和类型信息
        assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"
        return img_tot_paths, gt_tot_paths, tot_labels, tot_types

    assert category in visa_classes  # 对传入的参数进行了断言验证,确保它们的合法性
    assert k_shot in [0, 1, 5, 10]
    assert experiment_indx in [0, 1, 2]
# 构造训练集和测试集路径
    test_img_path = os.path.join(VISA_DIR, category, 'test')
    train_img_path = os.path.join(VISA_DIR, category, 'train')
    ground_truth_path = os.path.join(VISA_DIR, category, 'ground_truth')

# 调用load_phase函数分别加载它们的图像路径、标签路径和类型信息
    train_img_tot_paths, train_gt_tot_paths, train_tot_labels, \
    train_tot_types = load_phase(train_img_path, ground_truth_path)

    test_img_tot_paths, test_gt_tot_paths, test_tot_labels, \
    test_tot_types = load_phase(test_img_path, ground_truth_path)

    selected_train_img_tot_paths = []
    selected_train_gt_tot_paths = []
    selected_train_tot_labels = []
    selected_train_tot_types = []
# 选择训练样本

    if k_shot > 0:  # 如果”k_shot“大于0,从训练集中随机选择”k_shot“个样本,并保存相应的路径、标签和类型信息
        full_index = range(len(train_img_tot_paths))

        selected_index = random.sample(full_index, k_shot)
        selected_train_img_tot_paths = [train_img_tot_paths[k] for k in selected_index]
        selected_train_gt_tot_paths = [train_gt_tot_paths[k] for k in selected_index]
        selected_train_tot_labels = [train_tot_labels[k] for k in selected_index]
        selected_train_tot_types = [train_tot_types[k] for k in selected_index]

    return (selected_train_img_tot_paths, selected_train_gt_tot_paths, selected_train_tot_labels, selected_train_tot_types), \
           (test_img_tot_paths, test_gt_tot_paths, test_tot_labels, test_tot_types)
#  返回了选择的训练样本信息和完整的测试集信息

  • 7
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值