│ └─train
│ ├─Black-grass
│ ├─Charlock
│ ├─Cleavers
│ ├─Common Chickweed
│ ├─Common wheat
│ ├─Fat Hen
│ ├─Loose Silky-bent
│ ├─Maize
│ ├─Scentless Mayweed
│ ├─Shepherds Purse
│ ├─Small-flowered Cranesbill
│ └─Sugar beet
├─dataset
│ ├─ init.py
│ └─ dataset.py
├─Model
│ └─mpvit.py
├─ test1.py
├─ test2.py
└─ train.py
mpvit.py:来自官方的代码中。
train.py:本文定义。
dataset.py:本文定义
test1.py:本文定义
test2.py:本文定义
==============================================================
数据集选用植物幼苗分类,总共12类。数据集连接如下:
链接:https://pan.baidu.com/s/1TOLSNj9JE4-MFhU0Yv8TNQ
提取码:syng
在工程的根目录新建data文件夹,获取数据集后,将trian和test解压放到data文件夹下面,如下图:
=================================================================
从官方的链接中找到mpvit.py文件,将其放入Model文件夹中。如图:
======================================================================
模型用到了timm库,如果没有需要安装,执行命令:
pip install timm
新建train_connext.py文件,导入所需要的包:
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from dataset.dataset import SeedlingData
from torch.autograd import Variable
from Model.mpvit import mpvit_tiny
from torchtoolbox.tools import mixup_data, mixup_criterion
from torchtoolbox.transform import Cutout
========================