Paper:DiffusionAD: Norm-guided One-step Denoising Diffusion for Anomaly Detection
Code:https://github.com/HuiZhang0812/DiffusionAD?tab=readme-ov-file
文献解读:【2023工业图像异常检测文献】DiffusionAD: 基于规范引导单步去噪的扩散模型异常检测方法
1、Introduction
DiffusionAD由两个主要部分组成:重建子网络和分割子网络。重建子网络使用扩散模型将输入图像转换为噪声图像,然后预测噪声以重建无异常的图像。分割子网络接收输入图像和重建图像,预测像素级的异常分数。
- 重建子网络(Reconstruction Sub-network):
- 通过扩散模型实现,将重建过程重新定义为从噪声到规范的范式。
- 首先,输入图像被逐渐添加高斯噪声,直到变成噪声图像。
- 然后,使用U-Net架构的网络预测噪声,并使用预测的噪声重建图像。
- 分割子网络(Segmentation Sub-network):
- 使用U-Net架构,包括编码器、解码器和跳跃连接。
- 输入是输入图像和重建图像的通道级连接。
- 通过比较输入图像和重建图像之间的不一致性和共同性来预测像素级的异常分数。
- 规范引导单步去噪(Norm-guided One-step Denoising):
- 为了提高实时推理的速度,提出了单步去噪方法,通过一次网络推断直接预测噪声并重建图像。
- 引入规范引导范式,结合不同噪声尺度的优势,以处理不同类型的异常并提高重建质量。
2、Framework
项目文件说明:
-
args: 可能包含命令行参数的定义或配置,用来设置训练过程中的各种选项。
-
data: 存放数据集相关文件夹,例如下载的数据、预处理的数据等。
-
imgs: 可能存放一些示例图像或结果图像,用于展示算法的效果。
-
models: 包含模型相关的代码,可能包括模型架构的定义、权重加载/保存等功能。
-
README.md: 项目的介绍文档,包含了对 DiffusionAD 的概述、安装步骤、数据准备方法以及如何开始使用该项目的信息。
-
eval.py: 评估脚本,用于加载已训练好的模型并对测试数据进行评估,输出性能指标或生成可视化结果。
-
requirements.txt: 列出了运行项目所需的所有 Python 依赖包及版本信息。用户可以通过 pip install -r requirements.txt 命令一次性安装所有必要的库。
-
train.py: 训练脚本,定义了整个训练流程,包括数据加载、模型构建、损失函数定义、优化器选择以及训练循环等。
3、Environments
PyTorch 2.0.0
Python 3.8(ubuntu20.04)
Cuda 11.8
(扩散模型对硬件要求较高,建议使用4090以上显卡)
4、Datasets
数据集使用的是MVTecAD (MVTEC ANOMALY DETECTION DATASET)
数据来源https://www.mvtec.com/company/research/datasets/mvtec-ad,将其放在某个位置datapath。确保其遵循以下数据树结构:
MVTec-AD
|-- carpet
|-----|----- thresh
|-----|----- ground_truth
|-----|----- test
|-----|--------|------ good
|-----|--------|------ ...
|-----|----- train
|-----|--------|------ good
|-- cable
|-----|----- DISthresh
|-----|----- ground_truth
|-----|----- test
|-----|--------|------ good
|-----|--------|------ ...
|-----|----- train
|-----|--------|------ good
|-- ...
总共有15个子数据集:bottle(瓶子),cable(电缆),capsule(胶囊),carpet(地毯),grid(网格),hazelnut(危险螺母), leather(皮革),metal_nut(金属螺母),pill(药丸 ),screw(螺丝 ),tile(瓷砖),toothbrush(牙刷 ),transistor(变压器),wood(木材),zipper(拉链)。
数据集中的前景图像 thresh 或者 DISthresh 可以通过 Github 上下载。
其中,异常源图像集Describable Textures 数据集 (DTD) 也需要从Github上下载下来。
5、Requirements
安装所需依赖,执行:pip install -r requirements.txt
没有完全提及。。。运行时,出现报错需要安装指定包,pip install 包名 ,就可解决了。。。
6、Train and Evaluation of the Model
1)修改 args.json 中指定数据集路径 (MVTec-AD,VisA)、anomaly_source_path (DTD)
2)训练模型,运行:python train.py
3)评估和测试模型,运行:python eval.py