前言
本篇博客主要是讲述如何,训练自己的分类网络。使用的是Pytorch 框架。所有代码是使用自己学习到的方法拼凑出来的。包含了我之前博客中提到的ResNet,MLP-Mixer,ConvNeXt,ConvMixer。
网络 | 博客链接 |
---|---|
ResNet | 链接 |
MLP-Mixer | 链接 |
ConvNeXt | 链接 |
ConvMixer | 链接 |
项目地址:https://github.com/jiantenggei/torch-classification
猫狗数据集链接:https://pan.baidu.com/s/1y7Vjy3-RhlEFvtH6RY8JMQ 提取码:90f0
一、如何配置
1.运行环境:
仓库中提供了yaml文件,conda虚拟环境导入,torch是GPU版本:
conda env create -f torch.yaml
配置好后检查GPU是否可用:
import torch
#返回当前设备索引
print(torch.cuda.current_device())
#返回GPU的数量
print(torch.cuda.device_count())
#cuda是否可用 True 表示可用
print(torch.cuda.is_available())
2.准备数据集
深度学习图像分类数据集,一般是以文件夹作为类别命,下面以猫狗分类为例。数据集拜访如下所示:
─dataset
├─train
│ └─cats
│ └─xxjj.jpg
│ └─dogs
│ └─xxx.jpg
├─test
│ └─cats
│ └─dogs
cats 和dogs 文件夹下直接是图片即可。
3.生成训练数据索引
且还需要在classes.txt文件 按顺序写下该顺序如下图:
classes = [“cats”, “dogs”] 数据集中文件夹的命名顺序保持一致。 运行txt_annotation.py 生成上图中的
cls_train.txt 和cls_text.txt。
import os
from os import getcwd
#---------------------------------------------------#
# 训练自己的数据集的时候一定要注意修改classes
# 修改成自己数据集所区分的种类
# 种类顺序需要和训练时用到的classes.txt一样
# 生成cls_train.txt, cls_text.txt
#---------------------------------------------------#
classes = ["cats", "dogs"]
sets = [