MobileViT(PyTorch版)
图像识别是最适合初学者入门的项目,本文将介绍如何通过自制数据集进行图像分类训练;
想要实现这个项目,至少需要掌握python基础编程知识,我们这里选择的算法是MobileViT(苹果开源算法),采用了 CNN 和 Transformer 的混合架构,少量样本也能训练不错的效果;
GitHub官方(需翻墙):GitHub - apple/ml-cvnets at d38a116fe134a8cd5db18670764fdaafd39a5d4f
原文参考:15.1 MobileViT网络讲解_哔哩哔哩_bilibili
前期准备
百度网盘:https://pan.baidu.com/s/1aXJkukZcjRJJJoq7eI3dvg?pwd=9orm
外网实在太慢,我把项目的代码、数据集以及模型全部保存到了百度网盘里,223MB大小,请先下载;另外此代码是原作者从官方剥离出来的(为了简洁明了)
模型训练
花卉数据集在data/flower_photos/目录下,如要训练自己的数据集,替换原数据集即可
运行cmd,打开项目目录,运行训练文件
cd C:\Users\administrator\Desktop\MobileViT
python train.py
训练10个epoch准确率就达到了90%,训练好的模型保存在weights/目录下
模型测试
测试图片路径:predict.py文件img_path参数
运行测试文件
python predict.py
93%的概率是郁金香,测试成功
自制数据集
如需自制数据集,可运行crawler.py,进行网络图片爬虫
# 安装依赖库
pip install fake-useragent
# 爬虫程序
python crawler.py
输入关键字,回车,大概爬取1000多张图片后程序会自动停止
替换原花卉数据集,然后训练即可