mvtec2d.py代码讲解

        mvtec2d提供了一系列用于处理2D图像的函数或方法,比如图像增强、预处理、特征提取 .

1.__all__是一个特殊的变量,它用于定义模块的公共接口

__all__ = ['MVTec2D', 'mvtec2d_classes']

        这表示MVTec2D类和mvtec2d_classes列表是这个模块的公共接口部分。当其他模块使用from my_module import *语句导入时,只有MVTec2Dmvtec2d_classes会被导入,而模块中的其他名称则不会被导入。

2. mvtec2d_classes()函数:获取在mvtec2d中要检测的数据类

def mvtec2d_classes():
    return ["bottle",
    "cable", "capsule", "carpet", "grid", "hazelnut",
             "leather", "metal_nut", "pill", "screw", "tile",
            "toothbrush", "transistor", "wood", "zipper"]
3.class MVTec2D(Dataset)类
3.1 初始化  __init__
    def __init__(self, data_path, learning_mode='centralized', phase='train', 
                 data_transform=None, num_task=15):

        self.data_path = data_path
        self.learning_mode = learning_mode
        self.phase = phase
        self.class_name = mvtec2d_classes()
        self.img_transform = data_transform[0]
        self.mask_transform = data_transform[1] 
        assert set(self.class_name) <= set(mvtec2d_classes())
##断言来确保 self.class_name 中的所有类名都存在于 mvtec2d_classes() 返回的类名集合中。
        self.num_task = num_task 
        self.class_in_task = []

        self.imgs_list = []
        self.labels_list = []
        self.masks_list = []
        self.task_ids_list = []
        
        # mark each sample task id
        self.sample_num_in_task = []
        self.sample_indices_in_task = []

        # load dataset
        self.load_dataset()
        self.allocate_task_data()

3.1.1 data_transform值的由来

        img_transform = T.Compose([T.Resize((args['data_size'], args['data_size'])),重新裁剪尺寸值为(args['data_size'], args['data_size'])
                                    T.CenterCrop(args['data_crop_size']),  ##中心化裁剪
                                    T.ToTensor(),  ###转换为张量  转换后的张量会将图像的像素值从 [0, 255] 范围归一化到 [0, 1] 范围
                                    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])   ##标准化 mean,std是RGB图像的均值和标准差
                                    ])

        mask_transform = T.Compose([T.Resize(args['mask_size']),
                                    T.CenterCrop(args['mask_crop_size']),
                                    T.ToTensor(),
                                    ])

3.2__getitem__()函数

    def __getitem__(self, idx):
        img_src, label, mask, task_id = self.imgs_list[idx], self.labels_list[idx], self.masks_list[idx], self.task_ids_list[idx]

        img = Image.open(img_src).convert('RGB')
        img = self.img_transform(img)

        if label == 0:
            if isinstance(img, tuple):
                mask = torch.zeros([1, img[0].shape[1], img[0].shape[2]])
            else:
                mask = torch.zeros([1, img.shape[1], img.shape[2]])
        else:
            mask = Image.open(mask)
            mask = self.mask_transform(mask)

        return {
            'img': img, 'label': label, 'mask': mask, 'task_id': task_id, 'img_src': img_src,
        }
  • __getitem__(self, idx): 这是一个特殊的方法,当实例被索引时会被调用,例如dataset[idx]

  • img = Image.open(img_src).convert('RGB'): 使用PIL库(Python Imaging Library)打开图片,并将其转换为RGB格式

  • img = Image.open(img_src).convert('RGB'): 使用PIL库(Python Imaging Library)打开图片,并将其转换为RGB格式。

  • img = self.img_transform(img): 对图片应用一个变换,这个变换通常包括调整大小、裁剪、归一化等操作,以便图片可以被模型处理。

  • if label == 0:: 如果标签是0,这通常表示图片没有需要分割的目标或者没有异常。

  • if isinstance(img, tuple):: 这里的检查是因为img_transform可能返回一个元组,例如在图像增强技术中可能会返回多个版本的图像。

  • mask = torch.zeros(...): 如果标签是0,创建一个全零的掩码张量,这个张量的形状与图像相同。

  • else:: 如果标签不是0,说明图片中有需要分割的目标或者有异常。

  • mask = Image.open(mask): 打开掩码图片。

  • mask = self.mask_transform(mask): 对掩码图片应用变换,将其转换为模型可以处理的格式。

  • return { ... }: 返回一个包含图像、标签、掩码、任务ID和图片路径的字典。这个字典将被传递给数据加载器,然后用于训练或验证模型。

3.3__len__():

        在Python中,__len__方法是一个特殊的方法,它被定义在自定义类中,用于返回该类实例的“长度”。当对实例使用len()函数时,就会调用这个方法。

    def __len__(self):
        return len(self.imgs_list)

3.4 load_dataset(self)

    def load_dataset(self):
        """目的是加载数据集,并为每个样本分配一个任务 ID"""
        # input x, label y, [0, 1], good is 0 and bad is 1, mask is ground truth
        # train directory: only good cases
        # test directory: bad and good cases 
        # ground truth directory: only bad case

        # get classes in each task group
        # If num_task is 15, each task contain each class
        self.class_in_task = self.split_chunks(self.class_name, self.num_task)
        # get data
        for id, class_in_task in enumerate(self.class_in_task):
            x, y, mask = [], [], []
            for class_name in class_in_task:
                img_dir = os.path.join(self.data_path, class_name, self.phase)##'D:\\DATA\\mvtec2d\\bottle\\train'
                gt_dir = os.path.join(self.data_path, class_name, 'ground_truth') ##'D:\\DATA\\mvtec2d\\bottle\\ground_truth'

                img_types = sorted(os.listdir(img_dir))  ##'good'
                for img_type in img_types:

                    # load images
                    img_type_dir = os.path.join(img_dir, img_type)
                    if not os.path.isdir(img_type_dir):
                        continue
                    img_path_list = sorted([os.path.join(img_type_dir, f)
                                            for f in os.listdir(img_type_dir)
                                            if f.endswith('.png')])
                    x.extend(img_path_list)

                    if img_type == 'good':
                        y.extend([0] * len(img_path_list))
                        mask.extend([None] * len(img_path_list))
                    else:
                        y.extend([1] * len(img_path_list))
                        gt_type_dir = os.path.join(gt_dir, img_type)
                        img_name_list = [os.path.splitext(os.path.basename(f))[0] for f in img_path_list]
                        gt_path_list = [os.path.join(gt_type_dir, img_fname + '_mask.png')
                                        for img_fname in img_name_list]
                        mask.extend(gt_path_list)
            
            task_id = [id for i in range(len(x))]
            self.sample_num_in_task.append(len(x))

            self.imgs_list.extend(x)
            self.labels_list.extend(y)
            self.masks_list.extend(mask)
            self.task_ids_list.extend(task_id) 

extend()函数是在末尾加上数据。

os.listdir(x)函数是在x路径下列出所有的子目录并以列表的形式保存。

sorted()是对数据排序,原来的变量顺序不变。

os.path.join(x,y,z)是将x,y,z的路径合并成一个路径。

os.iddir(x) 判断x是不是目录.

os.path.splitext(os.path.basename(path))

  • os.path.basename(path):这个函数返回路径path的最后一部分,即基础文件名。例如,对于路径/home/user/document.txtbasename会返回document.txt

  • os.path.splitext(path):这个函数将路径path分割成两部分,即文件名和文件的扩展名。它返回一个元组(root, ext),其中root是不包含扩展名的文件名,ext是文件的扩展名,包括点号.。例如,对于路径document.txtsplitext会返回一个元组('document', '.txt')

3.5 allocate_task_data(self)

    def allocate_task_data(self):
        """作用是将数据样本按任务分配到不同的索引组中,并对每个任务中的索引进行随机打乱"""
        start = 0
        for num in self.sample_num_in_task:
            end = start + num
            indice = [i for i in range(start, end)]
            random.shuffle(indice)
            self.sample_indices_in_task.append(indice)
            start = end

random.shuffle(x)将x的顺序打乱。

3.6 split_chunks(arr,m)

    # split the arr into n chunks
    @staticmethod
    def split_chunks(arr, m):
        """将数组分割成 m 个子数组,每个子数组尽可能大小相等"""
        n = int(math.ceil(len(arr) / float(m)))
        return [arr[i:i + n] for i in range(0, len(arr), n)]

  • 26
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值