Paper:SimpleNet: A Simple Network for Image Anomaly Detection and Localization
Code:https://github.com/donaldrr/simplenet
文献解读:【2023工业图像异常检测文献】SimpleNet: 简单的图像异常检测和定位网络
1、Introduction
SimpleNet 使用预训练的主干进行常规特征提取,然后使用特征适配器将特征转移到目标域。然后,通过在适应的正态特征上加入高斯噪声,简单地生成异常特征,在这些特征上训练一个由几层 MLP 组成的简单判别器来判别异常。
2、Framework
项目结构与文件说明:
- datasets/: 存放数据集的目录,项目中提到的MVTec AD数据集应该存储在此。
- imgs/: 存储图像文件的目录,可能用于可视化或示例。
- backbones.py: 包含了模型的主干网络定义,通常用于特征提取。
- common.py: 可能包含了一些通用函数或类,供其他模块调用。
- main.py: 项目入口文件,定义了如何运行整个程序。
- metrics.py: 包含了评估指标的实现,用于量化模型性能。
- resnet.py: 定义了ResNet模型的实现,可能用于作为模型的一部分。
- run.sh: 一个Shell脚本,用于配置和运行训练过程。
- simplenet.py: 核心模块,实现了SimpleNet框架的主要逻辑。
- utils.py: 工具函数集合,可能包括数据处理、文件操作等辅助功能。
- run.sh: 一个Shell脚本,用于配置和运行训练过程。
3、Environments
PyTorch 2.0.0
Python 3.8(ubuntu20.04)
Cuda 11.8
4、Datasets
数据集使用的是MVTecAD (MVTEC ANOMALY DETECTION DATASET)
数据来源https://www.mvtec.com/company/research/datasets/mvtec-ad,将其放在某个位置datapath。确保其遵循以下数据树结构:
mvtec
|-- bottle
|-----|----- ground_truth
|-----|----- test
|-----|--------|------ good
|-----|--------|------ broken_large
|-----|--------|------ ...
|-----|----- train
|-----|--------|------ good
|-- cable
|-- ...
总共有15个子数据集:bottle(瓶子),cable(电缆),capsule(胶囊),carpet(地毯),grid(网格),hazelnut(危险螺母), leather(皮革),metal_nut(金属螺母),pill(药丸 ),screw(螺丝 ),tile(瓷砖),toothbrush(牙刷 ),transistor(变压器),wood(木材),zipper(拉链)。
5、Requirements
pip install 包名 ,安装以下模块
torch>=1.12.1
torchvision>=0.13.1
numpy>=1.22.4
opencv-python>=4.5.1
没有完全提及。。。运行时,出现报错需要安装指定包,pip install 包名 ,就可解决了。。。
6、Parameters
run.sh 中参数解释:
datapath=/root/autodl-tmp/mvtec_anomaly_detection
数据集路径,指向 MVTec AD 数据集的根目录。
datasets=('screw' 'pill' 'capsule' 'carpet' 'grid' 'tile' 'wood' 'zipper' 'cable' 'toothbrush' 'transistor' 'metal_nut' 'bottle' 'hazelnut' 'leather')
需要处理的数据集列表,包含 MVTec AD 数据集中所有的子数据集。
dataset_flags=($(for dataset in "${datasets[@]}"; do echo '-d '"${dataset}"; done))
生成一个包含所有数据集标志的数组,每个数据集对应一个 -d 参数。
主命令
python3 main.py
运行主脚本 main.py。
GPU 和随机种子
--gpu 0
使用第 0 块 GPU 进行计算。
--seed 0
设置随机种子为 0,以确保实验的可重复性。
日志和结果保存
--log_group simplenet_mvtec
日志组名称,用于标识不同的实验配置。
--log_project MVTecAD_Results
日志项目名称,用于组织不同实验的日志。
--results_path results
结果目录,用于保存实验结果。
--run_name run
实验运行的名称,用于区分不同的实验。
网络配置
net
指定使用网络配置部分。
-b wideresnet50
使用 WideResNet50 作为骨干网络。
-le layer2 -le layer3
使用骨干网络的第 2 层和第 3 层作为特征提取层。
--pretrain_embed_dimension 1536
预训练嵌入的维度为 1536。
--target_embed_dimension 1536
目标嵌入的维度为 1536。
--patchsize 3
图像块的大小为 3x3。
--meta_epochs 40
元训练的轮数为 40。
--embedding_size 256
嵌入向量的大小为 256。
--gan_epochs 4
GAN 训练的轮数为 4。
--noise_std 0.015
添加噪声的标准差为 0.015。
--dsc_hidden 1024
判别器隐藏层的大小为 1024。
--dsc_layers 2
判别器的层数为 2。
--dsc_margin .5
判别器的边界值为 0.5。
--pre_proj 1
是否使用预投影,1 表示使用。
数据集配置
dataset
指定使用数据集配置部分。
--batch_size 8
批次大小为 8。
--resize 329
将图像缩放到 329x329 大小。
--imagesize 288
将图像裁剪到 288x288 大小。
"${dataset_flags[@]}"
传递所有数据集标志。
mvtec
指定数据集类型为 MVTec AD。
$datapath
传递数据集路径。
报错①:在metrics.py文件中,出现对应库中没有该属性,可能原因是一些库已经更新,删除了对应的属性,所以需要做出对应修改
修改代码如下:
# binary_amaps = np.zeros_like(amaps, dtype=np.bool)
binary_amaps = np.zeros_like(amaps, dtype=np.bool_)
# df = df.append({"pro": np.mean(pros), "fpr": fpr, "threshold": th}, ignore_index=True)
df = pd.concat([df, pd.DataFrame([{"pro": np.mean(pros), "fpr": fpr, "threshold": th}])], ignore_index=True)
7、Results
没运行完。。。