参考图像增广
此次模型训练用到了CIFAR-10数据集
CIFAR-10数据集中物体的颜色和大小区别显著。
下面展示了CIFAR-10数据集中前32张训练图像。
import time
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import sys
sys.path.append("d2lzh_pytorch.py")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def show_images(imgs, num_rows, num_cols, scale=2):
figsize = (num_cols * scale, num_rows * scale)
_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
for i in range(num_rows):
for j in range(num_cols):
axes[i][j].imshow(imgs[i * num_cols + j])
axes[i][j].axes.get_xaxis().set_visible(False)
axes[i][j].axes.get_yaxis().set_visible(False)
return axes
all_imges = torchvision.datasets.CIFAR10(train=True, root="./data", download=False)
#若你的文件夹下没有该数据集,则改成 download = True
# all_imges的每一个元素都是(image, label)
show_images([all_imges[i][0] for i in range(32)], 4, 8, scale=0.8)
为了在预测时得到确定的结果,我们通常只将图像增广应用在**训练样本**上,而不在预测时使用含随机操作的图像增广。
在这里我们只使用最简单的随机左右翻转。 此外,我们使用ToTensor将小批量图像转成PyTorch需要的格式,即形状为(批量大小, 通道数, 高, 宽)、值域在0到1之间且类型为32位浮点数。
flip_aug = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor()])
no_aug = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()])
def train(train_iter, test_iter, net, loss, optimizer, device, num_epochs):
net = net.to(device)
print("training on ", device)
batch_count = 0
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
for X, y in train_iter:
X = X.to(device)
y = y.to(device)
y_hat = net(X)
l = loss(y_hat, y)
optimizer.zero_grad()
l.backward()
optimizer.step()
train_l_sum += l.cpu().item()
train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
n += y.shape[0]
batch_count += 1
test_acc = d2l.evaluate_accuracy(test_iter, net)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
num_workers = 0 if sys.platform.startswith('win32') else 4
def load_cifar10(is_train, augs, batch_size, root="./data"):
dataset = torchvision.datasets.CIFAR10(root=root, train=is_train, transform=augs, download=False)
return DataLoader(dataset, batch_size=batch_size, shuffle=is_train, num_workers=num_workers)
def train_with_data_aug(train_augs, test_augs, lr=0.001):
batch_size, net = 256, d2l.resnet18(10)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss = torch.nn.CrossEntropyLoss()
train_iter = load_cifar10(True, train_augs, batch_size)
test_iter = load_cifar10(False, test_augs, batch_size)
train(train_iter, test_iter, net, loss, optimizer, device, num_epochs=10)
train_with_data_aug(flip_aug, no_aug)
训练结果如下,会比较慢
d2lzh_pytorch