【2023工业图像异常检测代码复现】SimpleNet: 简单的图像异常检测和定位网络

Paper:SimpleNet: A Simple Network for Image Anomaly Detection and Localization
Code:https://github.com/donaldrr/simplenet

文献解读:【2023工业图像异常检测文献】SimpleNet: 简单的图像异常检测和定位网络

1、Introduction

SimpleNet 使用预训练的主干进行常规特征提取,然后使用特征适配器将特征转移到目标域。然后,通过在适应的正态特征上加入高斯噪声,简单地生成异常特征,在这些特征上训练一个由几层 MLP 组成的简单判别器来判别异常。

在这里插入图片描述

2、Framework

项目结构与文件说明:

  • datasets/: 存放数据集的目录,项目中提到的MVTec AD数据集应该存储在此。
  • imgs/: 存储图像文件的目录,可能用于可视化或示例。
  • backbones.py: 包含了模型的主干网络定义,通常用于特征提取。
  • common.py: 可能包含了一些通用函数或类,供其他模块调用。
  • main.py: 项目入口文件,定义了如何运行整个程序。
  • metrics.py: 包含了评估指标的实现,用于量化模型性能。
  • resnet.py: 定义了ResNet模型的实现,可能用于作为模型的一部分。
  • run.sh: 一个Shell脚本,用于配置和运行训练过程。
  • simplenet.py: 核心模块,实现了SimpleNet框架的主要逻辑。
  • utils.py: 工具函数集合,可能包括数据处理、文件操作等辅助功能。
  • run.sh: 一个Shell脚本,用于配置和运行训练过程。

3、Environments

PyTorch 2.0.0
Python 3.8(ubuntu20.04)
Cuda 11.8

4、Datasets

数据集使用的是MVTecAD (MVTEC ANOMALY DETECTION DATASET)

数据来源https://www.mvtec.com/company/research/datasets/mvtec-ad,将其放在某个位置datapath。确保其遵循以下数据树结构:

mvtec
|-- bottle
|-----|----- ground_truth
|-----|----- test
|-----|--------|------ good
|-----|--------|------ broken_large
|-----|--------|------ ...
|-----|----- train
|-----|--------|------ good
|-- cable
|-- ...

总共有15个子数据集:bottle(瓶子),cable(电缆),capsule(胶囊),carpet(地毯),grid(网格),hazelnut(危险螺母), leather(皮革),metal_nut(金属螺母),pill(药丸 ),screw(螺丝 ),tile(瓷砖),toothbrush(牙刷 ),transistor(变压器),wood(木材),zipper(拉链)。

5、Requirements

pip install 包名 ,安装以下模块

torch>=1.12.1
torchvision>=0.13.1
numpy>=1.22.4
opencv-python>=4.5.1

没有完全提及。。。运行时,出现报错需要安装指定包,pip install 包名 ,就可解决了。。。

6、Parameters

run.sh 中参数解释:

datapath=/root/autodl-tmp/mvtec_anomaly_detection
数据集路径,指向 MVTec AD 数据集的根目录。

datasets=('screw' 'pill' 'capsule' 'carpet' 'grid' 'tile' 'wood' 'zipper' 'cable' 'toothbrush' 'transistor' 'metal_nut' 'bottle' 'hazelnut' 'leather')
需要处理的数据集列表,包含 MVTec AD 数据集中所有的子数据集。

dataset_flags=($(for dataset in "${datasets[@]}"; do echo '-d '"${dataset}"; done))
生成一个包含所有数据集标志的数组,每个数据集对应一个 -d 参数。

主命令

python3 main.py
运行主脚本 main.py。

GPU 和随机种子

--gpu 0
使用第 0 块 GPU 进行计算。

--seed 0
设置随机种子为 0,以确保实验的可重复性。

日志和结果保存

--log_group simplenet_mvtec
日志组名称,用于标识不同的实验配置。

--log_project MVTecAD_Results
日志项目名称,用于组织不同实验的日志。

--results_path results
结果目录,用于保存实验结果。

--run_name run
实验运行的名称,用于区分不同的实验。

网络配置

net
指定使用网络配置部分。

-b wideresnet50
使用 WideResNet50 作为骨干网络。

-le layer2 -le layer3
使用骨干网络的第 2 层和第 3 层作为特征提取层。

--pretrain_embed_dimension 1536
预训练嵌入的维度为 1536。

--target_embed_dimension 1536
目标嵌入的维度为 1536。

--patchsize 3
图像块的大小为 3x3。

--meta_epochs 40
元训练的轮数为 40。

--embedding_size 256
嵌入向量的大小为 256。

--gan_epochs 4
GAN 训练的轮数为 4。

--noise_std 0.015
添加噪声的标准差为 0.015。

--dsc_hidden 1024
判别器隐藏层的大小为 1024。

--dsc_layers 2
判别器的层数为 2。

--dsc_margin .5
判别器的边界值为 0.5。

--pre_proj 1
是否使用预投影,1 表示使用。

数据集配置

dataset
指定使用数据集配置部分。

--batch_size 8
批次大小为 8。

--resize 329
将图像缩放到 329x329 大小。

--imagesize 288
将图像裁剪到 288x288 大小。

"${dataset_flags[@]}"
传递所有数据集标志。

mvtec
指定数据集类型为 MVTec AD。

$datapath
传递数据集路径。

报错①:在metrics.py文件中,出现对应库中没有该属性,可能原因是一些库已经更新,删除了对应的属性,所以需要做出对应修改

在这里插入图片描述

修改代码如下:

# binary_amaps = np.zeros_like(amaps, dtype=np.bool)
binary_amaps = np.zeros_like(amaps, dtype=np.bool_)

# df = df.append({"pro": np.mean(pros), "fpr": fpr, "threshold": th}, ignore_index=True)
df = pd.concat([df, pd.DataFrame([{"pro": np.mean(pros), "fpr": fpr, "threshold": th}])], ignore_index=True)

7、Results

没运行完。。。

在这里插入图片描述

### 关于SimpleNet复现方法 为了成功复现SimpleNet模型,建议按照官方GitHub仓库提供的指南进行操作。项目地址提供了详细的安装说明以及运行环境配置指导[^1]。 #### 安装依赖库 确保Python版本不低于3.6,并通过pip工具安装必要的软件包。通常情况下,`requirements.txt`文件会列出所有必需的第三方模块及其具体版本号: ```bash pip install -r requirements.txt ``` #### 数据准备 下载并解压所需的数据集,特别是用于验证算法效果的标准测试集合如MVTec AD。根据文档指示调整路径设置以便程序能够访问这些资源[^3]。 #### 训练过程 启动训练脚本之前,请确认已正确设置了CUDA可见设备编号其他超参数选项。对于初次使用者来说,默认配置往往已经过优化处理可以直接使用: ```python import tensorflow as tf from simplenet import SimpleNet if __name__ == "__main__": gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: print(e) model = SimpleNet() model.train() # 开始训练流程 ``` 此段代码展示了如何初始化SimpleNet类实例对象并通过调用其成员函数完成整个学习周期[^4]。 #### 测试评估 当训练完成后,可以通过加载保存下来的权重文件来进行预测性能评测工作。注意这里涉及到的具体API接口定义可以在源码中找到对应的实现逻辑。 ```python model.load_weights('./checkpoints/best_model.hdf5') # 加载最优模型参数 test_images = ... # 获取待测样本列表 predictions = model.predict(test_images) # 执行前向传播获取输出概率分布 ``` 上述步骤概括了从零搭建至最终部署SimpleNet的整体思路技术要点。希望这些建议能帮助到想要深入了解该领域的朋友。
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值