欢迎大家来到图像分类专栏,本篇基于pytorch完成一个多类别图像分类实战。
作者&编辑 | 郭冰洋
1 简介
实现一个完整的图像分类任务,大致需要分为五个步骤:
1、选择开源框架
目前常用的深度学习框架主要包括tensorflow、caffe、pytorch、mxnet等;
2、构建并读取数据集
根据任务需求搜集相关图像搭建相应的数据集,常见的方式包括:网络爬虫、实地拍摄、公共数据使用等。随后根据所选开源框架读取数据集。
3、框架搭建
选择合适的网络模型、损失函数以及优化方式,以完成整体框架的搭建
4、训练并调试参数
通过训练选定合适超参数
5、测试准确率
在测试集上验证模型的最终性能
本文利用Pytorch框架,按照上述结构实现一个基本的图像分类任务,并详细阐述其中的细节及注意事项。
2 数据集
本次实战选择的数据集为Kaggle竞赛中的细胞数据集,共包含9961个训练样本,2491个测试样本,可以分为嗜曙红细胞、淋巴细胞、单核细胞、中性白细胞4个类别,图片大小为320x240。
Pytorch中封装了相应的数据读取的类函数,通过调用torch.utils.data.Datasets函数,则可以实现读取功能。
__init__()模块用来定义相关的参数,__len__()模块用来获取训练样本个数,__getitem__()模块则用来获取每张具体的图片,在读取图片时其可以通过opencv库、PIL库等进行