WinClip非官方复现代码学习笔记5

一、程序结构展示

下图是项目中datasets文件夹下面的结构,今天主要讲解的是mvtec.py文件

二、 代码作用

这段代码实现了一个用于加载 MVTEC Anomaly Detection 数据集的函数 load_mvtec,它能够根据指定的类别、k-shot 值和实验索引加载对应的图像数据和真值,并根据训练索引选定特定的训练数据集,以便进行有控制的模型训练。

在这里单独介绍一下“k-shot”值,“k-shot”是指在机器学习和元学习领域中的一个概念。它通常用于描述一个学习任务中每个类别所包含的训练样本数量。"k-shot"中的"k"代表着一个整数,表示每个类别的训练样本数量。例如,在一个5-shot分类任务中,每个类别将有5个训练样本。这意味着模型在学习阶段仅能够从每个类别的这几个样本中学习,而不是像传统的监督学习中那样从大量的训练数据中学习。"k-shot"任务通常用于评估模型在小样本情况下的泛化能力和学习能力。

三、逐行注释

import glob
import os
# 导入了glob和os模块,用于文件路径操作和文件匹配
mvtec_classes = ['carpet', 'grid', 'leather', 'tile', 'wood',
                   'bottle', 'cable', 'capsule', 'hazelnut', 'metal_nut', 'pill',
                   'screw', 'toothbrush', 'transistor', 'zipper']
# 定义了数据集中的类别列表


MVTEC2D_DIR = './datasets/mvtec_anomaly_detection'
# 指定了mvtec2d数据集的根目录
# 定义一个名为load_mvtec的函数,这个函数接受“类别、k_shot值、实验索引”三者作为参数

def load_mvtec(category, k_shot, experiment_indx):
    def load_phase(root_path, gt_path):# 内部定义了一个load_phase函数,用于加载数据集中的一个阶段(训练或测试)
        # 这个函数加载了训练和测试阶段的图像路径、地面真相路径、标签和缺陷类型
        img_tot_paths = []  # 用于存储所有图像文件的路径
        gt_tot_paths = []  # 用于存储所有地面真相文件的路径
        tot_labels = []  # 用于存储所有图像的标签(0代表正常,1代表异常 )
        tot_types = []  # 用于存储所有图像的缺陷类型
        print(root_path)
        print(gt_path)

        defect_types = os.listdir(root_path)  # 使用os.listdir(root_path)获取root_path下的所有子目录(即缺陷类型)

        for defect_type in defect_types:
            if defect_type == 'good':  # 如果是正常图像(“good”):
                img_paths = glob.glob(os.path.join(root_path, defect_type) + "/*.png")  # 使用glob.glob获取该类型下所有.png文件的路径
                print(img_paths)
                img_tot_paths.extend(img_paths)  # 将图像路径添加到img_tot_paths中
                gt_tot_paths.extend([0] * len(img_paths))  # 将与图像数量相同的零值添加到“gt_tot_paths”中,表示为真值
                tot_labels.extend([0] * len(img_paths))  # 将与图像数量相同的零值添加到tot_labels中,表示为正常标签
                tot_types.extend(['good'] * len(img_paths))  # 将与图像数量相同的“good”字符串添加到tot_types中,表示正常类型
            else:  # 补充:len()是一个内置函数,用于返回对象的长度或元素个数
                img_paths = glob.glob(os.path.join(root_path, defect_type) + "/*.png")  # 当为缺陷样本的时候使用glob.glob获取该类型下所有.png文件的路径
                gt_paths = [os.path.join(gt_path, defect_type, os.path.basename(s)[:-4] + '_mask.png') for s in
                            img_paths]  # 构建对应的地面真相路径列表,使用os.path.basename(s)[:-4]截取图像文件名,并在末尾加上_mask.png
                img_paths.sort()  # 对图像路径进行排序
                gt_paths.sort()  # 对真值进行排序
                img_tot_paths.extend(img_paths)  # 将图像路径、真值路径、标签和类型添加到相应的列表中
                gt_tot_paths.extend(gt_paths)
                tot_labels.extend([1] * len(img_paths))
                tot_types.extend([defect_type] * len(img_paths))

        assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"
# 确保图像路径列表和地面真相路径列表的长度相同,以确保每个图像都有对应的地面真相
        return img_tot_paths, gt_tot_paths, tot_labels, tot_types
# 返回加载的图像路径、地面真相路径、标签和类型列表
    assert category in mvtec_classes
    assert k_shot in [0, 1, 5, 10]
    assert experiment_indx in [0, 1, 2]

    test_img_path = os.path.join(MVTEC2D_DIR, category, 'test')
    train_img_path = os.path.join(MVTEC2D_DIR, category, 'train')
    ground_truth_path = os.path.join(MVTEC2D_DIR, category, 'ground_truth')

    if k_shot == 0:
        training_indx = []  # 如果k_shot为0,则将训练索引设置为空列表,否则根据指定的实验索引和k_shot从文件中读取训练索引
    else:  # 否则就根据指定的实验索引和k_shot从文件中读取训练索引
        seed_file = os.path.join('./datasets/seeds_mvtec', category, 'selected_samples_per_run.txt')  # 这些索引保存在名为selected_samples_per_run.txt
        with open(seed_file, 'r') as f:
            files = f.readlines()
        begin_str = f'{experiment_indx}-{k_shot}: '  # 文件路径基于指定的类别构建

        training_indx = []
        for line in files:
            if line.count(begin_str) > 0:
                strip_line = line[len(begin_str):-1]
                index = strip_line.split(' ')
                training_indx = index

    train_img_tot_paths, train_gt_tot_paths, train_tot_labels, \
    train_tot_types = load_phase(train_img_path, ground_truth_path)  # 调用load_phase函数加载训练图像和测试图像的路径、地面真相路径、标签和缺陷类型。

    test_img_tot_paths, test_gt_tot_paths, test_tot_labels, \
    test_tot_types = load_phase(test_img_path, ground_truth_path)

    selected_train_img_tot_paths = []  # 选择训练数据子集
    selected_train_gt_tot_paths = []  # 遍历训练图像路径、地面真相路径、标签和缺陷类型,仅将在训练索引中的图像添加到选定的训练数据集中。这样做是为了将数据集限制在特定的子集上,以便进行有控制的训练。
    selected_train_tot_labels = []
    selected_train_tot_types = []

    for img_path, gt_path, label, defect_type in zip(train_img_tot_paths, train_gt_tot_paths, train_tot_labels,
                                                     train_tot_types):  # 遍历训练图像路径、地面真相路径、标签和缺陷类型,仅将在训练索引中的图像添加到选定的训练数据集中。这样做是为了将数据集限制在特定的子集上,以便进行有控制的训练。
        if os.path.basename(img_path[:-4]) in training_indx:
            selected_train_img_tot_paths.append(img_path)
            selected_train_gt_tot_paths.append(gt_path)
            selected_train_tot_labels.append(label)
            selected_train_tot_types.append(defect_type)

    return (selected_train_img_tot_paths, selected_train_gt_tot_paths, selected_train_tot_labels, selected_train_tot_types), \
           (test_img_tot_paths, test_gt_tot_paths, test_tot_labels, test_tot_types)
# 返回选定的训练数据集和完整的测试数据集,以便后续的模型训练和测试使用

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值