1. 简介
实现一个完整的图像分类任务,大致需要五个步骤:
- 选择开源框架
目前常用的深度学习框架主要有caffe、tesorflow、pytorch、mxnet、keras、paddlepaddle等。 - 构建并读取数据集
构建或获取数据集,根据选择开源框架进行数据集读取。 - 训练模型搭建
选择合适的网络模型、损失函数以及优化方式,完成整体的训练模型搭建。 - 训练并调试参数
通过训练选定合适参数。 - 测试准确率
在测试集上验证模型的最终性能。
本次实战选择pytorch开源框架,按照上述步骤实现一个基本的图像分类任务,并详细阐述其中的细节。
2. 数据集
2.1 数据集选取
表面缺陷检测是生产制造过程中必不可少的一步,尤其在带钢原料钢卷的轧制工艺过程中形成的表面缺陷是造成废、次品的主要原因,因此必须加强对带钢表面缺陷检测,通过缺陷检测,对于加强轧制工艺管理,剔除废品等都有重要的意义。
本次实战选择的数据库为由东北大学(NEU)发布的热轧钢带表面缺陷数据库,收集了热轧钢带的六种典型表面缺陷,即轧制氧化皮(RS),斑块(Pa),开裂(Cr),点蚀表面( PS),内含物(In)和划痕(Sc)。该数据库包括1,800个灰度图像:六种不同类型的典型表面缺陷,每一类缺陷包含300个样本。
数据库下载地址 NEU-CLS
提取码:175m
下面展示了6中缺陷样本的图像
2.2 数据集处理
首先需要将数据集分类处理成pytorch可以读取的形式,即是将缺陷图像按类别放置在不同的文件夹中。代码如下:
import os
import shutil
### 数据集根目录
root_dir = '数据集绝对地址'
### 数据集转移目录
shutil_dir = '处理数据集绝对地址'
all_images = os.listdir(root_dir) #读取所有文件
images_classes= ['Cr', 'In', 'Pa', 'PS', 'RS', 'Sc']
for img in all_images:
img_shutil_dir = os.path.join(shutil_dir, str(images_classes.index(img[0:2])))
if not os.path.isdir(img_shutil_dir):
os.mkdir(img_shutil_dir)
shutil.copyfile(os.path.join(root_dir, img), os.path.join(img_shutil_dir, img))
运行后,数据集形式如下:每个文件夹中放置的是同类型的缺陷图像。
2.3 数据集加载
在这一步,需要实现数据集的加载和数据集划分,数据集加载运用ImageFolder()
和DataLoader()
, 数据集划分运用random_spilt()
,同时实现数据集加载时的数据增强。
数据增强介绍:数据增强
Pytorch常用图像处理和数据增强方法:Pytorch
import torch.utils.data as Data
import torchvision
import torchvision.transforms as transforms
train_transform