数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
Cutout(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
Mixup实现,在train方法中。需要导入包:from torchtoolbox.tools import mixup_data, mixup_criterion
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
data, labels_a, labels_b, lam = mixup_data(data, target, alpha)
optimizer.zero_grad()
output = model(data)
loss = mixup_criterion(criterion, output, labels_a, labels_b, lam)
loss.backward()
optimizer.step()
print_loss = loss.data.item()
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from dataset.dataset import SeedlingData
from torch.autograd import Variable
from torchvision.models import mobilenet_v2
from torchtoolbox.tools import mixup_data, mixup_criterion
from torchtoolbox.transform import Cutout
设置学习率、BatchSize、epoch等参数,判断环境中是否存在GPU,如果没有则使用CPU。建议使用GPU,CPU太慢了。由于mobilenetv2模型很小,4G显存的GPU就可以设置BatchSize为16。
设置全局参数
modellr = 1e-4
BATCH_SIZE = 16
EPOCHS = 300
DEVICE = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)
数据处理比较简单,加入了Cutout、做了Resize和归一化。对于Normalize和std的值,这个一般是通用的设置,而且在后面的测试中要保持一致。
数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
Cutout(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
将数据集解压后放到data文件夹下面,如图:
然后我们在dataset文件夹下面新建 init.py和dataset.py,在datasets.py文件夹写入下面的代码:
coding:utf8
import os
from PIL import Image
from torch.utils import data
from torchvision import transforms as T
from sklearn.model_selection import train_test_split
Labels = {‘Black-grass’: 0, ‘Charlock’: 1, ‘Cleavers’: 2, ‘Common Chickweed’: 3,
‘Common wheat’: 4, ‘