在数据量较少的时候可以采用数据mixup数据增强,具体实现代码如下,将两张图叠加
alpha=1.0
loss_fun = nn.CrossEntropyLoss()
loss_fun = loss_fun.to(device)
for x,data in enumerate(tqdm(train_dataloader)):#加个tqdm可以显示进度条
imgs,targets = data
imgs = imgs.to(device)
targets = targets.to(device)
lam = np.random.beta(alpha, alpha)
index = torch.randperm(imgs.size(0))
images_a, images_b = imgs, imgs[index]
labels_a, labels_b = targets, targets[index]
mixed_images = lam * images_a + (1 - lam) * images_b
outputs = model(mixed_images)
_, preds = torch.max(outputs, 1)
loss = lam * loss_fun(outputs, labels_a) + (1 - lam) * loss_fun(outputs, labels_b)
#计算正确率
acc = (lam * preds.eq(labels_a.data).cpu().sum().float()
+ (1 - lam) * preds.eq(labels_b.data).cpu().sum().float())/100