mvtec2d提供了一系列用于处理2D图像的函数或方法,比如图像增强、预处理、特征提取 .
1.__all__
是一个特殊的变量,它用于定义模块的公共接口
__all__ = ['MVTec2D', 'mvtec2d_classes']
这表示MVTec2D
类和mvtec2d_classes
列表是这个模块的公共接口部分。当其他模块使用from my_module import *
语句导入时,只有MVTec2D
和mvtec2d_classes
会被导入,而模块中的其他名称则不会被导入。
2. mvtec2d_classes()函数:获取在mvtec2d中要检测的数据类
def mvtec2d_classes():
return ["bottle",
"cable", "capsule", "carpet", "grid", "hazelnut",
"leather", "metal_nut", "pill", "screw", "tile",
"toothbrush", "transistor", "wood", "zipper"]
3.class MVTec2D(Dataset)类 3.1 初始化 __init__
def __init__(self, data_path, learning_mode='centralized', phase='train',
data_transform=None, num_task=15):
self.data_path = data_path
self.learning_mode = learning_mode
self.phase = phase
self.class_name = mvtec2d_classes()
self.img_transform = data_transform[0]
self.mask_transform = data_transform[1]
assert set(self.class_name) <= set(mvtec2d_classes())
##断言来确保 self.class_name 中的所有类名都存在于 mvtec2d_classes() 返回的类名集合中。
self.num_task = num_task
self.class_in_task = []
self.imgs_list = []
self.labels_list = []
self.masks_list = []
self.task_ids_list = []
# mark each sample task id
self.sample_num_in_task = []
self.sample_indices_in_task = []
# load dataset
self.load_dataset()
self.allocate_task_data()
3.1.1 data_transform值的由来
img_transform = T.Compose([T.Resize((args['data_size'], args['data_size'])),重新裁剪尺寸值为(args['data_size'], args['data_size'])
T.CenterCrop(args['data_crop_size']), ##中心化裁剪
T.ToTensor(), ###转换为张量 转换后的张量会将图像的像素值从 [0, 255] 范围归一化到 [0, 1] 范围
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ##标准化 mean,std是RGB图像的均值和标准差
])
mask_transform = T.Compose([T.Resize(args['mask_size']),
T.CenterCrop(args['mask_crop_size']),
T.ToTensor(),
])
3.2__getitem__()函数
def __getitem__(self, idx):
img_src, label, mask, task_id = self.imgs_list[idx], self.labels_list[idx], self.masks_list[idx], self.task_ids_list[idx]
img = Image.open(img_src).convert('RGB')
img = self.img_transform(img)
if label == 0:
if isinstance(img, tuple):
mask = torch.zeros([1, img[0].shape[1], img[0].shape[2]])
else:
mask = torch.zeros([1, img.shape[1], img.shape[2]])
else:
mask = Image.open(mask)
mask = self.mask_transform(mask)
return {
'img': img, 'label': label, 'mask': mask, 'task_id': task_id, 'img_src': img_src,
}
-
__getitem__(self, idx)
: 这是一个特殊的方法,当实例被索引时会被调用,例如dataset[idx]
。 -
img = Image.open(img_src).convert('RGB')
: 使用PIL库(Python Imaging Library)打开图片,并将其转换为RGB格式 -
img = Image.open(img_src).convert('RGB')
: 使用PIL库(Python Imaging Library)打开图片,并将其转换为RGB格式。 -
img = self.img_transform(img)
: 对图片应用一个变换,这个变换通常包括调整大小、裁剪、归一化等操作,以便图片可以被模型处理。 -
if label == 0:
: 如果标签是0,这通常表示图片没有需要分割的目标或者没有异常。 -
if isinstance(img, tuple):
: 这里的检查是因为img_transform
可能返回一个元组,例如在图像增强技术中可能会返回多个版本的图像。 -
mask = torch.zeros(...)
: 如果标签是0,创建一个全零的掩码张量,这个张量的形状与图像相同。 -
else:
: 如果标签不是0,说明图片中有需要分割的目标或者有异常。 -
mask = Image.open(mask)
: 打开掩码图片。 -
mask = self.mask_transform(mask)
: 对掩码图片应用变换,将其转换为模型可以处理的格式。 -
return { ... }
: 返回一个包含图像、标签、掩码、任务ID和图片路径的字典。这个字典将被传递给数据加载器,然后用于训练或验证模型。
3.3__len__():
在Python中,__len__
方法是一个特殊的方法,它被定义在自定义类中,用于返回该类实例的“长度”。当对实例使用len()
函数时,就会调用这个方法。
def __len__(self):
return len(self.imgs_list)
3.4 load_dataset(self)
def load_dataset(self):
"""目的是加载数据集,并为每个样本分配一个任务 ID"""
# input x, label y, [0, 1], good is 0 and bad is 1, mask is ground truth
# train directory: only good cases
# test directory: bad and good cases
# ground truth directory: only bad case
# get classes in each task group
# If num_task is 15, each task contain each class
self.class_in_task = self.split_chunks(self.class_name, self.num_task)
# get data
for id, class_in_task in enumerate(self.class_in_task):
x, y, mask = [], [], []
for class_name in class_in_task:
img_dir = os.path.join(self.data_path, class_name, self.phase)##'D:\\DATA\\mvtec2d\\bottle\\train'
gt_dir = os.path.join(self.data_path, class_name, 'ground_truth') ##'D:\\DATA\\mvtec2d\\bottle\\ground_truth'
img_types = sorted(os.listdir(img_dir)) ##'good'
for img_type in img_types:
# load images
img_type_dir = os.path.join(img_dir, img_type)
if not os.path.isdir(img_type_dir):
continue
img_path_list = sorted([os.path.join(img_type_dir, f)
for f in os.listdir(img_type_dir)
if f.endswith('.png')])
x.extend(img_path_list)
if img_type == 'good':
y.extend([0] * len(img_path_list))
mask.extend([None] * len(img_path_list))
else:
y.extend([1] * len(img_path_list))
gt_type_dir = os.path.join(gt_dir, img_type)
img_name_list = [os.path.splitext(os.path.basename(f))[0] for f in img_path_list]
gt_path_list = [os.path.join(gt_type_dir, img_fname + '_mask.png')
for img_fname in img_name_list]
mask.extend(gt_path_list)
task_id = [id for i in range(len(x))]
self.sample_num_in_task.append(len(x))
self.imgs_list.extend(x)
self.labels_list.extend(y)
self.masks_list.extend(mask)
self.task_ids_list.extend(task_id)
extend()函数是在末尾加上数据。
os.listdir(x)函数是在x路径下列出所有的子目录并以列表的形式保存。
sorted()是对数据排序,原来的变量顺序不变。
os.path.join(x,y,z)是将x,y,z的路径合并成一个路径。
os.iddir(x) 判断x是不是目录.
os.path.splitext(os.path.basename(path))
-
os.path.basename(path)
:这个函数返回路径path
的最后一部分,即基础文件名。例如,对于路径/home/user/document.txt
,basename
会返回document.txt
。 -
os.path.splitext(path)
:这个函数将路径path
分割成两部分,即文件名和文件的扩展名。它返回一个元组(root, ext)
,其中root
是不包含扩展名的文件名,ext
是文件的扩展名,包括点号.
。例如,对于路径document.txt
,splitext
会返回一个元组('document', '.txt')
3.5 allocate_task_data(self)
def allocate_task_data(self):
"""作用是将数据样本按任务分配到不同的索引组中,并对每个任务中的索引进行随机打乱"""
start = 0
for num in self.sample_num_in_task:
end = start + num
indice = [i for i in range(start, end)]
random.shuffle(indice)
self.sample_indices_in_task.append(indice)
start = end
random.shuffle(x)将x的顺序打乱。
3.6 split_chunks(arr,m)
# split the arr into n chunks
@staticmethod
def split_chunks(arr, m):
"""将数组分割成 m 个子数组,每个子数组尽可能大小相等"""
n = int(math.ceil(len(arr) / float(m)))
return [arr[i:i + n] for i in range(0, len(arr), n)]