基于Pytorch的热轧钢带表面缺陷分类挑战

1. 简介

实现一个完整的图像分类任务,大致需要五个步骤:

  1. 选择开源框架
    目前常用的深度学习框架主要有caffe、tesorflow、pytorch、mxnet、keras、paddlepaddle等。
  2. 构建并读取数据集
    构建或获取数据集,根据选择开源框架进行数据集读取。
  3. 训练模型搭建
    选择合适的网络模型、损失函数以及优化方式,完成整体的训练模型搭建。
  4. 训练并调试参数
    通过训练选定合适参数。
  5. 测试准确率
    在测试集上验证模型的最终性能。

本次实战选择pytorch开源框架,按照上述步骤实现一个基本的图像分类任务,并详细阐述其中的细节。

2. 数据集

2.1 数据集选取

表面缺陷检测是生产制造过程中必不可少的一步,尤其在带钢原料钢卷的轧制工艺过程中形成的表面缺陷是造成废、次品的主要原因,因此必须加强对带钢表面缺陷检测,通过缺陷检测,对于加强轧制工艺管理,剔除废品等都有重要的意义。

本次实战选择的数据库为由东北大学(NEU)发布的热轧钢带表面缺陷数据库,收集了热轧钢带的六种典型表面缺陷,即轧制氧化皮(RS),斑块(Pa),开裂(Cr),点蚀表面( PS),内含物(In)和划痕(Sc)。该数据库包括1,800个灰度图像:六种不同类型的典型表面缺陷,每一类缺陷包含300个样本。

数据库下载地址 NEU-CLS
提取码:175m

下面展示了6中缺陷样本的图像
在这里插入图片描述

2.2 数据集处理

首先需要将数据集分类处理成pytorch可以读取的形式,即是将缺陷图像按类别放置在不同的文件夹中。代码如下:

import os
import shutil

### 数据集根目录
root_dir = '数据集绝对地址'

### 数据集转移目录
shutil_dir = '处理数据集绝对地址'

all_images = os.listdir(root_dir)   #读取所有文件

images_classes= ['Cr', 'In', 'Pa', 'PS', 'RS', 'Sc']

for img in all_images:
    img_shutil_dir = os.path.join(shutil_dir, str(images_classes.index(img[0:2])))
    if not os.path.isdir(img_shutil_dir):
        os.mkdir(img_shutil_dir)
    shutil.copyfile(os.path.join(root_dir, img), os.path.join(img_shutil_dir, img))

运行后,数据集形式如下:每个文件夹中放置的是同类型的缺陷图像。
在这里插入图片描述

2.3 数据集加载

在这一步,需要实现数据集的加载和数据集划分,数据集加载运用ImageFolder()DataLoader(), 数据集划分运用random_spilt(),同时实现数据集加载时的数据增强。
数据增强介绍:数据增强
Pytorch常用图像处理和数据增强方法:Pytorch

import torch.utils.data as Data
import torchvision
import torchvision.transforms as transforms

train_transform = transforms.Compose([
        transforms.RandomResizedCrop(200),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


dataset = torchvision.datasets.ImageFolder(shutil_dir, transform=train_transform)  #全部训练用例
'''
  按照8 :2 比例切分数据集为训练集和验证集
  
  train_dataset 为训练集,valid_dataset为验证集
'''
train_size = int(0.8*len(dataset))
valid_size = len(dataset)-train_size

train_dataset, valid_dataset = Data.random_split(dataset, [train_size, valid_size])

train_data = Data.DataLoader(train_dataset, batch_size=1, shuffle=True)
valid_data = Data.DataLoader(valid_dataset, batch_size=1, shuffle=False)

本例中的Normalize使用的参数为在ImageNet数据集上计算得到的方差和均值,实际使用时需要重新计算。参考链接:pytorch标准化

3. 训练模型

3.1 网络结构

常用的图像分类网络有VGG、ResNet、ResNext、DenseNet、Inception、ShuffleNet等,
参考链接:
图像分类:常用分类网络结构(附论文下载)
常用的分类网络

在本次实战中,主要选取了ResNet-50经典网络做为训练模型,

import torchvision
import torch.nn as nn

basic_model = torchvision.models.resnet50(pretrained=True)

class resnet_classifier(nn.Module):
    def __init__(self, classnumber=21):
        super(resnet_classifier, self).__init__()

        self.features = nn.Sequential(*list(basic_model.children())[:-1])
        fc_features = basic_model.fc.in_features
        self.classifier = nn.Linear(fc_features, classnumber, bias=False)
    def forward(self, x):
        features = self.features(x)
        features = torch.flatten(features, 1)
        classifier = self.classifier(features)
        return classifier

3.2 损失函数和优化方式

损失函数选择标准的交叉熵损失函数(详细介绍损失函数
优化方式选择Adam优化(详细介绍优化方式)

4. 训练及参数调试

在训练中,在网络结构中加载了预训练模型,可以加快训练速度和提升训练精度,初始学习率设置为1e-4, 在网络结构的特征层和分类层采取不同的学习率,分类层的学习率为特征层的10倍,学习率调整策略为指数衰减。(参考链接学习率调整

model = resnet_classifier()
train_params = [{'params':model.features.parameters(), 'lr':lr},
                {'params':model.classifier.parameters(),'lr':10*lr}]
optimizer = torch.optim.Adam(train_params)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9)

训练和测试代码:

    def training(self,epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_data)
        num_img_tr = len(self.train_data)
        for i, sample in enumerate(tbar):
            img, label = sample
            if self.cuda:
                img = img.cuda()
            self.optimizer.zero_grad()
            output = self.model(img)
            loss = self.Loss(output.cpu(), label)
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()
            train_loss += loss.item()
            ###  记录训练过程  监控loss值
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.batch_size + img.data.shape[0]))
        print('Loss: %.3f' % train_loss)

    def validation(self, epoch):
       	self.model.eval()
        tbar = tqdm(self.valid_data, desc='\r')
        test_loss = 0.0
        train_acc_sum = 0.0
        num_img_tr = len(self.valid_data) * self.batch_size
        for i, sample in enumerate(tbar):
            img, label = sample
            if self.cuda:
                img = img.cuda()
            with torch.no_grad():
                output = model(img)
            loss = self.loss(output.cpu(), label)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            # Add batch sample into evaluator
            train_acc_sum += (output.cpu().argmax(dim=1) == label).sum().cpu().item()
        ### 监控验证过程  记录正确率    
        accuracy = train_acc_sum / num_img_tr
        self.writer.add_scalar('test/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('accuracy', accuracy, epoch)

5. 测试

选用不同的模型和训练参数,对比训练精度,对模型或者超参数进行调整优化。

  • 8
    点赞
  • 64
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
PyTorch是一个流行的深度学习框架,可以用于训练和部署神经网络模型。在缺陷检测方面,PyTorch可以提供便利的工具和库,但也存在一些缺陷。 首先,PyTorch的学习曲线相对较陡。虽然PyTorch提供了丰富的文档和示例代码,但对于初学者来说,学习和理解PyTorch的概念和运作机制可能需要花费较多的时间和精力。 其次,PyTorch在大规模分布式训练方面存在一些局限性。虽然PyTorch支持分布式训练,但其在处理大量数据和大规模模型时,相比其他框架(如TensorFlow)可能表现出较差的性能。 此外,PyTorch在部署模型时相对复杂。虽然PyTorch提供了一些用于部署模型的工具和库,但相较于其他框架(如TensorFlow Serving),PyTorch在部署模型时需要更多的手动配置和管理。 另一个缺陷PyTorch相对较新,社区生态系统相对较小。与其他框架相比,PyTorch的社区贡献和支持相对较少,可能会导致在遇到问题时,很难找到解决方案或得到及时的帮助。 最后,PyTorch在一些特定任务上的性能可能不如其他框架。虽然PyTorch在图像分类和自然语言处理等任务上表现出色,但在一些特定的领域(如语音识别或推荐系统)中,其他框架(如TensorFlow或Keras)可能提供更好的性能和支持。 总的来说,尽管PyTorch深度学习任务中具有很多优势,但它也存在一些缺陷。只有在根据具体任务和需求权衡利弊后,才能做出选择。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值