transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
Mixup实现,在train方法中。需要导入包:from torchtoolbox.tools import mixup_data, mixup_criterion
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
data, labels_a, labels_b, lam = mixup_data(data, target, alpha)
optimizer.zero_grad()
output = model(data)
loss = mixup_criterion(criterion, output, labels_a, labels_b, lam)
loss.backward()
optimizer.step()
print_loss = loss.data.item()
===============================================================
使用tree命令,打印项目结构
MPViT_demo
├─data
│ ├─test
│ └─train
│ ├─Black-grass
│ ├─Charlock
│ ├─Cleavers
│ ├─Common Chickweed
│ ├─Common wheat
│ ├─Fat Hen
│ ├─Loose Silky-bent
│ ├─Maize
│ ├─Scentless Mayweed
│ ├─Shepherds Purse
│ ├─Small-flowered Cranesbill
│ └─Sugar beet
├─dataset
│ ├─ init.py
│ └─ dataset.py
├─Model
│ └─mpvit.py
├─ test1.py
├─ test2.py
└─ train.py
mpvit.py:来自官方的代码中。
train.py:本文定义。
dataset.py:本文定义
test1.py:本文定义
test2.py:本文定义
==============================================================
数据集选用植物幼苗分类,总共12类。数据集连接如下:
链接:https://pan.baidu.com/s/1TOLSNj9JE4-MFhU0Yv8TNQ
提取码:syng
在工程的根目录新建data文件夹,获取数据集后,将trian和test解压放到data文件夹下面,如下图:
=================================================================
从官方的链接中找到mpvit.py文件,将其放入Model文件夹中。如图:
==========