ANomalib第二章
ANomalib第二章:模型训练
这里我是基于源码进行的,如果还没安装好环境可以去看看第一章。
anomalib目前支持17种模型(有监督和无监督都有):patchcore、WinCLIP、EfficientAD等等。。。
模型的训练按数据集分成两种:仅使用正常数据训练、正常数据和缺陷数据一起训练。按检测模式也分为两种:分类和分割
这里先使用good,ng数据集和分类方式进行训练。后面再分享其他模式训练,应该都大差不差。
模型选择:patchcore
使用官方默认的超参数
数据集准备
我用的是工业的缺陷数据,大家可以自行寻找一些,比如说mvtec的开源数据并整理一下放到一个dataset目录下
--dataset
|--normal
| |--normal001.jpg
| |--...
|--defect
| |--defect--1.jpg
训练代码
我比较偷懒,直接就在anomalib的开源项目里面新建了myProj文件夹并放入训练代码
--anomalib
|--myproj
| |--train.py
| |--dataset
导入数据集后,平台会自动分配训练集、测试集和验证集,不用我们自己分配。
代码如下:
import multiprocessing
from anomalib.data import Folder
from anomalib.models import Patchcore
from anomalib.engine import Engine
def main():
# Create the datamodule
datamodule = Folder(
name="flat_enameled_wire",# 数据集名字,也是训练完的模型和过程日志存放的目录名。每次训练新的模型时记得要改一下,不然会覆盖原来的
root="dataset",# 这里就是数据集的根目录
normal_dir="normal",# 正常数据集
abnormal_dir="defect",# 缺陷数据集
task="classification",# 分类任务
train_batch_size=1# 批处理数量
)
# Setup the datamodule
datamodule.setup()
# Create the model and engine
model = Patchcore()
engine = Engine(task="classification")
# Train a Patchcore model on the given datamodule
engine.train(datamodule=datamodule, model=model)
if __name__ == '__main__':
multiprocessing.freeze_support() # Optional, if your script might be frozen into an executable
main()
说一下训练轮数默认就是1,就算改了参数还是1,因为1就是最优的训练轮数了。
训练完成后,平台会自动生成一个result目录,目录下会有模型目录然后是刚刚我们自定的name目录,直接打开latest(目录超级深,更套娃似的)。然后会在image目录里面看到我们模型在测试集上的推理结果,有原图和分类概率和heatmap,非常直观。
用训练好的模型进行推理
训练好的模型会保存在weights目录下,名字就是model.ckpt
推理代码如下
import multiprocessing
from anomalib.data import Folder
# Import the model and engine
from anomalib.models import Patchcore
from anomalib.engine import Engine
def main():
# Create the datamodule
datamodule = Folder(
name="flat_enameled_wire_infer2",
root="E:\\proj\\ai\\anomalib\\myProj\\infer_dataset2",
normal_dir="normal",
abnormal_dir="defect",
task="classification",
)
# Setup the datamodule
datamodule.setup()
model = Patchcore()
engine = Engine(task="classification")
predictions = engine.predict(
datamodule=datamodule,
model=model,
ckpt_path="E:\\proj\\ai\\anomalib\\myProj\\results\\Patchcore\\flat_enameled_wire_2\\latest\\weights"
"\\lightning\\model.ckpt",
)
if __name__ == '__main__':
multiprocessing.freeze_support() # Optional, if your script might be frozen into an executable
main()
和训练代码大差不差。如果是分割任务则会给出mask。分类任务没有,只有heatmap和概率。