工业缺陷检测项目实战(四)——基于HRNet的陶瓷缺陷检测

基于HRNet的陶瓷缺陷检测

1.原理:
参考大佬们的文章
HRNet: HRNet原理.

2.数据集准备和代码
数据下载链接:https://aistudio.baidu.com/aistudio/datasetdetail/32615
代码下载链接:https://gitee.com/wxyfmq123456/HRNet-Image-Classification?_from=gitee_search
3.原图二值化
这里数据集已经提供了二值化图像的png,我们需要用png图像进行训练。因为原图像特别不明显。
在这里插入图片描述
总共6个类别。
4. 参数配置
(1) 数据存放位置:
在这里插入图片描述
(2) 数据存放方式:
在这里插入图片描述
每个文件夹代表一个类型,里面的图片全部都是二值化图片(.png),原图(.jpg)可以删去或者备份在其他地方。

(3) 修改代码
在这里插入图片描述
打开cls_hrnet.py,修改

self.classifier = nn.Linear(2048, 1000)

self.classifier = nn.Linear(2048, 6)

也就是后面的参数为类别数

(4) 选择配置文件
在这里插入图片描述
我们选择第一个,即:
cls_hrnet_w18_sgd_lr5e-2_wd1e-4_bs32_x100.yaml

修改里面的参数,将

DATASET:
  DATASET: 'imagenet'
  DATA_FORMAT: 'jpg'
  ROOT: 'data/imagenet/'
  TEST_SET: 'val'
  TRAIN_SET: 'train'

修改为

DATASET:
  DATASET: 'data'
  DATA_FORMAT: 'png'
  ROOT: 'imagenet'
  TEST_SET: 'val'
  TRAIN_SET: 'train'

原因是我们设置的路径跟代码的不一样。
其他参数,比如迭代次数epoch,bath_size等,可以自己调参。

(5) 在train.py中增加如下代码

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    )
 
 
    #以下为增加的代码,上面几行是原有的代码
    #print(train_dataset.classes)  #根据分的文件夹的名字来确定的类别
    with open("class.txt","w") as f1:
        for classname in train_dataset.classes:
            f1.write(classname + "\n")
 
    #print(train_dataset.class_to_idx) #按顺序为这些类别定义索引为0,1...
    with open("classToIndex.txt", "w") as f2:
        for key, value in train_dataset.class_to_idx.items():
            f2.write(str(key) + " " + str(value) + '\n')
 
    #print(train_dataset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别

可以保持类别对应的index。

5.训练

python  tools/train.py --cfg  experiments/cls_hrnet_w18_sgd_lr5e-2_wd1e-4_bs32_x100.yaml

训练完在output文件夹里面有
权重偏置文件:final_state.pth.tar

6.测试
测试的图片,注意这里还是读取验证集vaild,所以为了测试一张图片,我们可以把验证集的图片变为一张,放在例如名为test的文件夹里面,路径如图:
在这里插入图片描述
在HRNet-Image-Classification-master\lib\core\function.py里面的def validate函数,添加

print('class:{}'.format(output.argmax(1)))

以打印识别的类别index。
运行:

python  tools/vaild.py --cfg  experiments/cls_hrnet_w18_sgd_lr5e-2_wd1e-4_bs32_x100.yaml

会打印出所属类别的index。

有Bug需要源代码的可以私信我。继续下一个项目的学习实战。

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

夏融化了这季节

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值