ANOMALIB第二章:模型训练

ANomalib第二章:模型训练

这里我是基于源码进行的,如果还没安装好环境可以去看看第一章。
anomalib目前支持17种模型(有监督和无监督都有):patchcore、WinCLIP、EfficientAD等等。。。
模型的训练按数据集分成两种:仅使用正常数据训练、正常数据和缺陷数据一起训练。按检测模式也分为两种:分类和分割
这里先使用good,ng数据集和分类方式进行训练。后面再分享其他模式训练,应该都大差不差。
模型选择:patchcore
使用官方默认的超参数

数据集准备

我用的是工业的缺陷数据,大家可以自行寻找一些,比如说mvtec的开源数据并整理一下放到一个dataset目录下

--dataset
|--normal
|	|--normal001.jpg
|	|--...
|--defect
|	|--defect--1.jpg

训练代码

我比较偷懒,直接就在anomalib的开源项目里面新建了myProj文件夹并放入训练代码

--anomalib
|--myproj
|	|--train.py
|	|--dataset

导入数据集后,平台会自动分配训练集、测试集和验证集,不用我们自己分配。
代码如下:

import multiprocessing
from anomalib.data import Folder
from anomalib.models import Patchcore
from anomalib.engine import Engine


def main():
    # Create the datamodule
    datamodule = Folder(
        name="flat_enameled_wire",# 数据集名字,也是训练完的模型和过程日志存放的目录名。每次训练新的模型时记得要改一下,不然会覆盖原来的
        root="dataset",# 这里就是数据集的根目录
        normal_dir="normal",# 正常数据集
        abnormal_dir="defect",# 缺陷数据集
        task="classification",# 分类任务
        train_batch_size=1# 批处理数量
    )

    # Setup the datamodule
    datamodule.setup()

    # Create the model and engine
    model = Patchcore()
    engine = Engine(task="classification")

    # Train a Patchcore model on the given datamodule
    engine.train(datamodule=datamodule, model=model)


if __name__ == '__main__':
    multiprocessing.freeze_support()  # Optional, if your script might be frozen into an executable
    main()

说一下训练轮数默认就是1,就算改了参数还是1,因为1就是最优的训练轮数了。
训练完成后,平台会自动生成一个result目录,目录下会有模型目录然后是刚刚我们自定的name目录,直接打开latest(目录超级深,更套娃似的)。然后会在image目录里面看到我们模型在测试集上的推理结果,有原图和分类概率和heatmap,非常直观。

用训练好的模型进行推理

训练好的模型会保存在weights目录下,名字就是model.ckpt
推理代码如下

import multiprocessing
from anomalib.data import Folder
# Import the model and engine
from anomalib.models import Patchcore
from anomalib.engine import Engine


def main():
    # Create the datamodule
    datamodule = Folder(
        name="flat_enameled_wire_infer2",
        root="E:\\proj\\ai\\anomalib\\myProj\\infer_dataset2",
        normal_dir="normal",
        abnormal_dir="defect",
        task="classification",
    )

    # Setup the datamodule
    datamodule.setup()

    model = Patchcore()
    engine = Engine(task="classification")

    predictions = engine.predict(
        datamodule=datamodule,
        model=model,
        ckpt_path="E:\\proj\\ai\\anomalib\\myProj\\results\\Patchcore\\flat_enameled_wire_2\\latest\\weights"
                  "\\lightning\\model.ckpt",
    )


if __name__ == '__main__':
    multiprocessing.freeze_support()  # Optional, if your script might be frozen into an executable
    main()

和训练代码大差不差。如果是分割任务则会给出mask。分类任务没有,只有heatmap和概率。
推理结果

  • 5
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值