由于M3DM训练时,占用大量内存
这里选择只训练测试一个类别的数据
将https://www.mvtec.com/company/research/datasets/mvtec-3d-ad/downloads 上下载的cookie数据放入正确的文件夹名下!!
main.py
在main.py函数中是循环处理每个类别的数据
def run_3d_ads(args):
if args.dataset_type=='eyecandies':
classes = eyecandies_classes()
elif args.dataset_type=='mvtec3d':
classes = mvtec3d_classes()
.......
for cls in classes:
model = M3DM(args)
model.fit(cls)
image_rocaucs, pixel_rocaucs, au_pros = model.evaluate(cls)
image_rocaucs_df[cls.title()] = image_rocaucs_df['Method'].map(image_rocaucs)
pixel_rocaucs_df[cls.title()] = pixel_rocaucs_df['Method'].map(pixel_rocaucs)
au_pros_df[cls.title()] = au_pros_df['Method'].map(au_pros)
这里可将此处的其他类别注释掉,仅需训练测试cookie类别【或修改main.py中的循环语句,将cls定死maybe】
dataset.py
def mvtec3d_classes():
return [
# "bagel",
# "cable_gland",
# "carrot",
"cookie",
# "dowel",
# "foam",
# "peach",
# "potato",
# "rope",
# "tire",
]